Initial AI conversion with some work

This commit is contained in:
Joshua Coles 2025-02-12 07:03:39 +00:00
parent fed48a2868
commit 70d58995ec
4 changed files with 401 additions and 5 deletions

118
Cargo.lock generated
View File

@ -198,9 +198,12 @@ dependencies = [
"prost",
"serde",
"serde_json",
"thiserror",
"tokio",
"tonic",
"tonic-build",
"tracing",
"tracing-subscriber",
]
[[package]]
@ -460,6 +463,12 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.169"
@ -532,6 +541,16 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "object"
version = "0.36.7"
@ -547,6 +566,12 @@ version = "1.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking_lot"
version = "0.12.3"
@ -844,6 +869,15 @@ dependencies = [
"serde",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.2"
@ -909,6 +943,36 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "thiserror"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "thread_local"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]]
name = "tokio"
version = "1.43.0"
@ -1081,6 +1145,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
dependencies = [
"nu-ansi-term",
"sharded-slab",
"smallvec",
"thread_local",
"tracing-core",
"tracing-log",
]
[[package]]
@ -1095,6 +1185,12 @@ version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "want"
version = "0.3.1"
@ -1119,6 +1215,28 @@ dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.52.0"

View File

@ -9,6 +9,9 @@ serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.138"
tokio = { version = "1.43.0", features = ["full"] }
tonic = "0.12.3"
thiserror = "2.0"
tracing = "0.1"
tracing-subscriber = "0.3"
[build-dependencies]
tonic-build = "0.12.3"

260
src/discovery.rs Normal file
View File

