Compare commits

...

10 Commits

Author SHA1 Message Date
Joshua Coles
e551efc0a4 Stash 2025-02-12 20:41:08 +00:00
Joshua Coles
01e352bc2e Stash so much code 2025-02-12 17:32:09 +00:00
Joshua Coles
a129325ade Stash gen of LLAMA mlx 2025-02-12 17:18:40 +00:00
Joshua Coles
4cd96b58b5 Stash gen of LLAMA mlx 2025-02-12 17:07:06 +00:00
Joshua Coles
39ac5a86dd Add basic process tensor 2025-02-12 15:00:43 +00:00
Joshua Coles
a631a1c0a9 Stash initial attempt at implementing process prompt before being sad 2025-02-12 14:55:26 +00:00
Joshua Coles
8d91f64dbf Reduce warning count 2025-02-12 14:23:54 +00:00
Joshua Coles
02694cbacc Add partition and shards 2025-02-12 14:09:07 +00:00
Joshua Coles
3aca4a22ae Implement collect topology 2025-02-12 13:01:14 +00:00
Joshua Coles
cd0b4a1bbf Maybe it works! 2025-02-12 12:02:41 +00:00
14 changed files with 1630 additions and 69 deletions

633
Cargo.lock generated
View File

@ -145,12 +145,61 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn",
"which",
]
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags",
"cexpr",
"clang-sys",
"itertools 0.13.0",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn",
]
[[package]]
name = "bitflags"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
[[package]]
name = "bytemuck"
version = "1.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3"
[[package]]
name = "byteorder"
version = "1.5.0"
@ -169,15 +218,46 @@ version = "1.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda"
dependencies = [
"jobserver",
"libc",
"shlex",
]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@ -203,12 +283,79 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "crunchy"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
[[package]]
name = "darling"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn",
]
[[package]]
name = "darling_macro"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core",
"quote",
"syn",
]
[[package]]
name = "dyn-clone"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35"
[[package]]
name = "either"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "enumflags2"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba2f4b465f5318854c6f8dd686ede6c0a9dc67d4b1ac241cf0eb51521a309147"
dependencies = [
"enumflags2_derive",
]
[[package]]
name = "enumflags2_derive"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc4caf64a58d7a6d65ab00639b046ff54399a39f5f2554728895ace4b297cd79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "equivalent"
version = "1.0.1"
@ -222,25 +369,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
dependencies = [
"libc",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
name = "exo-rs"
version = "0.1.0"
dependencies = [
"glob",
"llama-cpp-2",
"mlx-rs",
"network-interface",
"phf",
"prost",
"regex",
"serde",
"serde_json",
"socket2",
"system-configuration",
"thiserror 2.0.11",
"tinygrad",
"tokio",
"tonic",
"tonic-build",
"tracing",
"tracing-subscriber",
"uuid",
]
[[package]]
@ -339,6 +493,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "h2"
version = "0.4.7"
@ -358,6 +518,16 @@ dependencies = [
"tracing",
]
[[package]]
name = "half"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [
"cfg-if",
"crunchy",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
@ -376,6 +546,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "http"
version = "1.2.0"
@ -475,6 +654,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "indexmap"
version = "1.9.3"
@ -495,6 +680,15 @@ dependencies = [
"hashbrown 0.15.2",
]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.13.0"
@ -504,30 +698,90 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]]
name = "jobserver"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
dependencies = [
"libc",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libloading"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]]
name = "llama-cpp-2"
version = "0.1.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44818b6967a77379b8c8e105e2684d2bd2bca999ad24cfd806d8476a80c53255"
dependencies = [
"enumflags2",
"llama-cpp-sys-2",
"thiserror 1.0.69",
"tracing",
"tracing-core",
]
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92add8e8dabf941518dd573075721a6e0568db86c6557e176cec99bf56883c0a"
dependencies = [
"bindgen 0.69.5",
"cc",
"cmake",
"glob",
"walkdir",
]
[[package]]
name = "lock_api"
version = "0.4.12"
@ -544,12 +798,31 @@ version = "0.4.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
[[package]]
name = "mach-sys"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48460c2e82a3a0de197152fdf8d2c2d5e43adc501501553e439bf2156e6f87c7"
dependencies = [
"fastrand",
]
[[package]]
name = "matchit"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matrixmultiply"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]]
name = "memchr"
version = "2.7.4"
@ -562,6 +835,12 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "minimal-lexical"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.8.4"
@ -579,7 +858,68 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd"
dependencies = [
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
name = "mlx-internal-macros"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0756e4528d38dfd2c30551e3cb05f42b346d4b9fd14a867767d86353232056d"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "mlx-macros"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "177ff342309c789defa1552e763ea8fbb5548e3ec17134a45009a27fbddb6c26"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "mlx-rs"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c686ead28a57db28004d2c72f940bb3b4366ad01649899cacd06bca495f93ca"
dependencies = [
"bytemuck",
"dyn-clone",
"half",
"itertools 0.14.0",
"libc",
"mach-sys",
"mlx-internal-macros",
"mlx-macros",
"mlx-sys",
"num-complex",
"num-traits",
"num_enum",
"parking_lot",
"paste",
"safetensors",
"smallvec",
"strum",
"thiserror 1.0.69",
]
[[package]]
name = "mlx-sys"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af33a6b662998e5bb4099b1a191b4352fcb11d97706e82e4c8922fe200bb11f2"
dependencies = [
"bindgen 0.70.1",
"cc",
"cmake",
]
[[package]]
@ -588,6 +928,19 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "ndarray"
version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"rawpointer",
]
[[package]]
name = "network-interface"
version = "2.0.0"
@ -600,6 +953,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [
"memchr",
"minimal-lexical",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@ -610,6 +973,54 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "num_enum"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179"
dependencies = [
"num_enum_derive",
]
[[package]]
name = "num_enum_derive"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56"
dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "object"
version = "0.36.7"
@ -654,6 +1065,12 @@ dependencies = [
"windows-targets",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "percent-encoding"
version = "2.3.1"
@ -670,6 +1087,48 @@ dependencies = [
"indexmap 2.7.1",
]
[[package]]
name = "phf"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
dependencies = [
"phf_macros",
"phf_shared",
]
[[package]]
name = "phf_generator"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d"
dependencies = [
"phf_shared",
"rand",
]
[[package]]
name = "phf_macros"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216"
dependencies = [
"phf_generator",
"phf_shared",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "phf_shared"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
dependencies = [
"siphasher",
]
[[package]]
name = "pin-project"
version = "1.1.9"
@ -721,6 +1180,15 @@ dependencies = [
"syn",
]
[[package]]
name = "proc-macro-crate"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b"
dependencies = [
"toml_edit",
]
[[package]]
name = "proc-macro2"
version = "1.0.93"
@ -747,7 +1215,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b"
dependencies = [
"heck",
"itertools",
"itertools 0.13.0",
"log",
"multimap",
"once_cell",
@ -767,7 +1235,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3"
dependencies = [
"anyhow",
"itertools",
"itertools 0.13.0",
"proc-macro2",
"quote",
"syn",
@ -821,6 +1289,12 @@ dependencies = [
"getrandom 0.2.15",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "redox_syscall"
version = "0.5.8"
@ -865,6 +1339,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustix"
version = "0.38.44"
@ -875,7 +1355,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
@ -890,6 +1370,25 @@ version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
[[package]]
name = "safetensors"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0436dbfa2778e4ec1a00801b0ae24a1dd619499247d48b0589b679103379d0d4"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@ -952,6 +1451,12 @@ dependencies = [
"libc",
]
[[package]]
name = "siphasher"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
[[package]]
name = "slab"
version = "0.4.9"
@ -974,7 +1479,35 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8"
dependencies = [
"libc",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn",
]
[[package]]
@ -1026,7 +1559,7 @@ dependencies = [
"getrandom 0.3.1",
"once_cell",
"rustix",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
@ -1079,6 +1612,15 @@ dependencies = [
"once_cell",
]
[[package]]
name = "tinygrad"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc8951e9ced45095eb3ad7342c4e74b038bda930833df79b7debb019bb653c18"
dependencies = [
"ndarray",
]
[[package]]
name = "tokio"
version = "1.43.0"
@ -1094,7 +1636,7 @@ dependencies = [
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
@ -1132,6 +1674,23 @@ dependencies = [
"tokio",
]
[[package]]
name = "toml_datetime"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41"
[[package]]
name = "toml_edit"
version = "0.22.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
dependencies = [
"indexmap 2.7.1",
"toml_datetime",
"winnow",
]
[[package]]
name = "tonic"
version = "0.12.3"
@ -1292,12 +1851,31 @@ version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
[[package]]
name = "uuid"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0"
dependencies = [
"getrandom 0.3.1",
]
[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
@ -1322,6 +1900,18 @@ dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]]
name = "winapi"
version = "0.3.9"
@ -1338,6 +1928,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
@ -1353,6 +1952,15 @@ dependencies = [
"windows-targets",
]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
@ -1417,6 +2025,15 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603"
dependencies = [
"memchr",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.33.0"

View File

@ -15,6 +15,13 @@ tracing-subscriber = "0.3"
socket2 = "0.5.8"
system-configuration = "0.6.1"
network-interface = "2.0.0"
uuid = { version = "1.13.1", features = ["v4"] }
regex = "1.11.1"
phf = { version = "0.11.3", features = ["macros"] }
mlx-rs = { version = "0.21.0", features = ["metal", "accelerate", "safetensors"] }
tinygrad = "0.1.0"
llama-cpp-2 = { version = "0.1.93", features = ["metal", "native", "sampler"] }
glob = "0.3.2"
[build-dependencies]
tonic-build = "0.12.3"

View File

@ -0,0 +1,111 @@
use crate::topology::DeviceFlops;
use phf::phf_map;
const TFLOPS: f64 = 1.00;
pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
// Source: https://www.cpu-monkey.com
// Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
// M chips
"Apple M1" => DeviceFlops { fp32: 2.29*TFLOPS, fp16: 4.58*TFLOPS, int8: 9.16*TFLOPS },
"Apple M1 Pro" => DeviceFlops { fp32: 5.30*TFLOPS, fp16: 10.60*TFLOPS, int8: 21.20*TFLOPS },
"Apple M1 Max" => DeviceFlops { fp32: 10.60*TFLOPS, fp16: 21.20*TFLOPS, int8: 42.40*TFLOPS },
"Apple M1 Ultra" => DeviceFlops { fp32: 21.20*TFLOPS, fp16: 42.40*TFLOPS, int8: 84.80*TFLOPS },
"Apple M2" => DeviceFlops { fp32: 3.55*TFLOPS, fp16: 7.10*TFLOPS, int8: 14.20*TFLOPS },
"Apple M2 Pro" => DeviceFlops { fp32: 5.68*TFLOPS, fp16: 11.36*TFLOPS, int8: 22.72*TFLOPS },
"Apple M2 Max" => DeviceFlops { fp32: 13.49*TFLOPS, fp16: 26.98*TFLOPS, int8: 53.96*TFLOPS },
"Apple M2 Ultra" => DeviceFlops { fp32: 26.98*TFLOPS, fp16: 53.96*TFLOPS, int8: 107.92*TFLOPS },
"Apple M3" => DeviceFlops { fp32: 3.55*TFLOPS, fp16: 7.10*TFLOPS, int8: 14.20*TFLOPS },
"Apple M3 Pro" => DeviceFlops { fp32: 4.97*TFLOPS, fp16: 9.94*TFLOPS, int8: 19.88*TFLOPS },
"Apple M3 Max" => DeviceFlops { fp32: 14.20*TFLOPS, fp16: 28.40*TFLOPS, int8: 56.80*TFLOPS },
"Apple M4" => DeviceFlops { fp32: 4.26*TFLOPS, fp16: 8.52*TFLOPS, int8: 17.04*TFLOPS },
"Apple M4 Pro" => DeviceFlops { fp32: 5.72*TFLOPS, fp16: 11.44*TFLOPS, int8: 22.88*TFLOPS },
"Apple M4 Max" => DeviceFlops { fp32: 18.03*TFLOPS, fp16: 36.07*TFLOPS, int8: 72.14*TFLOPS },
// A chips
"Apple A13 Bionic" => DeviceFlops { fp32: 0.69*TFLOPS, fp16: 1.38*TFLOPS, int8: 2.76*TFLOPS },
"Apple A14 Bionic" => DeviceFlops { fp32: 0.75*TFLOPS, fp16: 1.50*TFLOPS, int8: 3.00*TFLOPS },
"Apple A15 Bionic" => DeviceFlops { fp32: 1.37*TFLOPS, fp16: 2.74*TFLOPS, int8: 5.48*TFLOPS },
"Apple A16 Bionic" => DeviceFlops { fp32: 1.79*TFLOPS, fp16: 3.58*TFLOPS, int8: 7.16*TFLOPS },
"Apple A17 Pro" => DeviceFlops { fp32: 2.15*TFLOPS, fp16: 4.30*TFLOPS, int8: 8.60*TFLOPS },
// NVIDIA GPUs
// RTX 40 series
"NVIDIA GEFORCE RTX 4090" => DeviceFlops { fp32: 82.58*TFLOPS, fp16: 165.16*TFLOPS, int8: 330.32*TFLOPS },
"NVIDIA GEFORCE RTX 4080" => DeviceFlops { fp32: 48.74*TFLOPS, fp16: 97.48*TFLOPS, int8: 194.96*TFLOPS },
"NVIDIA GEFORCE RTX 4080 SUPER" => DeviceFlops { fp32: 52.0*TFLOPS, fp16: 104.0*TFLOPS, int8: 208.0*TFLOPS },
"NVIDIA GEFORCE RTX 4070 TI SUPER" => DeviceFlops { fp32: 40.0*TFLOPS, fp16: 80.0*TFLOPS, int8: 160.0*TFLOPS },
"NVIDIA GEFORCE RTX 4070 TI" => DeviceFlops { fp32: 39.43*TFLOPS, fp16: 78.86*TFLOPS, int8: 157.72*TFLOPS },
"NVIDIA GEFORCE RTX 4070 SUPER" => DeviceFlops { fp32: 30.0*TFLOPS, fp16: 60.0*TFLOPS, int8: 120.0*TFLOPS },
"NVIDIA GEFORCE RTX 4070" => DeviceFlops { fp32: 29.0*TFLOPS, fp16: 58.0*TFLOPS, int8: 116.0*TFLOPS },
"NVIDIA GEFORCE RTX 4060 TI 16GB" => DeviceFlops { fp32: 22.0*TFLOPS, fp16: 44.0*TFLOPS, int8: 88.0*TFLOPS },
"NVIDIA GEFORCE RTX 4060 TI" => DeviceFlops { fp32: 22.0*TFLOPS, fp16: 44.0*TFLOPS, int8: 88.0*TFLOPS },
// RTX 30 series
"NVIDIA GEFORCE RTX 3050" => DeviceFlops { fp32: 9.11*TFLOPS, fp16: 18.22*TFLOPS, int8: 36.44*TFLOPS },
"NVIDIA GEFORCE RTX 3060" => DeviceFlops { fp32: 13.0*TFLOPS, fp16: 26.0*TFLOPS, int8: 52.0*TFLOPS },
"NVIDIA GEFORCE RTX 3060 TI" => DeviceFlops { fp32: 16.2*TFLOPS, fp16: 32.4*TFLOPS, int8: 64.8*TFLOPS },
"NVIDIA GEFORCE RTX 3070" => DeviceFlops { fp32: 20.3*TFLOPS, fp16: 40.6*TFLOPS, int8: 81.2*TFLOPS },
"NVIDIA GEFORCE RTX 3070 TI" => DeviceFlops { fp32: 21.8*TFLOPS, fp16: 43.6*TFLOPS, int8: 87.2*TFLOPS },
"NVIDIA GEFORCE RTX 3080 (10 GB)" => DeviceFlops { fp32: 29.8*TFLOPS, fp16: 59.6*TFLOPS, int8: 119.2*TFLOPS },
"NVIDIA GEFORCE RTX 3080 (12 GB)" => DeviceFlops { fp32: 30.6*TFLOPS, fp16: 61.2*TFLOPS, int8: 122.4*TFLOPS },
"NVIDIA GEFORCE RTX 3080 TI" => DeviceFlops { fp32: 34.1*TFLOPS, fp16: 68.2*TFLOPS, int8: 136.4*TFLOPS },
"NVIDIA GEFORCE RTX 3090" => DeviceFlops { fp32: 35.6*TFLOPS, fp16: 71.2*TFLOPS, int8: 142.4*TFLOPS },
"NVIDIA GEFORCE RTX 3090 TI" => DeviceFlops { fp32: 40.0*TFLOPS, fp16: 80.0*TFLOPS, int8: 160.0*TFLOPS },
// RTX 20 series
"NVIDIA GEFORCE RTX 2060" => DeviceFlops { fp32: 6.45*TFLOPS, fp16: 12.9*TFLOPS, int8: 25.8*TFLOPS },
"NVIDIA GEFORCE RTX 2060 SUPER" => DeviceFlops { fp32: 7.2*TFLOPS, fp16: 14.4*TFLOPS, int8: 28.8*TFLOPS },
"NVIDIA GEFORCE RTX 2070" => DeviceFlops { fp32: 7.46*TFLOPS, fp16: 14.93*TFLOPS, int8: 29.86*TFLOPS },
"NVIDIA GEFORCE RTX 2070 SUPER" => DeviceFlops { fp32: 9.06*TFLOPS, fp16: 18.12*TFLOPS, int8: 36.24*TFLOPS },
"NVIDIA GEFORCE RTX 2080" => DeviceFlops { fp32: 10.07*TFLOPS, fp16: 20.14*TFLOPS, int8: 40.28*TFLOPS },
"NVIDIA GEFORCE RTX 2080 TI" => DeviceFlops { fp32: 13.45*TFLOPS, fp16: 26.9*TFLOPS, int8: 40.28*TFLOPS },
"NVIDIA GEFORCE RTX 2080 SUPER" => DeviceFlops { fp32: 11.15*TFLOPS, fp16: 22.30*TFLOPS, int8: 44.60*TFLOPS },
"NVIDIA TITAN RTX" => DeviceFlops { fp32: 16.31*TFLOPS, fp16: 32.62*TFLOPS, int8: 65.24*TFLOPS },
// GTX 10 series
"NVIDIA GEFORCE GTX 1050 TI" => DeviceFlops { fp32: 2.0*TFLOPS, fp16: 4.0*TFLOPS, int8: 8.0*TFLOPS },
"NVIDIA GEFORCE GTX 1070" => DeviceFlops { fp32: 6.463*TFLOPS, fp16: 0.101*TFLOPS, int8: 25.852*TFLOPS },
"NVIDIA GEFORCE GTX 1080" => DeviceFlops { fp32: 8.873*TFLOPS, fp16: 0.138*TFLOPS, int8: 35.492*TFLOPS },
"NVIDIA GEFORCE GTX 1080 TI" => DeviceFlops { fp32: 11.34*TFLOPS, fp16: 0.177*TFLOPS, int8: 45.36*TFLOPS },
// GTX 16 series
"NVIDIA GeForce GTX 1660 TI" => DeviceFlops { fp32: 4.8*TFLOPS, fp16: 9.6*TFLOPS, int8: 19.2*TFLOPS },
// QUADRO RTX Ampere series
"NVIDIA RTX A2000" => DeviceFlops { fp32: 7.99*TFLOPS, fp16: 7.99*TFLOPS, int8: 31.91*TFLOPS },
"NVIDIA RTX A4000" => DeviceFlops { fp32: 19.17*TFLOPS, fp16: 19.17*TFLOPS, int8: 76.68*TFLOPS },
"NVIDIA RTX A4500" => DeviceFlops { fp32: 23.65*TFLOPS, fp16: 23.65*TFLOPS, int8: 94.6*TFLOPS },
"NVIDIA RTX A5000" => DeviceFlops { fp32: 27.8*TFLOPS, fp16: 27.8*TFLOPS, int8: 111.2*TFLOPS },
"NVIDIA RTX A6000" => DeviceFlops { fp32: 38.71*TFLOPS, fp16: 38.71*TFLOPS, int8: 154.84*TFLOPS },
// NVIDIA Ada Lovelace Architecture-Based
"NVIDIA RTX 4000 ADA GENERATION" => DeviceFlops { fp32: 26.7*TFLOPS, fp16: 26.7*TFLOPS, int8: 258.0*TFLOPS },
// Common Server GPUs
"NVIDIA A40 48GB PCIE" => DeviceFlops { fp32: 37.4*TFLOPS, fp16: 149.7*TFLOPS, int8: 299.3*TFLOPS },
"NVIDIA A100 40GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
"NVIDIA A800 40GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
"NVIDIA A100 80GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
"NVIDIA A800 80GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
"NVIDIA A100 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
"NVIDIA A800 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
// AMD GPUs
// RX 6000 series
"AMD Radeon RX 6900 XT" => DeviceFlops { fp32: 23.04*TFLOPS, fp16: 46.08*TFLOPS, int8: 92.16*TFLOPS },
"AMD Radeon RX 6800 XT" => DeviceFlops { fp32: 20.74*TFLOPS, fp16: 41.48*TFLOPS, int8: 82.96*TFLOPS },
"AMD Radeon RX 6800" => DeviceFlops { fp32: 16.17*TFLOPS, fp16: 32.34*TFLOPS, int8: 64.68*TFLOPS },
"AMD Radeon RX 6700 XT" => DeviceFlops { fp32: 13.21*TFLOPS, fp16: 26.42*TFLOPS, int8: 52.84*TFLOPS },
"AMD Radeon RX 6700" => DeviceFlops { fp32: 11.4*TFLOPS, fp16: 22.8*TFLOPS, int8: 45.6*TFLOPS },
"AMD Radeon RX 6600 XT" => DeviceFlops { fp32: 10.6*TFLOPS, fp16: 21.2*TFLOPS, int8: 42.4*TFLOPS },
"AMD Radeon RX 6600" => DeviceFlops { fp32: 8.93*TFLOPS, fp16: 17.86*TFLOPS, int8: 35.72*TFLOPS },
"AMD Radeon RX 6500 XT" => DeviceFlops { fp32: 5.77*TFLOPS, fp16: 11.54*TFLOPS, int8: 23.08*TFLOPS },
"AMD Radeon RX 6400" => DeviceFlops { fp32: 3.57*TFLOPS, fp16: 7.14*TFLOPS, int8: 14.28*TFLOPS },
// RX 7000 series
"AMD Radeon RX 7900 XTX" => DeviceFlops { fp32: 61.4*TFLOPS, fp16: 122.8*TFLOPS, int8: 245.6*TFLOPS },
"AMD Radeon RX 7900 XT" => DeviceFlops { fp32: 53.4*TFLOPS, fp16: 106.8*TFLOPS, int8: 213.6*TFLOPS },
"AMD Radeon RX 7800 XT" => DeviceFlops { fp32: 42.6*TFLOPS, fp16: 85.2*TFLOPS, int8: 170.4*TFLOPS },
"AMD Radeon RX 7700 XT" => DeviceFlops { fp32: 34.2*TFLOPS, fp16: 68.4*TFLOPS, int8: 136.8*TFLOPS },
"AMD Radeon RX 7600" => DeviceFlops { fp32: 21.5*TFLOPS, fp16: 43.0*TFLOPS, int8: 86.0*TFLOPS },
"AMD Radeon RX 7500" => DeviceFlops { fp32: 16.2*TFLOPS, fp16: 32.4*TFLOPS, int8: 64.8*TFLOPS },
};
pub fn look_up(chip: &str) -> Option<DeviceFlops> {
CHIP_FLOPS
.get(chip)
.or_else(|| CHIP_FLOPS.get(&format!("Laptop GPU {}", chip)))
.or_else(|| CHIP_FLOPS.get(&format!("{} Laptop GPU", chip)))
.cloned()
}

View File

@ -1,11 +1,17 @@
use crate::network::get_broadcast_creation_info;
use crate::orchestration::PeerHandle;
use crate::topology::DeviceCapabilities;
use serde::{Deserialize, Serialize};
use socket2::{Domain, Protocol, Socket, Type};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tracing::{debug, info};
use uuid::Uuid;
mod broadcast;
mod udp_listen;
@ -47,24 +53,49 @@ pub struct NodeInfo {
pub device_capabilities: DeviceCapabilities,
}
impl Default for NodeInfo {
fn default() -> Self {
NodeInfo {
node_id: Uuid::new_v4().to_string(),
discovery_listen_port: 5678,
broadcast_port: 5678,
broadcast_interval: Duration::from_secs_f32(2.5),
grpc_port: 49152,
allowed_peer_ids: None,
allowed_interfaces: None,
discovery_timeout: Duration::from_secs(30),
device_capabilities: DeviceCapabilities::determine(),
}
}
}
#[derive(Debug)]
pub struct UdpDiscovery {
node_info: NodeInfo,
discovery_handle: JoinHandle<()>,
presence_handle: JoinHandle<()>,
peer_manager_handle: JoinHandle<()>,
pub peers: Arc<Mutex<HashMap<String, PeerHandle>>>,
}
impl UdpDiscovery {
pub fn new(node_info: NodeInfo) -> Self {
let broadcast_creation_info = get_broadcast_creation_info();
let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info));
let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone());
info!("Found addresses: {:?}", broadcast_creation_info);
let peers = Arc::new(Mutex::new(HashMap::new()));
let discovery_handle = tokio::spawn(broadcast::listen_all(
node_info.clone(),
broadcast_creation_info,
));
let (presence_handle, peer_manager_handle) =
udp_listen::manage_discovery(node_info.clone(), peers.clone());
UdpDiscovery {
node_info,
discovery_handle,
presence_handle,
peer_manager_handle,
peers,
}
}

View File

@ -1,15 +1,13 @@
use crate::discovery::{DiscoveryMessage, NodeInfo};
use crate::orchestration::PeerHandle;
use crate::topology::DeviceCapabilities;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::Duration;
use system_configuration::sys::libc::disconnectx;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::select;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tonic::transport::Error;
use tracing::{debug, error, info};
async fn listen_for_discovery(
@ -119,8 +117,10 @@ async fn handle_new_peer(
peers.insert(message.node_id, new_peer);
}
pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>) {
let mut peers: HashMap<String, PeerHandle> = HashMap::new();
pub fn manage_discovery(
node_info: NodeInfo,
peers: Arc<Mutex<HashMap<String, PeerHandle>>>,
) -> (JoinHandle<()>, JoinHandle<()>) {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>();
// TODO: How do we handle killing this?
@ -128,11 +128,13 @@ pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>)
let peer_manager_handle = tokio::spawn(async move {
loop {
let action = select! {
let action: Action = select! {
_ = tokio::time::sleep(node_info.discovery_timeout) => Action::HealthChecks,
Some((addr, message)) = rx.recv() => Action::NewPeer(addr, message),
};
let mut peers = peers.lock().await;
match action {
Action::NewPeer(addr, message) => handle_new_peer(&mut peers, addr, message).await,
Action::HealthChecks => perform_health_checks(&mut peers).await,

20
src/inference.rs Normal file
View File

@ -0,0 +1,20 @@
use std::collections::HashMap;
use crate::node_service::{InferenceState, Tensor};
use crate::Shard;
#[derive(Debug)]
pub struct InferenceEngine {
// state_cache: HashMap<String, _>,
}
impl InferenceEngine {
pub(crate) fn infer_tensor(
&self,
request_id: String,
shard: Shard,
tensor: Option<Tensor>,
inference_state: Option<InferenceState>,
) -> Tensor {
}
}

160
src/llama_module.rs Normal file
View File

@ -0,0 +1,160 @@
use crate::Shard;
use mlx_rs::builder::Builder;
use mlx_rs::macros::ModuleParameters;
use mlx_rs::module::Module;
use mlx_rs::nn::{Embedding, RmsNorm, RmsNormBuilder};
use mlx_rs::Array;
use serde::Deserialize;
use std::collections::HashMap;
use std::env::args;
use mlx_rs::ops::zeros;
#[derive(Debug, Deserialize, ModuleParameters)]
struct ModelArgs {
vocab_size: i32,
hidden_size: i32,
num_hidden_layers: i32,
rms_norm_eps: f32,
}
#[derive(Debug)]
enum ShardedLayer {
TransformerBlock,
IdentityBlock,
}
#[derive(Debug, ModuleParameters)]
struct LlamaModel {
args: ModelArgs,
shard: Shard,
layers: Vec<ShardedLayer>,
embed_tokens: Embedding,
norm: RmsNorm,
cache: Vec<Option<Array>>,
}
impl LlamaModel {
fn new(args: ModelArgs, shard: Shard) -> Self {
let embed_tokens = Embedding::new(args.vocab_size, args.hidden_size).unwrap();
let layers = (0..(args.num_hidden_layers as u32)).map(|i| {
if shard.start_layer <= i && i <= shard.end_layer {
ShardedLayer::TransformerBlock
} else {
ShardedLayer::IdentityBlock
}
}).collect::<Vec<_>>();
let norm = RmsNormBuilder::new(args.hidden_size)
.eps(args.rms_norm_eps)
.build()
.unwrap();
Self {
cache: vec![None; args.num_hidden_layers as usize],
args,
shard,
layers,
embed_tokens,
norm
}
}
}
impl Module<Array> for LlamaModel {
type Output = Array;
type Error = mlx_rs::error::Exception;
fn forward(&mut self, input: Array) -> Result<Self::Output, Self::Error> {
let h = if self.shard.is_first_layer() {
self.embed_tokens.forward(&input)?
} else {
input
};
let mut mask = if h.ndim() > 1 && h.shape()[1] > 1 {
Some(create_attention_mask(&h, &self.cache)?)
} else {
None
};
let h = self.layers.iter_mut().zip(self.cache.iter_mut())
.fold(h, |h, (layer, c)| {
layer.forward(&h, mask.as_ref(), c)?
});
let h = if self.shard.is_last_layer() {
self.norm.forward(&h)?
} else {
h
};
Ok(h)
}
fn training_mode(&mut self, mode: bool) {
todo!()
}
}
fn create_attention_mask(h: &Array, cache: &[Option<HashMap<String, i32>>]) -> Result<Array, mlx_rs::error::Exception> {
let shape = h.shape();
let t = shape[1];
if t > 1 {
let (window_size, offset) = match cache {
&[Some(ref cache), ..] => {
let offset = *cache.get("offset").unwrap();
if let Some(max_size) = cache.get("max_size") {
(Some(*max_size), i32::min(*max_size, offset))
} else {
(None, offset)
}
},
_ => (None, 0),
};
let mask = create_causal_mask(t, offset, window_size, None)?;
mask.as_dtype(h.dtype())
} else {
Ok(zeros(&[0])) // Return empty array when T <= 1
}
}
fn create_causal_mask(
n: i32,
offset: i32,
window_size: Option<i32>,
lengths: Option<&Array>
) -> Result<Array, mlx_rs::error::Exception> {
let rinds = Array::arange(0, offset + n, 1)?;
let linds = if offset > 0 {
Array::arange(0, offset + n, 1)?
} else {
rinds.clone()
};
let linds = linds.reshape(&[-1, 1])?;
let rinds = rinds.reshape(&[1, -1])?;
let mut mask = linds.lt(&rinds)?;
if let Some(w) = window_size {
let window_mask = linds.gt(&(rinds + w))?;
mask = mask.logical_or(&window_mask)?;
}
if let Some(l) = lengths {
let l = l.reshape(&[-1, 1, 1, 1])?;
let length_mask = rinds.greater_equal(&l)?;
mask = mask.logical_or(&length_mask)?;
}
mask.multiply(-1e9)
}

73
src/llama_test.rs Normal file
View File

@ -0,0 +1,73 @@
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel, Special};
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::sampling::LlamaSampler;
fn test() {
let model_path = std::env::args().nth(1).expect("Please specify model path");
let backend = LlamaBackend::init().unwrap();
let params = LlamaModelParams::default();
let prompt =
"<|im_start|>user\nHello! how are you?<|im_end|>\n<|im_start|>assistant\n".to_string();
LlamaContextParams::default();
let model =
LlamaModel::load_from_file(&backend, model_path, &params).expect("unable to load model");
let ctx_params = LlamaContextParams::default();
let mut ctx = model
.new_context(&backend, ctx_params)
.expect("unable to create the llama_context");
let tokens_list = model
.str_to_token(&prompt, AddBos::Always)
.unwrap_or_else(|_| panic!("failed to tokenize {prompt}"));
let n_len = 64;
// create a llama_batch with size 512
// we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(512, 1);
let last_index = tokens_list.len() as i32 - 1;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
// llama_decode will output logits only for the last token of the prompt
let is_last = i == last_index;
batch.add(token, i, &[0], is_last).unwrap();
}
ctx.decode(&mut batch).expect("llama_decode() failed");
let mut n_cur = batch.n_tokens();
// The `Decoder`
// let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut sampler = LlamaSampler::greedy();
while n_cur <= n_len {
// sample the next token
{
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
// is it an end of stream?
if token == model.token_eos() {
eprintln!();
break;
}
let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap();
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
// let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
// print!("{output_string}");
// std::io::stdout().flush().unwrap();
batch.clear();
batch.add(token, n_cur, &[0], true).unwrap();
}
n_cur += 1;
ctx.decode(&mut batch).expect("failed to eval");
}
}

View File

@ -1,32 +1,100 @@
mod topology;
mod orchestration;
mod device_capability_data;
mod discovery;
mod inference;
mod network;
mod orchestration;
mod partitioning;
mod topology;
mod llama_test;
mod module_loading;
mod llama_module;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tonic::{transport::Server, Request, Response, Status};
use crate::discovery::{NodeInfo, UdpDiscovery};
use crate::inference::InferenceEngine;
use crate::node_service::{
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse,
InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor,
Topology as TopologyProto,
};
use node_service::node_service_server::{NodeService, NodeServiceServer};
use node_service::TensorRequest;
use std::collections::HashSet;
use topology::Topology;
use uuid::Uuid;
pub mod node_service {
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
}
#[derive(Debug)]
struct Node {
node_info: NodeInfo,
current_topology: Topology,
udp_discovery: UdpDiscovery,
inference_engine: InferenceEngine,
}
impl Node {
#[tracing::instrument]
pub(crate) async fn process_tensor(
&self,
base_shard: Shard,
tensor: Option<Tensor>,
request_id: String,
inference_state: Option<InferenceState>,
) -> Tensor {
let shard = self
.current_topology
.get_shard_for_node(base_shard, &self.node_info.node_id);
let result: Tensor = self
.inference_engine
.infer_tensor(request_id, shard, tensor, inference_state);
let result = self.process_inference_result(shard, result, request_id, inference_state);
result
}
#[tracing::instrument]
pub async fn process_prompt(
&self,
base_shard: Shard,
prompt: String,
request_id: String,
inference_state: Option<InferenceState>,
) {
let shard = self
.current_topology
.get_shard_for_node(base_shard, &self.node_info.node_id);
todo!();
// The python code is a little weird wrt return types here
// if shard.is_first_layer() {
// let result = self
// .inference_engine
// .infer_prompt(request_id, shard, prompt, inference_state)
// .await;
// self.process_inference_result(shard, result, request_id, inference_state)
// } else {
// self.forward_prompt(shard, prompt, request_id, inference_state)
// }
}
}
impl Default for Node {
fn default() -> Self {
let node_info = NodeInfo::default();
Self {
node_info: node_info.clone(),
current_topology: Topology::default(),
udp_discovery: UdpDiscovery::new(node_info),
}
}
}
@ -42,11 +110,36 @@ enum OpaqueStatus {
#[derive(Debug, Deserialize, Serialize, Clone)]
struct Shard {
pub model_id: String,
pub start_layer: i32,
pub end_layer: i32,
pub n_layers: i32,
pub start_layer: u32,
pub end_layer: u32,
#[serde(rename = "n_layers")]
pub total_layers: u32,
}
impl Shard {
pub fn is_first_layer(&self) -> bool {
self.start_layer == 0
}
pub fn is_last_layer(&self) -> bool {
self.end_layer == self.total_layers - 1
}
pub fn len(&self) -> u32 {
self.end_layer - self.start_layer + 1
}
}
impl From<node_service::Shard> for Shard {
fn from(proto: node_service::Shard) -> Self {
Self {
model_id: proto.model_id,
start_layer: proto.start_layer as u32,
end_layer: proto.end_layer as u32,
total_layers: proto.n_layers as u32,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct NodeStatus {
@ -132,14 +225,36 @@ impl NodeService for Node {
&self,
request: Request<PromptRequest>,
) -> Result<Response<Tensor>, Status> {
todo!()
let request = request.into_inner();
let request_id = request
.request_id
.unwrap_or_else(|| Uuid::new_v4().to_string());
let result = self.process_prompt(
request
.shard
.expect("No shard given. ExoPy does not allow this")
.into(),
request.prompt,
request_id,
request.inference_state,
);
todo!();
}
async fn send_tensor(
&self,
request: Request<TensorRequest>,
) -> Result<Response<Tensor>, Status> {
todo!()
let request = request.into_inner();
let shard = request.shard.expect("No shard provided").into();
let request_id = request.request_id.expect("No request id provided");
let result =
self.process_tensor(shard, request.tensor, request_id, request.inference_state);
Ok(Response::new(result.into()))
}
async fn send_example(
@ -149,11 +264,19 @@ impl NodeService for Node {
todo!()
}
// TODO: Why aren't we using the request?
async fn collect_topology(
&self,
request: Request<CollectTopologyRequest>,
) -> Result<Response<TopologyProto>, Status> {
todo!()
let request = request.into_inner();
let max_depth = request.max_depth as u8;
let visited = request.visited;
self.update_topology_inner(max_depth, visited.into_iter().collect())
.await;
Ok(Response::new(self.current_topology.clone().into()))
}
async fn send_result(
@ -186,13 +309,49 @@ impl NodeService for Node {
}
}
impl Node {
async fn update_topology(&mut self) {
let overall_max_depth = 4;
let visited: HashSet<String> = HashSet::new();
self.current_topology = self.update_topology_inner(overall_max_depth, visited).await;
}
async fn update_topology_inner(&self, max_depth: u8, mut visited: HashSet<String>) -> Topology {
let mut new_topology = Topology::default();
new_topology.update_node(
self.node_info.node_id.clone(),
self.node_info.device_capabilities.clone(),
);
for peer in self.udp_discovery.peers.lock().await.values() {
new_topology.update_node(peer.node_id.clone(), peer.device_capabilities.clone());
new_topology.update_edge(
self.node_info.node_id.clone(),
peer.node_id.clone(),
peer.description.clone(),
);
visited.insert(peer.node_id.clone());
if !visited.contains(&peer.node_id) {
let topology = peer.collect_topology(visited.clone(), max_depth - 1).await;
new_topology.merge_restricted(&peer.node_id, topology);
}
}
new_topology
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// install global collector configured based on RUST_LOG env var.
tracing_subscriber::fmt::init();
let grpc_addr = "[::1]:50051".parse()?;
let node = Node::default();
let node: Node = Node::default();
// TODO: Also implement discovery

63
src/module_loading.rs Normal file
View File

@ -0,0 +1,63 @@
use std::collections::HashMap;
use std::path::Path;
use serde_json::Value;
use crate::Shard;
fn load_config(
model_path: &Path,
) -> serde_json::Map<String, serde_json::Value> {
let config_path = model_path.join("config.json");
let model_index = model_path.join("model_index.json");
if config_path.exists() {
let config = std::fs::read_to_string(config_path).unwrap();
serde_json::from_str(&config).unwrap()
} else {
let model_index = std::fs::read_to_string(model_index).unwrap();
serde_json::from_str(&model_index).unwrap()
}
}
async fn load_model_shard(
model_path: &Path,
shard: Shard,
lazy: bool,
model_config: serde_json::Map<String, serde_json::Value>,
) {
let mut config = load_config(model_path);
config.extend(model_config.into_iter());
let model_name = model_path.file_name().unwrap().to_str().unwrap();
config["shard"] = serde_json::json!({
"model_id": model_name,
"start_layer": shard.start_layer,
"end_layer": shard.end_layer,
"n_layers": shard.total_layers,
});
let mut weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
let model_args = serde_json::from_value::<ModelArgsRs>(Value::Object(config)).unwrap();
let model = LlamaModelRs::new(model_args).unwrap();
weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap().to_string());
let mut weights = HashMap::new();
for weight_file in weight_files {
weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap());
}
todo!();
}
#[test]
fn test_load_llama() {
load_model_shard(
)
}

View File

@ -1,13 +1,9 @@
use crate::topology::DeviceCapabilities;
use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig};
use serde::{Deserialize, Serialize};
use socket2::{Domain, Protocol, Socket, Type};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use std::net::Ipv4Addr;
use system_configuration::core_foundation::array::CFArray;
use system_configuration::network_configuration::{get_interfaces, SCNetworkInterface, SCNetworkInterfaceType};
use system_configuration::sys::network_configuration::SCNetworkInterfaceRef;
use tokio::net::UdpSocket;
use system_configuration::network_configuration::{
get_interfaces, SCNetworkInterface, SCNetworkInterfaceType,
};
#[derive(Copy, Clone, Debug)]
pub enum InterfaceType {
@ -76,7 +72,8 @@ pub fn get_broadcast_creation_info() -> Vec<BroadcastCreationInfo> {
for (bind_address, broadcast_address) in map {
let interface_type = if cfg!(target_os = "macos") {
get_sc_interface_type(&interface.name, &sc_interfaces).unwrap_or_else(|| determine_interface_type(&interface.name))
get_sc_interface_type(&interface.name, &sc_interfaces)
.unwrap_or_else(|| determine_interface_type(&interface.name))
} else {
determine_interface_type(&interface.name)
};
@ -94,7 +91,10 @@ pub fn get_broadcast_creation_info() -> Vec<BroadcastCreationInfo> {
}
#[cfg(target_os = "macos")]
fn get_sc_interface_type(name: &str, sc_interfaces: &CFArray<SCNetworkInterface>) -> Option<InterfaceType> {
fn get_sc_interface_type(
name: &str,
sc_interfaces: &CFArray<SCNetworkInterface>,
) -> Option<InterfaceType> {
sc_interfaces.iter().find_map(|sc_if| {
sc_if
.bsd_name()
@ -110,9 +110,14 @@ fn get_sc_interface_type(name: &str, sc_interfaces: &CFArray<SCNetworkInterface>
fn determine_interface_type(name: &str) -> InterfaceType {
// Fallback to interface name pattern matching
if name.starts_with("docker") || name.starts_with("br-") || name.starts_with("veth")
|| name.starts_with("cni") || name.starts_with("flannel") || name.starts_with("calico")
|| name.starts_with("weave") || name.contains("bridge")
if name.starts_with("docker")
|| name.starts_with("br-")
|| name.starts_with("veth")
|| name.starts_with("cni")
|| name.starts_with("flannel")
|| name.starts_with("calico")
|| name.starts_with("weave")
|| name.contains("bridge")
{
InterfaceType::ContainerVirtual
} else if name.starts_with("lo") {
@ -122,13 +127,20 @@ fn determine_interface_type(name: &str) -> InterfaceType {
} else if (name.starts_with("eth") || name.starts_with("en")) && !matches!(name, "en0" | "en1")
{
InterfaceType::Ethernet
} else if name.starts_with("wlan") || name.starts_with("wifi") || name.starts_with("wl")
|| matches!(name, "en0" | "en1")
} else if name.starts_with("wlan")
|| name.starts_with("wifi")
|| name.starts_with("wl")
|| matches!(name, "en0" | "en1")
{
InterfaceType::WiFi
} else if name.starts_with("tun") || name.starts_with("tap") || name.starts_with("vtun")
|| name.starts_with("utun") || name.starts_with("gif") || name.starts_with("stf")
|| name.starts_with("awdl") || name.starts_with("llw")
} else if name.starts_with("tun")
|| name.starts_with("tap")
|| name.starts_with("vtun")
|| name.starts_with("utun")
|| name.starts_with("gif")
|| name.starts_with("stf")
|| name.starts_with("awdl")
|| name.starts_with("llw")
{
InterfaceType::ExternalVirtual
} else {

View File

@ -1,15 +1,17 @@
use crate::node_service::node_service_client::NodeServiceClient;
use crate::node_service::HealthCheckRequest;
use crate::topology::DeviceCapabilities;
use crate::node_service::{CollectTopologyRequest, HealthCheckRequest};
use crate::topology::{DeviceCapabilities, Topology};
use std::collections::HashSet;
use std::net::SocketAddr;
use tonic::codec::CompressionEncoding;
#[derive(Debug, Clone)]
pub struct PeerHandle {
pub node_id: String,
pub address: SocketAddr,
pub address_priority: u8,
pub description: Option<String>,
pub client: tokio::sync::Mutex<NodeServiceClient<tonic::transport::Channel>>,
client: NodeServiceClient<tonic::transport::Channel>,
pub device_capabilities: DeviceCapabilities,
}
@ -31,19 +33,35 @@ impl PeerHandle {
description,
address_priority,
address,
client: tokio::sync::Mutex::new(client),
client,
device_capabilities,
})
}
pub fn client(&self) -> NodeServiceClient<tonic::transport::Channel> {
self.client.clone()
}
pub async fn is_healthy(&self) -> bool {
self.client
.lock()
.await
self.client()
.health_check(HealthCheckRequest::default())
.await
.ok()
.map(|x| x.into_inner().is_healthy)
.unwrap_or(false)
}
pub async fn collect_topology(&self, visited: HashSet<String>, max_depth: u8) -> Topology {
let response = self
.client()
.collect_topology(CollectTopologyRequest {
visited: visited.clone().into_iter().collect(),
max_depth: max_depth as i32,
})
.await
.unwrap()
.into_inner();
response.into()
}
}

62
src/partitioning.rs Normal file
View File

@ -0,0 +1,62 @@
use crate::Shard;
use crate::topology::Topology;
pub enum PartitionStrategy {
RingMemoryWeighted,
}
pub struct Partition {
pub node_id: String,
pub start: f32,
pub end: f32,
}
impl PartitionStrategy {
pub fn partition(&self, topology: &Topology) -> Vec<Partition> {
let mut entries = topology.nodes.iter().collect::<Vec<_>>();
entries.sort_by_key(|(node_id, device_capabilities)| {
(device_capabilities.memory, node_id.clone())
});
let mut start = 0.0;
let mut partitions = Vec::with_capacity(entries.len());
for (node_id, device_capabilities) in entries {
let end = ((start + device_capabilities.memory as f32) * 100000.0).round() / 100000.0;
partitions.push(Partition {
node_id: node_id.to_string(),
start,
end,
});
start = end;
}
partitions
}
}
pub fn shard_model_by_partition(
partition_set: &[Partition],
total_layers: u32,
model_id: &str,
) -> Vec<Shard> {
let mut shards: Vec<Shard> = Vec::with_capacity(partition_set.len());
for partition in partition_set {
let start_layer = (partition.start * total_layers as f32).round() as u32;
let mut end_layer = (partition.end * total_layers as f32).round() as u32 - 1;
if end_layer < start_layer {
end_layer = total_layers - 1;
}
shards.push(Shard {
model_id: model_id.to_string(),
start_layer,
end_layer,
total_layers,
});
}
shards
}

View File

@ -1,11 +1,152 @@
use std::collections::HashMap;
use crate::partitioning::{shard_model_by_partition, PartitionStrategy};
use crate::{device_capability_data, node_service, Shard};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::process::Command;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Topology {
nodes: HashMap<String, DeviceCapabilities>,
peer_graph: HashMap<String, Vec<PeerConnection>>,
active_node_id: Option<String>
pub nodes: HashMap<String, DeviceCapabilities>,
pub peer_graph: HashMap<String, Vec<PeerConnection>>,
pub active_node_id: Option<String>,
}
impl Topology {
pub fn get_shard_for_node(&self, base_shard: Shard, node_id: &str) -> Shard {
let partition_set = PartitionStrategy::RingMemoryWeighted.partition(&self);
// TODO: This feels like it could be a better data structure
let partition_index = partition_set
.iter()
.position(|s| s.node_id == node_id)
.expect("Did not find node in partition set");
let shards = shard_model_by_partition(
&partition_set,
base_shard.total_layers.try_into().unwrap(),
base_shard.model_id.as_str(),
);
shards[partition_index].clone()
}
}
impl Topology {
pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) {
self.nodes.insert(node_id, device_capabilities);
}
pub fn update_edge(&mut self, from_id: String, to_id: String, description: Option<String>) {
let conn = PeerConnection {
from_id: from_id.clone(),
to_id,
description,
};
match self.peer_graph.get_mut(&from_id) {
None => {
self.peer_graph.insert(from_id, vec![conn]);
}
Some(existing) => {
existing.push(conn);
}
}
}
pub(crate) fn merge_restricted(&mut self, from_peer_id: &str, topology: Topology) {
if let Some(peer_capabilities) = topology.nodes.get(from_peer_id) {
self.nodes
.insert(from_peer_id.to_string(), peer_capabilities.clone());
}
self.peer_graph.extend(
topology
.peer_graph
.into_iter()
.filter(|(id, _)| id == from_peer_id),
);
}
}
impl From<node_service::Topology> for Topology {
fn from(proto: node_service::Topology) -> Self {
let nodes = proto
.nodes
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect();
let peer_graph = proto
.peer_graph
.into_iter()
.map(|(from_id, connections)| {
(
from_id.clone(),
connections
.connections
.into_iter()
.map(|pc| PeerConnection {
from_id: from_id.clone(),
to_id: pc.to_id,
description: pc.description,
})
.collect(),
)
})
.collect();
Topology {
nodes,
peer_graph,
active_node_id: None,
}
}
}
impl Into<node_service::Topology> for Topology {
fn into(self) -> node_service::Topology {
let nodes = self
.nodes
.iter()
.map(|(node_id, cap)| {
(
node_id.clone(),
node_service::DeviceCapabilities {
model: cap.model.clone(),
chip: cap.chip.clone(),
memory: cap.memory as i32,
flops: Some(node_service::DeviceFlops {
fp32: cap.flops.fp32,
fp16: cap.flops.fp16,
int8: cap.flops.int8,
}),
},
)
})
.collect::<HashMap<String, node_service::DeviceCapabilities>>();
let peer_graph = self
.peer_graph
.iter()
.map(|(node_id, connections)| {
(
node_id.clone(),
node_service::PeerConnections {
connections: connections
.iter()
.map(|conn| node_service::PeerConnection {
to_id: conn.to_id.clone(),
description: conn.description.clone(),
})
.collect(),
},
)
})
.collect::<HashMap<String, node_service::PeerConnections>>();
node_service::Topology { nodes, peer_graph }
}
}
impl Default for Topology {
@ -20,22 +161,107 @@ impl Default for Topology {
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct DeviceCapabilities {
model: String,
chip: String,
memory: u64,
flops: DeviceFlops,
pub model: String,
pub chip: String,
pub memory: u64,
pub flops: DeviceFlops,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct SystemProfilerOutputData {
#[serde(rename = "SPHardwareDataType")]
hardware: Vec<SPHardwareDataType>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct SPHardwareDataType {
#[serde(rename = "_name")]
name: String,
activation_lock_status: String,
boot_rom_version: String,
chip_type: String,
machine_model: String,
machine_name: String,
model_number: String,
number_processors: String,
os_loader_version: String,
physical_memory: String,
#[serde(rename = "platform_UUID")]
platform_uuid: String,
#[serde(rename = "provisioning_UDID")]
provisioning_udid: String,
serial_number: String,
}
impl DeviceCapabilities {
pub fn determine() -> DeviceCapabilities {
let s = Command::new("system_profiler")
.arg("SPHardwareDataType")
.arg("-json")
.output()
.unwrap()
.stdout;
let mut data = serde_json::from_slice::<SystemProfilerOutputData>(&s).unwrap();
let hardware = data.hardware.remove(0);
let model = hardware.machine_name;
let chip = hardware.chip_type;
let memory = {
let parts: Vec<&str> = hardware.physical_memory.split_ascii_whitespace().collect();
if parts.len() >= 2 {
let value = parts[0].parse::<u64>().unwrap_or(0);
if parts[1] == "GB" {
value * 1024
} else {
value
}
} else {
0
}
};
DeviceCapabilities {
flops: device_capability_data::look_up(&chip)
.expect("Failed to find FLOPS data for chip"),
model,
chip,
memory,
}
}
}
impl From<node_service::DeviceCapabilities> for DeviceCapabilities {
fn from(value: node_service::DeviceCapabilities) -> Self {
DeviceCapabilities {
model: value.model,
chip: value.chip,
memory: value.memory as u64,
flops: value.flops.map(|x| x.into()).unwrap_or_default(),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub struct DeviceFlops {
fp32: u64,
fp16: u64,
int8: u64,
pub fp32: f64,
pub fp16: f64,
pub int8: f64,
}
impl From<node_service::DeviceFlops> for DeviceFlops {
fn from(value: node_service::DeviceFlops) -> Self {
DeviceFlops {
fp32: value.fp32,
fp16: value.fp16,
int8: value.int8,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
pub struct PeerConnection {
from_id: String,
to_id: String,
description: Option<String>,
pub from_id: String,
pub to_id: String,
pub description: Option<String>,
}