Initial AI conversion with some work
This commit is contained in:
parent
fed48a2868
commit
70d58995ec
118
Cargo.lock
generated
118
Cargo.lock
generated
@ -198,9 +198,12 @@ dependencies = [
|
||||
"prost",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tonic",
|
||||
"tonic-build",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -460,6 +463,12 @@ version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.169"
|
||||
@ -532,6 +541,16 @@ version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.46.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
|
||||
dependencies = [
|
||||
"overload",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.36.7"
|
||||
@ -547,6 +566,12 @@ version = "1.20.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.3"
|
||||
@ -844,6 +869,15 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.2"
|
||||
@ -909,6 +943,36 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "2.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.43.0"
|
||||
@ -1081,6 +1145,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"valuable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-log"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-subscriber"
|
||||
version = "0.3.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
|
||||
dependencies = [
|
||||
"nu-ansi-term",
|
||||
"sharded-slab",
|
||||
"smallvec",
|
||||
"thread_local",
|
||||
"tracing-core",
|
||||
"tracing-log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1095,6 +1185,12 @@ version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
@ -1119,6 +1215,28 @@ dependencies = [
|
||||
"wit-bindgen-rt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
|
||||
@ -9,6 +9,9 @@ serde = { version = "1.0.217", features = ["derive"] }
|
||||
serde_json = "1.0.138"
|
||||
tokio = { version = "1.43.0", features = ["full"] }
|
||||
tonic = "0.12.3"
|
||||
thiserror = "2.0"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.12.3"
|
||||
|
||||
260
src/discovery.rs
Normal file
260
src/discovery.rs
Normal file
@ -0,0 +1,260 @@
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::{net::UdpSocket, sync::RwLock, time};
|
||||
use tracing::{error};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DiscoveryError {
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("JSON serialization error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("UTF-8 conversion error: {0}")]
|
||||
Utf8(#[from] std::string::FromUtf8Error),
|
||||
|
||||
#[error("Address parse error: {0}")]
|
||||
AddrParse(#[from] std::net::AddrParseError),
|
||||
|
||||
#[error("Network interface error: {0}")]
|
||||
NetworkInterface(String),
|
||||
|
||||
#[error("Peer error: {0}")]
|
||||
Peer(String),
|
||||
}
|
||||
|
||||
// Constants
|
||||
const DEBUG: i32 = 0;
|
||||
const BROADCAST_ADDR: &str = "255.255.255.255";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct DiscoveryMessage {
|
||||
message_type: String,
|
||||
node_id: String,
|
||||
grpc_port: u16,
|
||||
device_capabilities: DeviceCapabilities,
|
||||
priority: i32,
|
||||
interface_name: String,
|
||||
interface_type: String,
|
||||
}
|
||||
|
||||
struct PeerInfo {
|
||||
peer_handle: PeerHandle,
|
||||
connected_at: Instant,
|
||||
last_seen: Instant,
|
||||
priority: i32,
|
||||
}
|
||||
|
||||
struct UdpDiscovery {
|
||||
node_id: String,
|
||||
node_port: u16,
|
||||
listen_port: u16,
|
||||
broadcast_port: u16,
|
||||
broadcast_interval: Duration,
|
||||
discovery_timeout: Duration,
|
||||
device_capabilities: DeviceCapabilities,
|
||||
allowed_node_ids: Option<Vec<String>>,
|
||||
allowed_interface_types: Option<Vec<String>>,
|
||||
known_peers: Arc<RwLock<HashMap<String, PeerInfo>>>,
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
pub fn new(
|
||||
node_id: String,
|
||||
node_port: u16,
|
||||
listen_port: u16,
|
||||
broadcast_port: u16,
|
||||
broadcast_interval: Duration,
|
||||
discovery_timeout: Duration,
|
||||
device_capabilities: DeviceCapabilities,
|
||||
allowed_node_ids: Option<Vec<String>>,
|
||||
allowed_interface_types: Option<Vec<String>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
node_port,
|
||||
listen_port,
|
||||
broadcast_port,
|
||||
broadcast_interval,
|
||||
discovery_timeout,
|
||||
device_capabilities,
|
||||
allowed_node_ids,
|
||||
allowed_interface_types,
|
||||
known_peers: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start(&self) -> Result<(), DiscoveryError> {
|
||||
let broadcast_task = self.task_broadcast_presence();
|
||||
let listen_task = self.task_listen_for_peers();
|
||||
let cleanup_task = self.task_cleanup_peers();
|
||||
|
||||
tokio::try_join!(broadcast_task, listen_task, cleanup_task)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
async fn task_listen_for_peers(&self) -> Result<(), DiscoveryError> {
|
||||
let socket = UdpSocket::bind(format!("0.0.0.0:{}", self.listen_port)).await?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
loop {
|
||||
let (len, addr) = socket.recv_from(&mut buf).await?;
|
||||
if len == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(message) = String::from_utf8(buf[..len].to_vec()) {
|
||||
if let Ok(discovery_message) = serde_json::from_str::<DiscoveryMessage>(&message) {
|
||||
self.handle_discovery_message(discovery_message, addr)
|
||||
.await
|
||||
.map_err(|e| DiscoveryError::Peer(e.to_string()))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
async fn task_broadcast_presence(&self) -> Result<(), DiscoveryError> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
socket.set_broadcast(true)?;
|
||||
|
||||
loop {
|
||||
let interfaces = get_all_ip_addresses_and_interfaces()?;
|
||||
for (addr, interface_name) in interfaces {
|
||||
let (interface_priority, interface_type) =
|
||||
get_interface_priority_and_type(&interface_name)
|
||||
.await
|
||||
.map_err(|e| DiscoveryError::NetworkInterface(e.to_string()))?;
|
||||
|
||||
let message = DiscoveryMessage {
|
||||
message_type: "discovery".to_string(),
|
||||
node_id: self.node_id.clone(),
|
||||
grpc_port: self.node_port,
|
||||
device_capabilities: self.device_capabilities.clone(),
|
||||
priority: interface_priority,
|
||||
interface_name: interface_name.clone(),
|
||||
interface_type: interface_type.clone(),
|
||||
};
|
||||
|
||||
let message_json = serde_json::to_string(&message)?;
|
||||
let broadcast_addr =
|
||||
SocketAddr::new(get_broadcast_address(&addr).parse()?, self.broadcast_port);
|
||||
|
||||
if let Err(e) = socket
|
||||
.send_to(message_json.as_bytes(), &broadcast_addr)
|
||||
.await
|
||||
{
|
||||
error!("Error broadcasting to {}: {}", broadcast_addr, e);
|
||||
}
|
||||
}
|
||||
|
||||
time::sleep(self.broadcast_interval).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
async fn task_cleanup_peers(&self) -> Result<(), DiscoveryError> {
|
||||
loop {
|
||||
let now = Instant::now();
|
||||
|
||||
// TODO: Do we really want a read then write lock or should we just take a write lock
|
||||
// to begin with?
|
||||
let peers_to_remove = {
|
||||
let mut peers_to_remove = Vec::new();
|
||||
|
||||
let peers = self.known_peers.read().await;
|
||||
for (peer_id, peer_info) in peers.iter() {
|
||||
if self.should_remove_peer(peer_info, now).await {
|
||||
peers_to_remove.push(peer_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
peers_to_remove
|
||||
};
|
||||
|
||||
{
|
||||
let mut peers = self.known_peers.write().await;
|
||||
for peer_id in peers_to_remove {
|
||||
peers.remove(&peer_id);
|
||||
}
|
||||
};
|
||||
|
||||
time::sleep(self.broadcast_interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn should_remove_peer(&self, peer_info: &PeerInfo, now: Instant) -> bool {
|
||||
let is_connected = peer_info
|
||||
.peer_handle
|
||||
.is_connected()
|
||||
.await
|
||||
.ok()
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_connected {
|
||||
return true;
|
||||
}
|
||||
|
||||
if now.duration_since(peer_info.connected_at) > self.discovery_timeout {
|
||||
return true;
|
||||
}
|
||||
|
||||
if now.duration_since(peer_info.last_seen) > self.discovery_timeout {
|
||||
return true;
|
||||
}
|
||||
|
||||
peer_info
|
||||
.peer_handle
|
||||
.health_check()
|
||||
.await
|
||||
.ok()
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_broadcast_address(ip_addr: &IpAddr) -> String {
|
||||
match ip_addr {
|
||||
IpAddr::V4(addr) => {
|
||||
let octets = addr.octets();
|
||||
format!("{}.{}.{}.255", octets[0], octets[1], octets[2])
|
||||
}
|
||||
_ => BROADCAST_ADDR.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
// You'll need to implement these functions based on your system
|
||||
fn get_all_ip_addresses_and_interfaces() -> Result<Vec<(IpAddr, String)>, DiscoveryError> {
|
||||
// Implementation needed
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn get_interface_priority_and_type(
|
||||
interface_name: &str,
|
||||
) -> Result<(i32, String), DiscoveryError> {
|
||||
// Implementation needed
|
||||
Ok((0, "unknown".to_string()))
|
||||
}
|
||||
|
||||
struct PeerHandle;
|
||||
|
||||
impl PeerHandle {
|
||||
async fn is_connected(&self) -> Result<bool, DiscoveryError> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<bool, DiscoveryError> {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
25
src/main.rs
25
src/main.rs
@ -1,4 +1,5 @@
|
||||
mod topology;
|
||||
mod discovery;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@ -9,7 +10,7 @@ use crate::node_service::{
|
||||
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
|
||||
};
|
||||
use node_service::node_service_server::{NodeService, NodeServiceServer};
|
||||
use node_service::{Shard, TensorRequest};
|
||||
use node_service::{TensorRequest};
|
||||
use topology::Topology;
|
||||
|
||||
pub mod node_service {
|
||||
@ -36,6 +37,15 @@ enum OpaqueStatus {
|
||||
SupportedInferenceEngines(SupportedInferenceEngines),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct Shard {
|
||||
pub model_id: String,
|
||||
pub start_layer: i32,
|
||||
pub end_layer: i32,
|
||||
pub n_layers: i32,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct NodeStatus {
|
||||
node_id: String,
|
||||
@ -69,7 +79,7 @@ struct SupportedInferenceEngines {
|
||||
}
|
||||
|
||||
impl Node {
|
||||
fn on_opaque_status(&self, request_id: String, status: String) {
|
||||
fn on_opaque_status(&self, _request_id: String, status: String) {
|
||||
let status = serde_json::from_str::<OpaqueStatus>(&status).unwrap();
|
||||
|
||||
match status {
|
||||
@ -140,7 +150,7 @@ impl NodeService for Node {
|
||||
async fn collect_topology(
|
||||
&self,
|
||||
request: Request<CollectTopologyRequest>,
|
||||
) -> Result<Response<Topology>, Status> {
|
||||
) -> Result<Response<TopologyProto>, Status> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
@ -155,8 +165,10 @@ impl NodeService for Node {
|
||||
&self,
|
||||
request: Request<SendOpaqueStatusRequest>,
|
||||
) -> Result<Response<Empty>, Status> {
|
||||
let request_id = request.into_inner().request_id;
|
||||
let status = request.into_inner().status;
|
||||
let request = request.into_inner();
|
||||
let request_id = request.request_id;
|
||||
let status = request.status;
|
||||
|
||||
println!(
|
||||
"Received SendOpaqueStatus request: {} {}",
|
||||
request_id, status
|
||||
@ -174,6 +186,9 @@ impl NodeService for Node {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// install global collector configured based on RUST_LOG env var.
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let grpc_addr = "[::1]:50051".parse()?;
|
||||
let node = Node::default();
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user