@ -0,0 +1,260 @@
use crate::topology::DeviceCapabilities;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
net::{IpAddr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use thiserror::Error;
use tokio::{net::UdpSocket, sync::RwLock, time};
use tracing::{error};
#[derive(Error, Debug)]
pub enum DiscoveryError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON serialization error: {0}")]
Json(#[from] serde_json::Error),
#[error("UTF-8 conversion error: {0}")]
Utf8(#[from] std::string::FromUtf8Error),
#[error("Address parse error: {0}")]
AddrParse(#[from] std::net::AddrParseError),
#[error("Network interface error: {0}")]
NetworkInterface(String),
#[error("Peer error: {0}")]
Peer(String),
}
// Constants
const DEBUG: i32 = 0;
const BROADCAST_ADDR: &str = "255.255.255.255";
#[derive(Debug, Clone, Serialize, Deserialize)]
struct DiscoveryMessage {
message_type: String,
node_id: String,
grpc_port: u16,
device_capabilities: DeviceCapabilities,
priority: i32,
interface_name: String,
interface_type: String,
}
struct PeerInfo {
peer_handle: PeerHandle,
connected_at: Instant,
last_seen: Instant,
priority: i32,
}
struct UdpDiscovery {
node_id: String,
node_port: u16,
listen_port: u16,
broadcast_port: u16,
broadcast_interval: Duration,
discovery_timeout: Duration,
device_capabilities: DeviceCapabilities,
allowed_node_ids: Option<Vec<String>>,
allowed_interface_types: Option<Vec<String>>,
known_peers: Arc<RwLock<HashMap<String, PeerInfo>>>,
}
impl UdpDiscovery {
pub fn new(
node_id: String,
node_port: u16,
listen_port: u16,
broadcast_port: u16,
broadcast_interval: Duration,
discovery_timeout: Duration,
device_capabilities: DeviceCapabilities,
allowed_node_ids: Option<Vec<String>>,
allowed_interface_types: Option<Vec<String>>,
) -> Self {
Self {
node_id,
node_port,
listen_port,
broadcast_port,
broadcast_interval,
discovery_timeout,
device_capabilities,
allowed_node_ids,
allowed_interface_types,
known_peers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn start(&self) -> Result<(), DiscoveryError> {
let broadcast_task = self.task_broadcast_presence();
let listen_task = self.task_listen_for_peers();
let cleanup_task = self.task_cleanup_peers();
tokio::try_join!(broadcast_task, listen_task, cleanup_task)?;
Ok(())
}
}
impl UdpDiscovery {
async fn task_listen_for_peers(&self) -> Result<(), DiscoveryError> {
let socket = UdpSocket::bind(format!("0.0.0.0:{}", self.listen_port)).await?;
let mut buf = vec![0u8; 65535];
loop {
let (len, addr) = socket.recv_from(&mut buf).await?;
if len == 0 {
continue;
}
if let Ok(message) = String::from_utf8(buf[..len].to_vec()) {
if let Ok(discovery_message) = serde_json::from_str::<DiscoveryMessage>(&message) {
self.handle_discovery_message(discovery_message, addr)
.await
.map_err(|e| DiscoveryError::Peer(e.to_string()))?;
}
}
}
}
}
impl UdpDiscovery {
async fn task_broadcast_presence(&self) -> Result<(), DiscoveryError> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.set_broadcast(true)?;
loop {
let interfaces = get_all_ip_addresses_and_interfaces()?;
for (addr, interface_name) in interfaces {
let (interface_priority, interface_type) =
get_interface_priority_and_type(&interface_name)
.await
.map_err(|e| DiscoveryError::NetworkInterface(e.to_string()))?;
let message = DiscoveryMessage {
message_type: "discovery".to_string(),
node_id: self.node_id.clone(),
grpc_port: self.node_port,
device_capabilities: self.device_capabilities.clone(),
priority: interface_priority,
interface_name: interface_name.clone(),
interface_type: interface_type.clone(),
};
let message_json = serde_json::to_string(&message)?;
let broadcast_addr =
SocketAddr::new(get_broadcast_address(&addr).parse()?, self.broadcast_port);
if let Err(e) = socket
.send_to(message_json.as_bytes(), &broadcast_addr)
.await
{
error!("Error broadcasting to {}: {}", broadcast_addr, e);
}
}
time::sleep(self.broadcast_interval).await;
}
}
}
impl UdpDiscovery {
async fn task_cleanup_peers(&self) -> Result<(), DiscoveryError> {
loop {
let now = Instant::now();
// TODO: Do we really want a read then write lock or should we just take a write lock
// to begin with?
let peers_to_remove = {
let mut peers_to_remove = Vec::new();
let peers = self.known_peers.read().await;
for (peer_id, peer_info) in peers.iter() {
if self.should_remove_peer(peer_info, now).await {
peers_to_remove.push(peer_id.clone());
}
}
peers_to_remove
};
{
let mut peers = self.known_peers.write().await;
for peer_id in peers_to_remove {
peers.remove(&peer_id);
}
};
time::sleep(self.broadcast_interval).await;
}
}
async fn should_remove_peer(&self, peer_info: &PeerInfo, now: Instant) -> bool {
let is_connected = peer_info
.peer_handle
.is_connected()
.await
.ok()
.unwrap_or(false);
if !is_connected {
return true;
}
if now.duration_since(peer_info.connected_at) > self.discovery_timeout {
return true;
}
if now.duration_since(peer_info.last_seen) > self.discovery_timeout {
return true;
}
peer_info
.peer_handle
.health_check()
.await
.ok()
.unwrap_or(false)
}
}
fn get_broadcast_address(ip_addr: &IpAddr) -> String {
match ip_addr {
IpAddr::V4(addr) => {
let octets = addr.octets();
format!("{}.{}.{}.255", octets[0], octets[1], octets[2])
}
_ => BROADCAST_ADDR.to_string(),
}
}
// You'll need to implement these functions based on your system
fn get_all_ip_addresses_and_interfaces() -> Result<Vec<(IpAddr, String)>, DiscoveryError> {
// Implementation needed
Ok(vec![])
}
async fn get_interface_priority_and_type(
interface_name: &str,
) -> Result<(i32, String), DiscoveryError> {
// Implementation needed
Ok((0, "unknown".to_string()))
}
struct PeerHandle;
impl PeerHandle {
async fn is_connected(&self) -> Result<bool, DiscoveryError> {
Ok(true)
}
async fn health_check(&self) -> Result<bool, DiscoveryError> {
Ok(true)
}
}

View File

@ -1,4 +1,5 @@
mod topology;
mod discovery;
use serde::{Deserialize, Serialize};
use serde_json::Value;
@ -9,7 +10,7 @@ use crate::node_service::{
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
};
use node_service::node_service_server::{NodeService, NodeServiceServer};
use node_service::{Shard, TensorRequest};
use node_service::{TensorRequest};
use topology::Topology;
pub mod node_service {
@ -36,6 +37,15 @@ enum OpaqueStatus {
SupportedInferenceEngines(SupportedInferenceEngines),
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct Shard {
pub model_id: String,
pub start_layer: i32,
pub end_layer: i32,
pub n_layers: i32,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct NodeStatus {
node_id: String,
@ -69,7 +79,7 @@ struct SupportedInferenceEngines {
}
impl Node {
fn on_opaque_status(&self, request_id: String, status: String) {
fn on_opaque_status(&self, _request_id: String, status: String) {
let status = serde_json::from_str::<OpaqueStatus>(&status).unwrap();
match status {
@ -140,7 +150,7 @@ impl NodeService for Node {
async fn collect_topology(
&self,
request: Request<CollectTopologyRequest>,
) -> Result<Response<Topology>, Status> {
) -> Result<Response<TopologyProto>, Status> {
todo!()
}
@ -155,8 +165,10 @@ impl NodeService for Node {
&self,
request: Request<SendOpaqueStatusRequest>,
) -> Result<Response<Empty>, Status> {
let request_id = request.into_inner().request_id;
let status = request.into_inner().status;
let request = request.into_inner();
let request_id = request.request_id;
let status = request.status;
println!(
"Received SendOpaqueStatus request: {} {}",
request_id, status
@ -174,6 +186,9 @@ impl NodeService for Node {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// install global collector configured based on RUST_LOG env var.
tracing_subscriber::fmt::init();
let grpc_addr = "[::1]:50051".parse()?;
let node = Node::default();