diff --git a/src/main.rs b/src/main.rs index 82a9db7..feb4e2f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,23 +1,30 @@ +mod topology; + use serde::{Deserialize, Serialize}; use serde_json::Value; use tonic::{transport::Server, Request, Response, Status}; use crate::node_service::{ CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss, - PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology, + PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto, }; use node_service::node_service_server::{NodeService, NodeServiceServer}; use node_service::{Shard, TensorRequest}; +use topology::Topology; pub mod node_service { tonic::include_proto!("node_service"); // The string specified here must match the proto package name } -struct Node {} +struct Node { + current_topology: Topology, +} impl Default for Node { fn default() -> Self { - Self {} + Self { + current_topology: Topology::default(), + } } } @@ -78,7 +85,7 @@ impl Node { fn on_node_status(&self, node_status: NodeStatus) { println!("Received NodeStatus: {}", node_status.status); - + // This seems to only be used for visualization so we can ignore it for now // if node_status.is_start() { // self.current_topology.active_node_id = node_status.node_id; // } else if node_status.is_end() { diff --git a/src/topology.rs b/src/topology.rs new file mode 100644 index 0000000..2f5fe66 --- /dev/null +++ b/src/topology.rs @@ -0,0 +1,41 @@ +use std::collections::HashMap; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct Topology { + nodes: HashMap, + peer_graph: HashMap>, + active_node_id: Option +} + +impl Default for Topology { + fn default() -> Self { + Self { + nodes: HashMap::new(), + peer_graph: HashMap::new(), + active_node_id: None, + } + } +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct DeviceCapabilities { + model: String, + chip: String, + memory: u64, + flops: DeviceFlops, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct DeviceFlops { + fp32: u64, + fp16: u64, + int8: u64, +} + +#[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)] +pub struct PeerConnection { + from_id: String, + to_id: String, + description: Option, +}