Compare commits
10 Commits
f102e44cc4
...
e551efc0a4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e551efc0a4 | ||
|
|
01e352bc2e | ||
|
|
a129325ade | ||
|
|
4cd96b58b5 | ||
|
|
39ac5a86dd | ||
|
|
a631a1c0a9 | ||
|
|
8d91f64dbf | ||
|
|
02694cbacc | ||
|
|
3aca4a22ae | ||
|
|
cd0b4a1bbf |
633
Cargo.lock
generated
633
Cargo.lock
generated
@ -145,12 +145,61 @@ version = "0.22.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.69.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"syn",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.13.0",
|
||||
"log",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.5.0"
|
||||
@ -169,15 +218,46 @@ version = "1.2.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
||||
dependencies = [
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"libc",
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.54"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
@ -203,12 +283,79 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.20.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.20.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dyn-clone"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
|
||||
|
||||
[[package]]
|
||||
name = "enumflags2"
|
||||
version = "0.7.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba2f4b465f5318854c6f8dd686ede6c0a9dc67d4b1ac241cf0eb51521a309147"
|
||||
dependencies = [
|
||||
"enumflags2_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "enumflags2_derive"
|
||||
version = "0.7.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc4caf64a58d7a6d65ab00639b046ff54399a39f5f2554728895ace4b297cd79"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
@ -222,25 +369,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exo-rs"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"llama-cpp-2",
|
||||
"mlx-rs",
|
||||
"network-interface",
|
||||
"phf",
|
||||
"prost",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socket2",
|
||||
"system-configuration",
|
||||
"thiserror 2.0.11",
|
||||
"tinygrad",
|
||||
"tokio",
|
||||
"tonic",
|
||||
"tonic-build",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -339,6 +493,12 @@ version = "0.31.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.7"
|
||||
@ -358,6 +518,16 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.12.3"
|
||||
@ -376,6 +546,15 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "home"
|
||||
version = "0.5.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.2.0"
|
||||
@ -475,6 +654,12 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.9.3"
|
||||
@ -495,6 +680,15 @@ dependencies = [
|
||||
"hashbrown 0.15.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.13.0"
|
||||
@ -504,30 +698,90 @@ dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "lazycell"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.169"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
||||
|
||||
[[package]]
|
||||
name = "llama-cpp-2"
|
||||
version = "0.1.93"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44818b6967a77379b8c8e105e2684d2bd2bca999ad24cfd806d8476a80c53255"
|
||||
dependencies = [
|
||||
"enumflags2",
|
||||
"llama-cpp-sys-2",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "llama-cpp-sys-2"
|
||||
version = "0.1.93"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92add8e8dabf941518dd573075721a6e0568db86c6557e176cec99bf56883c0a"
|
||||
dependencies = [
|
||||
"bindgen 0.69.5",
|
||||
"cc",
|
||||
"cmake",
|
||||
"glob",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.12"
|
||||
@ -544,12 +798,31 @@ version = "0.4.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
|
||||
|
||||
[[package]]
|
||||
name = "mach-sys"
|
||||
version = "0.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48460c2e82a3a0de197152fdf8d2c2d5e43adc501501553e439bf2156e6f87c7"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.4"
|
||||
@ -562,6 +835,12 @@ version = "0.3.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
||||
|
||||
[[package]]
|
||||
name = "minimal-lexical"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.8.4"
|
||||
@ -579,7 +858,68 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"windows-sys",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-internal-macros"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0756e4528d38dfd2c30551e3cb05f42b346d4b9fd14a867767d86353232056d"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-macros"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "177ff342309c789defa1552e763ea8fbb5548e3ec17134a45009a27fbddb6c26"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-rs"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c686ead28a57db28004d2c72f940bb3b4366ad01649899cacd06bca495f93ca"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"dyn-clone",
|
||||
"half",
|
||||
"itertools 0.14.0",
|
||||
"libc",
|
||||
"mach-sys",
|
||||
"mlx-internal-macros",
|
||||
"mlx-macros",
|
||||
"mlx-sys",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"num_enum",
|
||||
"parking_lot",
|
||||
"paste",
|
||||
"safetensors",
|
||||
"smallvec",
|
||||
"strum",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-sys"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af33a6b662998e5bb4099b1a191b4352fcb11d97706e82e4c8922fe200bb11f2"
|
||||
dependencies = [
|
||||
"bindgen 0.70.1",
|
||||
"cc",
|
||||
"cmake",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -588,6 +928,19 @@ version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
|
||||
|
||||
[[package]]
|
||||
name = "ndarray"
|
||||
version = "0.15.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||
dependencies = [
|
||||
"matrixmultiply",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "network-interface"
|
||||
version = "2.0.0"
|
||||
@ -600,6 +953,16 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.46.0"
|
||||
@ -610,6 +973,54 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_enum"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179"
|
||||
dependencies = [
|
||||
"num_enum_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_enum_derive"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56"
|
||||
dependencies = [
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.36.7"
|
||||
@ -654,6 +1065,12 @@ dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.1"
|
||||
@ -670,6 +1087,48 @@ dependencies = [
|
||||
"indexmap 2.7.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
|
||||
dependencies = [
|
||||
"phf_macros",
|
||||
"phf_shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf_generator"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d"
|
||||
dependencies = [
|
||||
"phf_shared",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf_macros"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216"
|
||||
dependencies = [
|
||||
"phf_generator",
|
||||
"phf_shared",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf_shared"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
|
||||
dependencies = [
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.1.9"
|
||||
@ -721,6 +1180,15 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-crate"
|
||||
version = "3.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b"
|
||||
dependencies = [
|
||||
"toml_edit",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.93"
|
||||
@ -747,7 +1215,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"itertools",
|
||||
"itertools 0.13.0",
|
||||
"log",
|
||||
"multimap",
|
||||
"once_cell",
|
||||
@ -767,7 +1235,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools",
|
||||
"itertools 0.13.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
@ -821,6 +1289,12 @@ dependencies = [
|
||||
"getrandom 0.2.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.8"
|
||||
@ -865,6 +1339,12 @@ version = "0.1.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
@ -875,7 +1355,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -890,6 +1370,25 @@ version = "1.0.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
|
||||
|
||||
[[package]]
|
||||
name = "safetensors"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0436dbfa2778e4ec1a00801b0ae24a1dd619499247d48b0589b679103379d0d4"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
@ -952,6 +1451,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "siphasher"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
@ -974,7 +1479,35 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.26.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
|
||||
dependencies = [
|
||||
"strum_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1026,7 +1559,7 @@ dependencies = [
|
||||
"getrandom 0.3.1",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
"windows-sys",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1079,6 +1612,15 @@ dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinygrad"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc8951e9ced45095eb3ad7342c4e74b038bda930833df79b7debb019bb653c18"
|
||||
dependencies = [
|
||||
"ndarray",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.43.0"
|
||||
@ -1094,7 +1636,7 @@ dependencies = [
|
||||
"signal-hook-registry",
|
||||
"socket2",
|
||||
"tokio-macros",
|
||||
"windows-sys",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1132,6 +1674,23 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41"
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.22.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
|
||||
dependencies = [
|
||||
"indexmap 2.7.1",
|
||||
"toml_datetime",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
version = "0.12.3"
|
||||
@ -1292,12 +1851,31 @@ version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0"
|
||||
dependencies = [
|
||||
"getrandom 0.3.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
@ -1322,6 +1900,18 @@ dependencies = [
|
||||
"wit-bindgen-rt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
|
||||
dependencies = [
|
||||
"either",
|
||||
"home",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
@ -1338,6 +1928,15 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
@ -1353,6 +1952,15 @@ dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.59.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.6"
|
||||
@ -1417,6 +2025,15 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-rt"
|
||||
version = "0.33.0"
|
||||
|
||||
@ -15,6 +15,13 @@ tracing-subscriber = "0.3"
|
||||
socket2 = "0.5.8"
|
||||
system-configuration = "0.6.1"
|
||||
network-interface = "2.0.0"
|
||||
uuid = { version = "1.13.1", features = ["v4"] }
|
||||
regex = "1.11.1"
|
||||
phf = { version = "0.11.3", features = ["macros"] }
|
||||
mlx-rs = { version = "0.21.0", features = ["metal", "accelerate", "safetensors"] }
|
||||
tinygrad = "0.1.0"
|
||||
llama-cpp-2 = { version = "0.1.93", features = ["metal", "native", "sampler"] }
|
||||
glob = "0.3.2"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.12.3"
|
||||
|
||||
111
src/device_capability_data.rs
Normal file
111
src/device_capability_data.rs
Normal file
@ -0,0 +1,111 @@
|
||||
use crate::topology::DeviceFlops;
|
||||
|
||||
use phf::phf_map;
|
||||
|
||||
const TFLOPS: f64 = 1.00;
|
||||
|
||||
pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
|
||||
// Source: https://www.cpu-monkey.com
|
||||
// Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
|
||||
// M chips
|
||||
"Apple M1" => DeviceFlops { fp32: 2.29*TFLOPS, fp16: 4.58*TFLOPS, int8: 9.16*TFLOPS },
|
||||
"Apple M1 Pro" => DeviceFlops { fp32: 5.30*TFLOPS, fp16: 10.60*TFLOPS, int8: 21.20*TFLOPS },
|
||||
"Apple M1 Max" => DeviceFlops { fp32: 10.60*TFLOPS, fp16: 21.20*TFLOPS, int8: 42.40*TFLOPS },
|
||||
"Apple M1 Ultra" => DeviceFlops { fp32: 21.20*TFLOPS, fp16: 42.40*TFLOPS, int8: 84.80*TFLOPS },
|
||||
"Apple M2" => DeviceFlops { fp32: 3.55*TFLOPS, fp16: 7.10*TFLOPS, int8: 14.20*TFLOPS },
|
||||
"Apple M2 Pro" => DeviceFlops { fp32: 5.68*TFLOPS, fp16: 11.36*TFLOPS, int8: 22.72*TFLOPS },
|
||||
"Apple M2 Max" => DeviceFlops { fp32: 13.49*TFLOPS, fp16: 26.98*TFLOPS, int8: 53.96*TFLOPS },
|
||||
"Apple M2 Ultra" => DeviceFlops { fp32: 26.98*TFLOPS, fp16: 53.96*TFLOPS, int8: 107.92*TFLOPS },
|
||||
"Apple M3" => DeviceFlops { fp32: 3.55*TFLOPS, fp16: 7.10*TFLOPS, int8: 14.20*TFLOPS },
|
||||
"Apple M3 Pro" => DeviceFlops { fp32: 4.97*TFLOPS, fp16: 9.94*TFLOPS, int8: 19.88*TFLOPS },
|
||||
"Apple M3 Max" => DeviceFlops { fp32: 14.20*TFLOPS, fp16: 28.40*TFLOPS, int8: 56.80*TFLOPS },
|
||||
"Apple M4" => DeviceFlops { fp32: 4.26*TFLOPS, fp16: 8.52*TFLOPS, int8: 17.04*TFLOPS },
|
||||
"Apple M4 Pro" => DeviceFlops { fp32: 5.72*TFLOPS, fp16: 11.44*TFLOPS, int8: 22.88*TFLOPS },
|
||||
"Apple M4 Max" => DeviceFlops { fp32: 18.03*TFLOPS, fp16: 36.07*TFLOPS, int8: 72.14*TFLOPS },
|
||||
// A chips
|
||||
"Apple A13 Bionic" => DeviceFlops { fp32: 0.69*TFLOPS, fp16: 1.38*TFLOPS, int8: 2.76*TFLOPS },
|
||||
"Apple A14 Bionic" => DeviceFlops { fp32: 0.75*TFLOPS, fp16: 1.50*TFLOPS, int8: 3.00*TFLOPS },
|
||||
"Apple A15 Bionic" => DeviceFlops { fp32: 1.37*TFLOPS, fp16: 2.74*TFLOPS, int8: 5.48*TFLOPS },
|
||||
"Apple A16 Bionic" => DeviceFlops { fp32: 1.79*TFLOPS, fp16: 3.58*TFLOPS, int8: 7.16*TFLOPS },
|
||||
"Apple A17 Pro" => DeviceFlops { fp32: 2.15*TFLOPS, fp16: 4.30*TFLOPS, int8: 8.60*TFLOPS },
|
||||
// NVIDIA GPUs
|
||||
// RTX 40 series
|
||||
"NVIDIA GEFORCE RTX 4090" => DeviceFlops { fp32: 82.58*TFLOPS, fp16: 165.16*TFLOPS, int8: 330.32*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 4080" => DeviceFlops { fp32: 48.74*TFLOPS, fp16: 97.48*TFLOPS, int8: 194.96*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 4080 SUPER" => DeviceFlops { fp32: 52.0*TFLOPS, fp16: 104.0*TFLOPS, int8: 208.0*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 4070 TI SUPER" => DeviceFlops { fp32: 40.0*TFLOPS, fp16: 80.0*TFLOPS, int8: 160.0*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 4070 TI" => DeviceFlops { fp32: 39.43*TFLOPS, fp16: 78.86*TFLOPS, int8: 157.72*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 4070 SUPER" => DeviceFlops { fp32: 30.0*TFLOPS, fp16: 60.0*TFLOPS, int8: 120.0*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 4070" => DeviceFlops { fp32: 29.0*TFLOPS, fp16: 58.0*TFLOPS, int8: 116.0*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 4060 TI 16GB" => DeviceFlops { fp32: 22.0*TFLOPS, fp16: 44.0*TFLOPS, int8: 88.0*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 4060 TI" => DeviceFlops { fp32: 22.0*TFLOPS, fp16: 44.0*TFLOPS, int8: 88.0*TFLOPS },
|
||||
// RTX 30 series
|
||||
"NVIDIA GEFORCE RTX 3050" => DeviceFlops { fp32: 9.11*TFLOPS, fp16: 18.22*TFLOPS, int8: 36.44*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 3060" => DeviceFlops { fp32: 13.0*TFLOPS, fp16: 26.0*TFLOPS, int8: 52.0*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 3060 TI" => DeviceFlops { fp32: 16.2*TFLOPS, fp16: 32.4*TFLOPS, int8: 64.8*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 3070" => DeviceFlops { fp32: 20.3*TFLOPS, fp16: 40.6*TFLOPS, int8: 81.2*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 3070 TI" => DeviceFlops { fp32: 21.8*TFLOPS, fp16: 43.6*TFLOPS, int8: 87.2*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 3080 (10 GB)" => DeviceFlops { fp32: 29.8*TFLOPS, fp16: 59.6*TFLOPS, int8: 119.2*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 3080 (12 GB)" => DeviceFlops { fp32: 30.6*TFLOPS, fp16: 61.2*TFLOPS, int8: 122.4*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 3080 TI" => DeviceFlops { fp32: 34.1*TFLOPS, fp16: 68.2*TFLOPS, int8: 136.4*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 3090" => DeviceFlops { fp32: 35.6*TFLOPS, fp16: 71.2*TFLOPS, int8: 142.4*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 3090 TI" => DeviceFlops { fp32: 40.0*TFLOPS, fp16: 80.0*TFLOPS, int8: 160.0*TFLOPS },
|
||||
// RTX 20 series
|
||||
"NVIDIA GEFORCE RTX 2060" => DeviceFlops { fp32: 6.45*TFLOPS, fp16: 12.9*TFLOPS, int8: 25.8*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 2060 SUPER" => DeviceFlops { fp32: 7.2*TFLOPS, fp16: 14.4*TFLOPS, int8: 28.8*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 2070" => DeviceFlops { fp32: 7.46*TFLOPS, fp16: 14.93*TFLOPS, int8: 29.86*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 2070 SUPER" => DeviceFlops { fp32: 9.06*TFLOPS, fp16: 18.12*TFLOPS, int8: 36.24*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 2080" => DeviceFlops { fp32: 10.07*TFLOPS, fp16: 20.14*TFLOPS, int8: 40.28*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 2080 TI" => DeviceFlops { fp32: 13.45*TFLOPS, fp16: 26.9*TFLOPS, int8: 40.28*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 2080 SUPER" => DeviceFlops { fp32: 11.15*TFLOPS, fp16: 22.30*TFLOPS, int8: 44.60*TFLOPS },
|
||||
"NVIDIA TITAN RTX" => DeviceFlops { fp32: 16.31*TFLOPS, fp16: 32.62*TFLOPS, int8: 65.24*TFLOPS },
|
||||
// GTX 10 series
|
||||
"NVIDIA GEFORCE GTX 1050 TI" => DeviceFlops { fp32: 2.0*TFLOPS, fp16: 4.0*TFLOPS, int8: 8.0*TFLOPS },
|
||||
"NVIDIA GEFORCE GTX 1070" => DeviceFlops { fp32: 6.463*TFLOPS, fp16: 0.101*TFLOPS, int8: 25.852*TFLOPS },
|
||||
"NVIDIA GEFORCE GTX 1080" => DeviceFlops { fp32: 8.873*TFLOPS, fp16: 0.138*TFLOPS, int8: 35.492*TFLOPS },
|
||||
"NVIDIA GEFORCE GTX 1080 TI" => DeviceFlops { fp32: 11.34*TFLOPS, fp16: 0.177*TFLOPS, int8: 45.36*TFLOPS },
|
||||
// GTX 16 series
|
||||
"NVIDIA GeForce GTX 1660 TI" => DeviceFlops { fp32: 4.8*TFLOPS, fp16: 9.6*TFLOPS, int8: 19.2*TFLOPS },
|
||||
// QUADRO RTX Ampere series
|
||||
"NVIDIA RTX A2000" => DeviceFlops { fp32: 7.99*TFLOPS, fp16: 7.99*TFLOPS, int8: 31.91*TFLOPS },
|
||||
"NVIDIA RTX A4000" => DeviceFlops { fp32: 19.17*TFLOPS, fp16: 19.17*TFLOPS, int8: 76.68*TFLOPS },
|
||||
"NVIDIA RTX A4500" => DeviceFlops { fp32: 23.65*TFLOPS, fp16: 23.65*TFLOPS, int8: 94.6*TFLOPS },
|
||||
"NVIDIA RTX A5000" => DeviceFlops { fp32: 27.8*TFLOPS, fp16: 27.8*TFLOPS, int8: 111.2*TFLOPS },
|
||||
"NVIDIA RTX A6000" => DeviceFlops { fp32: 38.71*TFLOPS, fp16: 38.71*TFLOPS, int8: 154.84*TFLOPS },
|
||||
// NVIDIA Ada Lovelace Architecture-Based
|
||||
"NVIDIA RTX 4000 ADA GENERATION" => DeviceFlops { fp32: 26.7*TFLOPS, fp16: 26.7*TFLOPS, int8: 258.0*TFLOPS },
|
||||
// Common Server GPUs
|
||||
"NVIDIA A40 48GB PCIE" => DeviceFlops { fp32: 37.4*TFLOPS, fp16: 149.7*TFLOPS, int8: 299.3*TFLOPS },
|
||||
"NVIDIA A100 40GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||
"NVIDIA A800 40GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||
"NVIDIA A100 80GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||
"NVIDIA A800 80GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||
"NVIDIA A100 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||
"NVIDIA A800 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||
// AMD GPUs
|
||||
// RX 6000 series
|
||||
"AMD Radeon RX 6900 XT" => DeviceFlops { fp32: 23.04*TFLOPS, fp16: 46.08*TFLOPS, int8: 92.16*TFLOPS },
|
||||
"AMD Radeon RX 6800 XT" => DeviceFlops { fp32: 20.74*TFLOPS, fp16: 41.48*TFLOPS, int8: 82.96*TFLOPS },
|
||||
"AMD Radeon RX 6800" => DeviceFlops { fp32: 16.17*TFLOPS, fp16: 32.34*TFLOPS, int8: 64.68*TFLOPS },
|
||||
"AMD Radeon RX 6700 XT" => DeviceFlops { fp32: 13.21*TFLOPS, fp16: 26.42*TFLOPS, int8: 52.84*TFLOPS },
|
||||
"AMD Radeon RX 6700" => DeviceFlops { fp32: 11.4*TFLOPS, fp16: 22.8*TFLOPS, int8: 45.6*TFLOPS },
|
||||
"AMD Radeon RX 6600 XT" => DeviceFlops { fp32: 10.6*TFLOPS, fp16: 21.2*TFLOPS, int8: 42.4*TFLOPS },
|
||||
"AMD Radeon RX 6600" => DeviceFlops { fp32: 8.93*TFLOPS, fp16: 17.86*TFLOPS, int8: 35.72*TFLOPS },
|
||||
"AMD Radeon RX 6500 XT" => DeviceFlops { fp32: 5.77*TFLOPS, fp16: 11.54*TFLOPS, int8: 23.08*TFLOPS },
|
||||
"AMD Radeon RX 6400" => DeviceFlops { fp32: 3.57*TFLOPS, fp16: 7.14*TFLOPS, int8: 14.28*TFLOPS },
|
||||
// RX 7000 series
|
||||
"AMD Radeon RX 7900 XTX" => DeviceFlops { fp32: 61.4*TFLOPS, fp16: 122.8*TFLOPS, int8: 245.6*TFLOPS },
|
||||
"AMD Radeon RX 7900 XT" => DeviceFlops { fp32: 53.4*TFLOPS, fp16: 106.8*TFLOPS, int8: 213.6*TFLOPS },
|
||||
"AMD Radeon RX 7800 XT" => DeviceFlops { fp32: 42.6*TFLOPS, fp16: 85.2*TFLOPS, int8: 170.4*TFLOPS },
|
||||
"AMD Radeon RX 7700 XT" => DeviceFlops { fp32: 34.2*TFLOPS, fp16: 68.4*TFLOPS, int8: 136.8*TFLOPS },
|
||||
"AMD Radeon RX 7600" => DeviceFlops { fp32: 21.5*TFLOPS, fp16: 43.0*TFLOPS, int8: 86.0*TFLOPS },
|
||||
"AMD Radeon RX 7500" => DeviceFlops { fp32: 16.2*TFLOPS, fp16: 32.4*TFLOPS, int8: 64.8*TFLOPS },
|
||||
};
|
||||
|
||||
pub fn look_up(chip: &str) -> Option<DeviceFlops> {
|
||||
CHIP_FLOPS
|
||||
.get(chip)
|
||||
.or_else(|| CHIP_FLOPS.get(&format!("Laptop GPU {}", chip)))
|
||||
.or_else(|| CHIP_FLOPS.get(&format!("{} Laptop GPU", chip)))
|
||||
.cloned()
|
||||
}
|
||||
@ -1,11 +1,17 @@
|
||||
use crate::network::get_broadcast_creation_info;
|
||||
use crate::orchestration::PeerHandle;
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use socket2::{Domain, Protocol, Socket, Type};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
mod broadcast;
|
||||
mod udp_listen;
|
||||
@ -47,24 +53,49 @@ pub struct NodeInfo {
|
||||
pub device_capabilities: DeviceCapabilities,
|
||||
}
|
||||
|
||||
impl Default for NodeInfo {
|
||||
fn default() -> Self {
|
||||
NodeInfo {
|
||||
node_id: Uuid::new_v4().to_string(),
|
||||
discovery_listen_port: 5678,
|
||||
broadcast_port: 5678,
|
||||
broadcast_interval: Duration::from_secs_f32(2.5),
|
||||
grpc_port: 49152,
|
||||
allowed_peer_ids: None,
|
||||
allowed_interfaces: None,
|
||||
discovery_timeout: Duration::from_secs(30),
|
||||
device_capabilities: DeviceCapabilities::determine(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UdpDiscovery {
|
||||
node_info: NodeInfo,
|
||||
discovery_handle: JoinHandle<()>,
|
||||
presence_handle: JoinHandle<()>,
|
||||
peer_manager_handle: JoinHandle<()>,
|
||||
pub peers: Arc<Mutex<HashMap<String, PeerHandle>>>,
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
pub fn new(node_info: NodeInfo) -> Self {
|
||||
let broadcast_creation_info = get_broadcast_creation_info();
|
||||
let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info));
|
||||
let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone());
|
||||
info!("Found addresses: {:?}", broadcast_creation_info);
|
||||
|
||||
let peers = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
let discovery_handle = tokio::spawn(broadcast::listen_all(
|
||||
node_info.clone(),
|
||||
broadcast_creation_info,
|
||||
));
|
||||
let (presence_handle, peer_manager_handle) =
|
||||
udp_listen::manage_discovery(node_info.clone(), peers.clone());
|
||||
|
||||
UdpDiscovery {
|
||||
node_info,
|
||||
discovery_handle,
|
||||
presence_handle,
|
||||
peer_manager_handle,
|
||||
peers,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
use crate::discovery::{DiscoveryMessage, NodeInfo};
|
||||
use crate::orchestration::PeerHandle;
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use system_configuration::sys::libc::disconnectx;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::select;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tonic::transport::Error;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
async fn listen_for_discovery(
|
||||
@ -119,8 +117,10 @@ async fn handle_new_peer(
|
||||
peers.insert(message.node_id, new_peer);
|
||||
}
|
||||
|
||||
pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>) {
|
||||
let mut peers: HashMap<String, PeerHandle> = HashMap::new();
|
||||
pub fn manage_discovery(
|
||||
node_info: NodeInfo,
|
||||
peers: Arc<Mutex<HashMap<String, PeerHandle>>>,
|
||||
) -> (JoinHandle<()>, JoinHandle<()>) {
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>();
|
||||
|
||||
// TODO: How do we handle killing this?
|
||||
@ -128,11 +128,13 @@ pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>)
|
||||
|
||||
let peer_manager_handle = tokio::spawn(async move {
|
||||
loop {
|
||||
let action = select! {
|
||||
let action: Action = select! {
|
||||
_ = tokio::time::sleep(node_info.discovery_timeout) => Action::HealthChecks,
|
||||
Some((addr, message)) = rx.recv() => Action::NewPeer(addr, message),
|
||||
};
|
||||
|
||||
let mut peers = peers.lock().await;
|
||||
|
||||
match action {
|
||||
Action::NewPeer(addr, message) => handle_new_peer(&mut peers, addr, message).await,
|
||||
Action::HealthChecks => perform_health_checks(&mut peers).await,
|
||||
|
||||
20
src/inference.rs
Normal file
20
src/inference.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use std::collections::HashMap;
|
||||
use crate::node_service::{InferenceState, Tensor};
|
||||
use crate::Shard;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InferenceEngine {
|
||||
// state_cache: HashMap<String, _>,
|
||||
}
|
||||
|
||||
impl InferenceEngine {
|
||||
pub(crate) fn infer_tensor(
|
||||
&self,
|
||||
request_id: String,
|
||||
shard: Shard,
|
||||
tensor: Option<Tensor>,
|
||||
inference_state: Option<InferenceState>,
|
||||
) -> Tensor {
|
||||
|
||||
}
|
||||
}
|
||||
160
src/llama_module.rs
Normal file
160
src/llama_module.rs
Normal file
@ -0,0 +1,160 @@
|
||||
use crate::Shard;
|
||||
use mlx_rs::builder::Builder;
|
||||
use mlx_rs::macros::ModuleParameters;
|
||||
use mlx_rs::module::Module;
|
||||
use mlx_rs::nn::{Embedding, RmsNorm, RmsNormBuilder};
|
||||
use mlx_rs::Array;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::env::args;
|
||||
use mlx_rs::ops::zeros;
|
||||
|
||||
#[derive(Debug, Deserialize, ModuleParameters)]
|
||||
struct ModelArgs {
|
||||
vocab_size: i32,
|
||||
hidden_size: i32,
|
||||
num_hidden_layers: i32,
|
||||
rms_norm_eps: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ShardedLayer {
|
||||
TransformerBlock,
|
||||
IdentityBlock,
|
||||
}
|
||||
|
||||
#[derive(Debug, ModuleParameters)]
|
||||
struct LlamaModel {
|
||||
args: ModelArgs,
|
||||
shard: Shard,
|
||||
layers: Vec<ShardedLayer>,
|
||||
|
||||
embed_tokens: Embedding,
|
||||
norm: RmsNorm,
|
||||
|
||||
cache: Vec<Option<Array>>,
|
||||
}
|
||||
|
||||
impl LlamaModel {
|
||||
fn new(args: ModelArgs, shard: Shard) -> Self {
|
||||
let embed_tokens = Embedding::new(args.vocab_size, args.hidden_size).unwrap();
|
||||
|
||||
let layers = (0..(args.num_hidden_layers as u32)).map(|i| {
|
||||
if shard.start_layer <= i && i <= shard.end_layer {
|
||||
ShardedLayer::TransformerBlock
|
||||
} else {
|
||||
ShardedLayer::IdentityBlock
|
||||
}
|
||||
}).collect::<Vec<_>>();
|
||||
|
||||
let norm = RmsNormBuilder::new(args.hidden_size)
|
||||
.eps(args.rms_norm_eps)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
cache: vec![None; args.num_hidden_layers as usize],
|
||||
args,
|
||||
shard,
|
||||
layers,
|
||||
embed_tokens,
|
||||
norm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<Array> for LlamaModel {
|
||||
type Output = Array;
|
||||
type Error = mlx_rs::error::Exception;
|
||||
|
||||
fn forward(&mut self, input: Array) -> Result<Self::Output, Self::Error> {
|
||||
let h = if self.shard.is_first_layer() {
|
||||
self.embed_tokens.forward(&input)?
|
||||
} else {
|
||||
input
|
||||
};
|
||||
|
||||
let mut mask = if h.ndim() > 1 && h.shape()[1] > 1 {
|
||||
Some(create_attention_mask(&h, &self.cache)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let h = self.layers.iter_mut().zip(self.cache.iter_mut())
|
||||
.fold(h, |h, (layer, c)| {
|
||||
layer.forward(&h, mask.as_ref(), c)?
|
||||
});
|
||||
|
||||
let h = if self.shard.is_last_layer() {
|
||||
self.norm.forward(&h)?
|
||||
} else {
|
||||
h
|
||||
};
|
||||
|
||||
Ok(h)
|
||||
}
|
||||
|
||||
fn training_mode(&mut self, mode: bool) {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn create_attention_mask(h: &Array, cache: &[Option<HashMap<String, i32>>]) -> Result<Array, mlx_rs::error::Exception> {
|
||||
let shape = h.shape();
|
||||
let t = shape[1];
|
||||
|
||||
if t > 1 {
|
||||
let (window_size, offset) = match cache {
|
||||
&[Some(ref cache), ..] => {
|
||||
let offset = *cache.get("offset").unwrap();
|
||||
|
||||
if let Some(max_size) = cache.get("max_size") {
|
||||
(Some(*max_size), i32::min(*max_size, offset))
|
||||
} else {
|
||||
(None, offset)
|
||||
}
|
||||
},
|
||||
_ => (None, 0),
|
||||
};
|
||||
|
||||
|
||||
let mask = create_causal_mask(t, offset, window_size, None)?;
|
||||
mask.as_dtype(h.dtype())
|
||||
} else {
|
||||
Ok(zeros(&[0])) // Return empty array when T <= 1
|
||||
}
|
||||
}
|
||||
|
||||
fn create_causal_mask(
|
||||
n: i32,
|
||||
offset: i32,
|
||||
window_size: Option<i32>,
|
||||
lengths: Option<&Array>
|
||||
) -> Result<Array, mlx_rs::error::Exception> {
|
||||
let rinds = Array::arange(0, offset + n, 1)?;
|
||||
let linds = if offset > 0 {
|
||||
Array::arange(0, offset + n, 1)?
|
||||
} else {
|
||||
rinds.clone()
|
||||
};
|
||||
|
||||
let linds = linds.reshape(&[-1, 1])?;
|
||||
let rinds = rinds.reshape(&[1, -1])?;
|
||||
|
||||
let mut mask = linds.lt(&rinds)?;
|
||||
|
||||
if let Some(w) = window_size {
|
||||
let window_mask = linds.gt(&(rinds + w))?;
|
||||
mask = mask.logical_or(&window_mask)?;
|
||||
}
|
||||
|
||||
if let Some(l) = lengths {
|
||||
let l = l.reshape(&[-1, 1, 1, 1])?;
|
||||
let length_mask = rinds.greater_equal(&l)?;
|
||||
mask = mask.logical_or(&length_mask)?;
|
||||
}
|
||||
|
||||
mask.multiply(-1e9)
|
||||
}
|
||||
|
||||
73
src/llama_test.rs
Normal file
73
src/llama_test.rs
Normal file
@ -0,0 +1,73 @@
|
||||
use llama_cpp_2::context::params::LlamaContextParams;
|
||||
use llama_cpp_2::llama_backend::LlamaBackend;
|
||||
use llama_cpp_2::llama_batch::LlamaBatch;
|
||||
use llama_cpp_2::model::{AddBos, LlamaModel, Special};
|
||||
use llama_cpp_2::model::params::LlamaModelParams;
|
||||
use llama_cpp_2::sampling::LlamaSampler;
|
||||
|
||||
fn test() {
|
||||
let model_path = std::env::args().nth(1).expect("Please specify model path");
|
||||
let backend = LlamaBackend::init().unwrap();
|
||||
let params = LlamaModelParams::default();
|
||||
|
||||
let prompt =
|
||||
"<|im_start|>user\nHello! how are you?<|im_end|>\n<|im_start|>assistant\n".to_string();
|
||||
LlamaContextParams::default();
|
||||
let model =
|
||||
LlamaModel::load_from_file(&backend, model_path, ¶ms).expect("unable to load model");
|
||||
let ctx_params = LlamaContextParams::default();
|
||||
let mut ctx = model
|
||||
.new_context(&backend, ctx_params)
|
||||
.expect("unable to create the llama_context");
|
||||
let tokens_list = model
|
||||
.str_to_token(&prompt, AddBos::Always)
|
||||
.unwrap_or_else(|_| panic!("failed to tokenize {prompt}"));
|
||||
let n_len = 64;
|
||||
|
||||
// create a llama_batch with size 512
|
||||
// we use this object to submit token data for decoding
|
||||
let mut batch = LlamaBatch::new(512, 1);
|
||||
|
||||
let last_index = tokens_list.len() as i32 - 1;
|
||||
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
let is_last = i == last_index;
|
||||
batch.add(token, i, &[0], is_last).unwrap();
|
||||
}
|
||||
ctx.decode(&mut batch).expect("llama_decode() failed");
|
||||
|
||||
let mut n_cur = batch.n_tokens();
|
||||
|
||||
// The `Decoder`
|
||||
// let mut decoder = encoding_rs::UTF_8.new_decoder();
|
||||
let mut sampler = LlamaSampler::greedy();
|
||||
|
||||
while n_cur <= n_len {
|
||||
// sample the next token
|
||||
{
|
||||
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
|
||||
|
||||
sampler.accept(token);
|
||||
|
||||
// is it an end of stream?
|
||||
if token == model.token_eos() {
|
||||
eprintln!();
|
||||
break;
|
||||
}
|
||||
|
||||
let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap();
|
||||
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
|
||||
let mut output_string = String::with_capacity(32);
|
||||
// let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
|
||||
// print!("{output_string}");
|
||||
// std::io::stdout().flush().unwrap();
|
||||
|
||||
batch.clear();
|
||||
batch.add(token, n_cur, &[0], true).unwrap();
|
||||
}
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
ctx.decode(&mut batch).expect("failed to eval");
|
||||
}
|
||||
}
|
||||
181
src/main.rs
181
src/main.rs
@ -1,32 +1,100 @@
|
||||
mod topology;
|
||||
mod orchestration;
|
||||
mod device_capability_data;
|
||||
mod discovery;
|
||||
mod inference;
|
||||
mod network;
|
||||
mod orchestration;
|
||||
mod partitioning;
|
||||
mod topology;
|
||||
mod llama_test;
|
||||
mod module_loading;
|
||||
mod llama_module;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tonic::{transport::Server, Request, Response, Status};
|
||||
|
||||
use crate::discovery::{NodeInfo, UdpDiscovery};
|
||||
use crate::inference::InferenceEngine;
|
||||
use crate::node_service::{
|
||||
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
|
||||
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
|
||||
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse,
|
||||
InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor,
|
||||
Topology as TopologyProto,
|
||||
};
|
||||
use node_service::node_service_server::{NodeService, NodeServiceServer};
|
||||
use node_service::TensorRequest;
|
||||
use std::collections::HashSet;
|
||||
use topology::Topology;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod node_service {
|
||||
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Node {
|
||||
node_info: NodeInfo,
|
||||
current_topology: Topology,
|
||||
udp_discovery: UdpDiscovery,
|
||||
inference_engine: InferenceEngine,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
#[tracing::instrument]
|
||||
pub(crate) async fn process_tensor(
|
||||
&self,
|
||||
base_shard: Shard,
|
||||
tensor: Option<Tensor>,
|
||||
request_id: String,
|
||||
inference_state: Option<InferenceState>,
|
||||
) -> Tensor {
|
||||
let shard = self
|
||||
.current_topology
|
||||
.get_shard_for_node(base_shard, &self.node_info.node_id);
|
||||
|
||||
let result: Tensor = self
|
||||
.inference_engine
|
||||
.infer_tensor(request_id, shard, tensor, inference_state);
|
||||
|
||||
let result = self.process_inference_result(shard, result, request_id, inference_state);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn process_prompt(
|
||||
&self,
|
||||
base_shard: Shard,
|
||||
prompt: String,
|
||||
request_id: String,
|
||||
inference_state: Option<InferenceState>,
|
||||
) {
|
||||
let shard = self
|
||||
.current_topology
|
||||
.get_shard_for_node(base_shard, &self.node_info.node_id);
|
||||
|
||||
todo!();
|
||||
// The python code is a little weird wrt return types here
|
||||
|
||||
// if shard.is_first_layer() {
|
||||
// let result = self
|
||||
// .inference_engine
|
||||
// .infer_prompt(request_id, shard, prompt, inference_state)
|
||||
// .await;
|
||||
// self.process_inference_result(shard, result, request_id, inference_state)
|
||||
// } else {
|
||||
// self.forward_prompt(shard, prompt, request_id, inference_state)
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Node {
|
||||
fn default() -> Self {
|
||||
let node_info = NodeInfo::default();
|
||||
|
||||
Self {
|
||||
node_info: node_info.clone(),
|
||||
current_topology: Topology::default(),
|
||||
udp_discovery: UdpDiscovery::new(node_info),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -42,11 +110,36 @@ enum OpaqueStatus {
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct Shard {
|
||||
pub model_id: String,
|
||||
pub start_layer: i32,
|
||||
pub end_layer: i32,
|
||||
pub n_layers: i32,
|
||||
pub start_layer: u32,
|
||||
pub end_layer: u32,
|
||||
#[serde(rename = "n_layers")]
|
||||
pub total_layers: u32,
|
||||
}
|
||||
|
||||
impl Shard {
|
||||
pub fn is_first_layer(&self) -> bool {
|
||||
self.start_layer == 0
|
||||
}
|
||||
|
||||
pub fn is_last_layer(&self) -> bool {
|
||||
self.end_layer == self.total_layers - 1
|
||||
}
|
||||
|
||||
pub fn len(&self) -> u32 {
|
||||
self.end_layer - self.start_layer + 1
|
||||
}
|
||||
}
|
||||
|
||||
impl From<node_service::Shard> for Shard {
|
||||
fn from(proto: node_service::Shard) -> Self {
|
||||
Self {
|
||||
model_id: proto.model_id,
|
||||
start_layer: proto.start_layer as u32,
|
||||
end_layer: proto.end_layer as u32,
|
||||
total_layers: proto.n_layers as u32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct NodeStatus {
|
||||
@ -132,14 +225,36 @@ impl NodeService for Node {
|
||||
&self,
|
||||
request: Request<PromptRequest>,
|
||||
) -> Result<Response<Tensor>, Status> {
|
||||
todo!()
|
||||
let request = request.into_inner();
|
||||
let request_id = request
|
||||
.request_id
|
||||
.unwrap_or_else(|| Uuid::new_v4().to_string());
|
||||
|
||||
let result = self.process_prompt(
|
||||
request
|
||||
.shard
|
||||
.expect("No shard given. ExoPy does not allow this")
|
||||
.into(),
|
||||
request.prompt,
|
||||
request_id,
|
||||
request.inference_state,
|
||||
);
|
||||
|
||||
todo!();
|
||||
}
|
||||
|
||||
async fn send_tensor(
|
||||
&self,
|
||||
request: Request<TensorRequest>,
|
||||
) -> Result<Response<Tensor>, Status> {
|
||||
todo!()
|
||||
let request = request.into_inner();
|
||||
let shard = request.shard.expect("No shard provided").into();
|
||||
let request_id = request.request_id.expect("No request id provided");
|
||||
|
||||
let result =
|
||||
self.process_tensor(shard, request.tensor, request_id, request.inference_state);
|
||||
|
||||
Ok(Response::new(result.into()))
|
||||
}
|
||||
|
||||
async fn send_example(
|
||||
@ -149,11 +264,19 @@ impl NodeService for Node {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// TODO: Why aren't we using the request?
|
||||
async fn collect_topology(
|
||||
&self,
|
||||
request: Request<CollectTopologyRequest>,
|
||||
) -> Result<Response<TopologyProto>, Status> {
|
||||
todo!()
|
||||
let request = request.into_inner();
|
||||
let max_depth = request.max_depth as u8;
|
||||
let visited = request.visited;
|
||||
|
||||
self.update_topology_inner(max_depth, visited.into_iter().collect())
|
||||
.await;
|
||||
|
||||
Ok(Response::new(self.current_topology.clone().into()))
|
||||
}
|
||||
|
||||
async fn send_result(
|
||||
@ -186,13 +309,49 @@ impl NodeService for Node {
|
||||
}
|
||||
}
|
||||
|
||||
impl Node {
|
||||
async fn update_topology(&mut self) {
|
||||
let overall_max_depth = 4;
|
||||
let visited: HashSet<String> = HashSet::new();
|
||||
|
||||
self.current_topology = self.update_topology_inner(overall_max_depth, visited).await;
|
||||
}
|
||||
|
||||
async fn update_topology_inner(&self, max_depth: u8, mut visited: HashSet<String>) -> Topology {
|
||||
let mut new_topology = Topology::default();
|
||||
|
||||
new_topology.update_node(
|
||||
self.node_info.node_id.clone(),
|
||||
self.node_info.device_capabilities.clone(),
|
||||
);
|
||||
|
||||
for peer in self.udp_discovery.peers.lock().await.values() {
|
||||
new_topology.update_node(peer.node_id.clone(), peer.device_capabilities.clone());
|
||||
new_topology.update_edge(
|
||||
self.node_info.node_id.clone(),
|
||||
peer.node_id.clone(),
|
||||
peer.description.clone(),
|
||||
);
|
||||
|
||||
visited.insert(peer.node_id.clone());
|
||||
|
||||
if !visited.contains(&peer.node_id) {
|
||||
let topology = peer.collect_topology(visited.clone(), max_depth - 1).await;
|
||||
new_topology.merge_restricted(&peer.node_id, topology);
|
||||
}
|
||||
}
|
||||
|
||||
new_topology
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// install global collector configured based on RUST_LOG env var.
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let grpc_addr = "[::1]:50051".parse()?;
|
||||
let node = Node::default();
|
||||
let node: Node = Node::default();
|
||||
|
||||
// TODO: Also implement discovery
|
||||
|
||||
|
||||
63
src/module_loading.rs
Normal file
63
src/module_loading.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use serde_json::Value;
|
||||
use crate::Shard;
|
||||
|
||||
fn load_config(
|
||||
model_path: &Path,
|
||||
) -> serde_json::Map<String, serde_json::Value> {
|
||||
let config_path = model_path.join("config.json");
|
||||
let model_index = model_path.join("model_index.json");
|
||||
|
||||
if config_path.exists() {
|
||||
let config = std::fs::read_to_string(config_path).unwrap();
|
||||
serde_json::from_str(&config).unwrap()
|
||||
} else {
|
||||
let model_index = std::fs::read_to_string(model_index).unwrap();
|
||||
serde_json::from_str(&model_index).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_model_shard(
|
||||
model_path: &Path,
|
||||
shard: Shard,
|
||||
lazy: bool,
|
||||
model_config: serde_json::Map<String, serde_json::Value>,
|
||||
) {
|
||||
let mut config = load_config(model_path);
|
||||
config.extend(model_config.into_iter());
|
||||
|
||||
let model_name = model_path.file_name().unwrap().to_str().unwrap();
|
||||
|
||||
config["shard"] = serde_json::json!({
|
||||
"model_id": model_name,
|
||||
"start_layer": shard.start_layer,
|
||||
"end_layer": shard.end_layer,
|
||||
"n_layers": shard.total_layers,
|
||||
});
|
||||
|
||||
let mut weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
|
||||
.unwrap()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap();
|
||||
|
||||
let model_args = serde_json::from_value::<ModelArgsRs>(Value::Object(config)).unwrap();
|
||||
let model = LlamaModelRs::new(model_args).unwrap();
|
||||
|
||||
weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap().to_string());
|
||||
|
||||
let mut weights = HashMap::new();
|
||||
|
||||
for weight_file in weight_files {
|
||||
weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap());
|
||||
}
|
||||
|
||||
todo!();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_llama() {
|
||||
load_model_shard(
|
||||
|
||||
)
|
||||
}
|
||||
@ -1,13 +1,9 @@
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use socket2::{Domain, Protocol, Socket, Type};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
use std::net::Ipv4Addr;
|
||||
use system_configuration::core_foundation::array::CFArray;
|
||||
use system_configuration::network_configuration::{get_interfaces, SCNetworkInterface, SCNetworkInterfaceType};
|
||||
use system_configuration::sys::network_configuration::SCNetworkInterfaceRef;
|
||||
use tokio::net::UdpSocket;
|
||||
use system_configuration::network_configuration::{
|
||||
get_interfaces, SCNetworkInterface, SCNetworkInterfaceType,
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum InterfaceType {
|
||||
@ -76,7 +72,8 @@ pub fn get_broadcast_creation_info() -> Vec<BroadcastCreationInfo> {
|
||||
|
||||
for (bind_address, broadcast_address) in map {
|
||||
let interface_type = if cfg!(target_os = "macos") {
|
||||
get_sc_interface_type(&interface.name, &sc_interfaces).unwrap_or_else(|| determine_interface_type(&interface.name))
|
||||
get_sc_interface_type(&interface.name, &sc_interfaces)
|
||||
.unwrap_or_else(|| determine_interface_type(&interface.name))
|
||||
} else {
|
||||
determine_interface_type(&interface.name)
|
||||
};
|
||||
@ -94,7 +91,10 @@ pub fn get_broadcast_creation_info() -> Vec<BroadcastCreationInfo> {
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
fn get_sc_interface_type(name: &str, sc_interfaces: &CFArray<SCNetworkInterface>) -> Option<InterfaceType> {
|
||||
fn get_sc_interface_type(
|
||||
name: &str,
|
||||
sc_interfaces: &CFArray<SCNetworkInterface>,
|
||||
) -> Option<InterfaceType> {
|
||||
sc_interfaces.iter().find_map(|sc_if| {
|
||||
sc_if
|
||||
.bsd_name()
|
||||
@ -110,9 +110,14 @@ fn get_sc_interface_type(name: &str, sc_interfaces: &CFArray<SCNetworkInterface>
|
||||
|
||||
fn determine_interface_type(name: &str) -> InterfaceType {
|
||||
// Fallback to interface name pattern matching
|
||||
if name.starts_with("docker") || name.starts_with("br-") || name.starts_with("veth")
|
||||
|| name.starts_with("cni") || name.starts_with("flannel") || name.starts_with("calico")
|
||||
|| name.starts_with("weave") || name.contains("bridge")
|
||||
if name.starts_with("docker")
|
||||
|| name.starts_with("br-")
|
||||
|| name.starts_with("veth")
|
||||
|| name.starts_with("cni")
|
||||
|| name.starts_with("flannel")
|
||||
|| name.starts_with("calico")
|
||||
|| name.starts_with("weave")
|
||||
|| name.contains("bridge")
|
||||
{
|
||||
InterfaceType::ContainerVirtual
|
||||
} else if name.starts_with("lo") {
|
||||
@ -122,13 +127,20 @@ fn determine_interface_type(name: &str) -> InterfaceType {
|
||||
} else if (name.starts_with("eth") || name.starts_with("en")) && !matches!(name, "en0" | "en1")
|
||||
{
|
||||
InterfaceType::Ethernet
|
||||
} else if name.starts_with("wlan") || name.starts_with("wifi") || name.starts_with("wl")
|
||||
|| matches!(name, "en0" | "en1")
|
||||
} else if name.starts_with("wlan")
|
||||
|| name.starts_with("wifi")
|
||||
|| name.starts_with("wl")
|
||||
|| matches!(name, "en0" | "en1")
|
||||
{
|
||||
InterfaceType::WiFi
|
||||
} else if name.starts_with("tun") || name.starts_with("tap") || name.starts_with("vtun")
|
||||
|| name.starts_with("utun") || name.starts_with("gif") || name.starts_with("stf")
|
||||
|| name.starts_with("awdl") || name.starts_with("llw")
|
||||
} else if name.starts_with("tun")
|
||||
|| name.starts_with("tap")
|
||||
|| name.starts_with("vtun")
|
||||
|| name.starts_with("utun")
|
||||
|| name.starts_with("gif")
|
||||
|| name.starts_with("stf")
|
||||
|| name.starts_with("awdl")
|
||||
|| name.starts_with("llw")
|
||||
{
|
||||
InterfaceType::ExternalVirtual
|
||||
} else {
|
||||
|
||||
@ -1,15 +1,17 @@
|
||||
use crate::node_service::node_service_client::NodeServiceClient;
|
||||
use crate::node_service::HealthCheckRequest;
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use crate::node_service::{CollectTopologyRequest, HealthCheckRequest};
|
||||
use crate::topology::{DeviceCapabilities, Topology};
|
||||
use std::collections::HashSet;
|
||||
use std::net::SocketAddr;
|
||||
use tonic::codec::CompressionEncoding;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PeerHandle {
|
||||
pub node_id: String,
|
||||
pub address: SocketAddr,
|
||||
pub address_priority: u8,
|
||||
pub description: Option<String>,
|
||||
pub client: tokio::sync::Mutex<NodeServiceClient<tonic::transport::Channel>>,
|
||||
client: NodeServiceClient<tonic::transport::Channel>,
|
||||
pub device_capabilities: DeviceCapabilities,
|
||||
}
|
||||
|
||||
@ -31,19 +33,35 @@ impl PeerHandle {
|
||||
description,
|
||||
address_priority,
|
||||
address,
|
||||
client: tokio::sync::Mutex::new(client),
|
||||
client,
|
||||
device_capabilities,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn client(&self) -> NodeServiceClient<tonic::transport::Channel> {
|
||||
self.client.clone()
|
||||
}
|
||||
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
self.client
|
||||
.lock()
|
||||
.await
|
||||
self.client()
|
||||
.health_check(HealthCheckRequest::default())
|
||||
.await
|
||||
.ok()
|
||||
.map(|x| x.into_inner().is_healthy)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub async fn collect_topology(&self, visited: HashSet<String>, max_depth: u8) -> Topology {
|
||||
let response = self
|
||||
.client()
|
||||
.collect_topology(CollectTopologyRequest {
|
||||
visited: visited.clone().into_iter().collect(),
|
||||
max_depth: max_depth as i32,
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.into_inner();
|
||||
|
||||
response.into()
|
||||
}
|
||||
}
|
||||
|
||||
62
src/partitioning.rs
Normal file
62
src/partitioning.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use crate::Shard;
|
||||
use crate::topology::Topology;
|
||||
|
||||
pub enum PartitionStrategy {
|
||||
RingMemoryWeighted,
|
||||
}
|
||||
|
||||
pub struct Partition {
|
||||
pub node_id: String,
|
||||
pub start: f32,
|
||||
pub end: f32,
|
||||
}
|
||||
|
||||
impl PartitionStrategy {
|
||||
pub fn partition(&self, topology: &Topology) -> Vec<Partition> {
|
||||
let mut entries = topology.nodes.iter().collect::<Vec<_>>();
|
||||
entries.sort_by_key(|(node_id, device_capabilities)| {
|
||||
(device_capabilities.memory, node_id.clone())
|
||||
});
|
||||
|
||||
let mut start = 0.0;
|
||||
let mut partitions = Vec::with_capacity(entries.len());
|
||||
|
||||
for (node_id, device_capabilities) in entries {
|
||||
let end = ((start + device_capabilities.memory as f32) * 100000.0).round() / 100000.0;
|
||||
partitions.push(Partition {
|
||||
node_id: node_id.to_string(),
|
||||
start,
|
||||
end,
|
||||
});
|
||||
start = end;
|
||||
}
|
||||
|
||||
partitions
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shard_model_by_partition(
|
||||
partition_set: &[Partition],
|
||||
total_layers: u32,
|
||||
model_id: &str,
|
||||
) -> Vec<Shard> {
|
||||
let mut shards: Vec<Shard> = Vec::with_capacity(partition_set.len());
|
||||
|
||||
for partition in partition_set {
|
||||
let start_layer = (partition.start * total_layers as f32).round() as u32;
|
||||
let mut end_layer = (partition.end * total_layers as f32).round() as u32 - 1;
|
||||
|
||||
if end_layer < start_layer {
|
||||
end_layer = total_layers - 1;
|
||||
}
|
||||
|
||||
shards.push(Shard {
|
||||
model_id: model_id.to_string(),
|
||||
start_layer,
|
||||
end_layer,
|
||||
total_layers,
|
||||
});
|
||||
}
|
||||
|
||||
shards
|
||||
}
|
||||
254
src/topology.rs
254
src/topology.rs
@ -1,11 +1,152 @@
|
||||
use std::collections::HashMap;
|
||||
use crate::partitioning::{shard_model_by_partition, PartitionStrategy};
|
||||
use crate::{device_capability_data, node_service, Shard};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::process::Command;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct Topology {
|
||||
nodes: HashMap<String, DeviceCapabilities>,
|
||||
peer_graph: HashMap<String, Vec<PeerConnection>>,
|
||||
active_node_id: Option<String>
|
||||
pub nodes: HashMap<String, DeviceCapabilities>,
|
||||
pub peer_graph: HashMap<String, Vec<PeerConnection>>,
|
||||
pub active_node_id: Option<String>,
|
||||
}
|
||||
|
||||
impl Topology {
|
||||
pub fn get_shard_for_node(&self, base_shard: Shard, node_id: &str) -> Shard {
|
||||
let partition_set = PartitionStrategy::RingMemoryWeighted.partition(&self);
|
||||
|
||||
// TODO: This feels like it could be a better data structure
|
||||
let partition_index = partition_set
|
||||
.iter()
|
||||
.position(|s| s.node_id == node_id)
|
||||
.expect("Did not find node in partition set");
|
||||
|
||||
let shards = shard_model_by_partition(
|
||||
&partition_set,
|
||||
base_shard.total_layers.try_into().unwrap(),
|
||||
base_shard.model_id.as_str(),
|
||||
);
|
||||
|
||||
shards[partition_index].clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Topology {
|
||||
pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) {
|
||||
self.nodes.insert(node_id, device_capabilities);
|
||||
}
|
||||
|
||||
pub fn update_edge(&mut self, from_id: String, to_id: String, description: Option<String>) {
|
||||
let conn = PeerConnection {
|
||||
from_id: from_id.clone(),
|
||||
to_id,
|
||||
description,
|
||||
};
|
||||
|
||||
match self.peer_graph.get_mut(&from_id) {
|
||||
None => {
|
||||
self.peer_graph.insert(from_id, vec![conn]);
|
||||
}
|
||||
|
||||
Some(existing) => {
|
||||
existing.push(conn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn merge_restricted(&mut self, from_peer_id: &str, topology: Topology) {
|
||||
if let Some(peer_capabilities) = topology.nodes.get(from_peer_id) {
|
||||
self.nodes
|
||||
.insert(from_peer_id.to_string(), peer_capabilities.clone());
|
||||
}
|
||||
|
||||
self.peer_graph.extend(
|
||||
topology
|
||||
.peer_graph
|
||||
.into_iter()
|
||||
.filter(|(id, _)| id == from_peer_id),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<node_service::Topology> for Topology {
|
||||
fn from(proto: node_service::Topology) -> Self {
|
||||
let nodes = proto
|
||||
.nodes
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.into()))
|
||||
.collect();
|
||||
|
||||
let peer_graph = proto
|
||||
.peer_graph
|
||||
.into_iter()
|
||||
.map(|(from_id, connections)| {
|
||||
(
|
||||
from_id.clone(),
|
||||
connections
|
||||
.connections
|
||||
.into_iter()
|
||||
.map(|pc| PeerConnection {
|
||||
from_id: from_id.clone(),
|
||||
to_id: pc.to_id,
|
||||
description: pc.description,
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Topology {
|
||||
nodes,
|
||||
peer_graph,
|
||||
active_node_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<node_service::Topology> for Topology {
|
||||
fn into(self) -> node_service::Topology {
|
||||
let nodes = self
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|(node_id, cap)| {
|
||||
(
|
||||
node_id.clone(),
|
||||
node_service::DeviceCapabilities {
|
||||
model: cap.model.clone(),
|
||||
chip: cap.chip.clone(),
|
||||
memory: cap.memory as i32,
|
||||
flops: Some(node_service::DeviceFlops {
|
||||
fp32: cap.flops.fp32,
|
||||
fp16: cap.flops.fp16,
|
||||
int8: cap.flops.int8,
|
||||
}),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<HashMap<String, node_service::DeviceCapabilities>>();
|
||||
|
||||
let peer_graph = self
|
||||
.peer_graph
|
||||
.iter()
|
||||
.map(|(node_id, connections)| {
|
||||
(
|
||||
node_id.clone(),
|
||||
node_service::PeerConnections {
|
||||
connections: connections
|
||||
.iter()
|
||||
.map(|conn| node_service::PeerConnection {
|
||||
to_id: conn.to_id.clone(),
|
||||
description: conn.description.clone(),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<HashMap<String, node_service::PeerConnections>>();
|
||||
|
||||
node_service::Topology { nodes, peer_graph }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Topology {
|
||||
@ -20,22 +161,107 @@ impl Default for Topology {
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct DeviceCapabilities {
|
||||
model: String,
|
||||
chip: String,
|
||||
memory: u64,
|
||||
flops: DeviceFlops,
|
||||
pub model: String,
|
||||
pub chip: String,
|
||||
pub memory: u64,
|
||||
pub flops: DeviceFlops,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct SystemProfilerOutputData {
|
||||
#[serde(rename = "SPHardwareDataType")]
|
||||
hardware: Vec<SPHardwareDataType>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct SPHardwareDataType {
|
||||
#[serde(rename = "_name")]
|
||||
name: String,
|
||||
activation_lock_status: String,
|
||||
boot_rom_version: String,
|
||||
chip_type: String,
|
||||
machine_model: String,
|
||||
machine_name: String,
|
||||
model_number: String,
|
||||
number_processors: String,
|
||||
os_loader_version: String,
|
||||
physical_memory: String,
|
||||
#[serde(rename = "platform_UUID")]
|
||||
platform_uuid: String,
|
||||
#[serde(rename = "provisioning_UDID")]
|
||||
provisioning_udid: String,
|
||||
serial_number: String,
|
||||
}
|
||||
|
||||
impl DeviceCapabilities {
|
||||
pub fn determine() -> DeviceCapabilities {
|
||||
let s = Command::new("system_profiler")
|
||||
.arg("SPHardwareDataType")
|
||||
.arg("-json")
|
||||
.output()
|
||||
.unwrap()
|
||||
.stdout;
|
||||
|
||||
let mut data = serde_json::from_slice::<SystemProfilerOutputData>(&s).unwrap();
|
||||
let hardware = data.hardware.remove(0);
|
||||
|
||||
let model = hardware.machine_name;
|
||||
let chip = hardware.chip_type;
|
||||
let memory = {
|
||||
let parts: Vec<&str> = hardware.physical_memory.split_ascii_whitespace().collect();
|
||||
if parts.len() >= 2 {
|
||||
let value = parts[0].parse::<u64>().unwrap_or(0);
|
||||
if parts[1] == "GB" {
|
||||
value * 1024
|
||||
} else {
|
||||
value
|
||||
}
|
||||
} else {
|
||||
0
|
||||
}
|
||||
};
|
||||
|
||||
DeviceCapabilities {
|
||||
flops: device_capability_data::look_up(&chip)
|
||||
.expect("Failed to find FLOPS data for chip"),
|
||||
model,
|
||||
chip,
|
||||
memory,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<node_service::DeviceCapabilities> for DeviceCapabilities {
|
||||
fn from(value: node_service::DeviceCapabilities) -> Self {
|
||||
DeviceCapabilities {
|
||||
model: value.model,
|
||||
chip: value.chip,
|
||||
memory: value.memory as u64,
|
||||
flops: value.flops.map(|x| x.into()).unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||
pub struct DeviceFlops {
|
||||
fp32: u64,
|
||||
fp16: u64,
|
||||
int8: u64,
|
||||
pub fp32: f64,
|
||||
pub fp16: f64,
|
||||
pub int8: f64,
|
||||
}
|
||||
|
||||
impl From<node_service::DeviceFlops> for DeviceFlops {
|
||||
fn from(value: node_service::DeviceFlops) -> Self {
|
||||
DeviceFlops {
|
||||
fp32: value.fp32,
|
||||
fp16: value.fp16,
|
||||
int8: value.int8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct PeerConnection {
|
||||
from_id: String,
|
||||
to_id: String,
|
||||
description: Option<String>,
|
||||
pub from_id: String,
|
||||
pub to_id: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user