exo-rs/src/main.rs
Joshua Coles d2909d5c17 Stash
2025-02-12 09:59:37 +00:00

206 lines
5.4 KiB
Rust

mod topology;
mod orchestration;
mod discovery;
mod network;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tonic::{transport::Server, Request, Response, Status};
use crate::node_service::{
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
};
use node_service::node_service_server::{NodeService, NodeServiceServer};
use node_service::{TensorRequest};
use topology::Topology;
pub mod node_service {
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
}
struct Node {
current_topology: Topology,
}
impl Default for Node {
fn default() -> Self {
Self {
current_topology: Topology::default(),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(tag = "type")]
enum OpaqueStatus {
NodeStatus(NodeStatus),
DownloadProgress(DownloadProgress),
SupportedInferenceEngines(SupportedInferenceEngines),
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct Shard {
pub model_id: String,
pub start_layer: i32,
pub end_layer: i32,
pub n_layers: i32,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct NodeStatus {
node_id: String,
status: String,
base_shard: Shard,
shard: Shard,
prompt: String,
request_id: String,
}
impl NodeStatus {
fn is_start(&self) -> bool {
self.status.starts_with("start_")
}
fn is_end(&self) -> bool {
self.status.starts_with("end_")
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct DownloadProgress {
node_id: String,
progress: Value,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct SupportedInferenceEngines {
node_id: String,
engines: Vec<String>,
}
impl Node {
fn on_opaque_status(&self, _request_id: String, status: String) {
let status = serde_json::from_str::<OpaqueStatus>(&status).unwrap();
match status {
OpaqueStatus::NodeStatus(node_status) => self.on_node_status(node_status),
OpaqueStatus::DownloadProgress(download_progress) => {
self.on_download_progress(download_progress)
}
OpaqueStatus::SupportedInferenceEngines(supported_inference_engines) => {
self.on_supported_inference_engines(supported_inference_engines)
}
}
}
fn on_node_status(&self, node_status: NodeStatus) {
println!("Received NodeStatus: {}", node_status.status);
// This seems to only be used for visualization so we can ignore it for now
// if node_status.is_start() {
// self.current_topology.active_node_id = node_status.node_id;
// } else if node_status.is_end() {
// if node_status.node_id == self.current_topology.active_node_id {
// self.current_topology.active_node_id = None;
// }
// }
}
fn on_download_progress(&self, download_progress: DownloadProgress) {
// This is only used for visualization so we can ignore it for now
}
fn on_supported_inference_engines(
&self,
supported_inference_engines: SupportedInferenceEngines,
) {
println!(
"Received SupportedInferenceEngines: {}",
supported_inference_engines.engines.join(", ")
);
// let node_id = supported_inference_engines.node_id;
// let engines = supported_inference_engines.engines;
// self.topology_inference_engines_pool.append(engines);
todo!();
}
}
#[tonic::async_trait]
impl NodeService for Node {
async fn send_prompt(
&self,
request: Request<PromptRequest>,
) -> Result<Response<Tensor>, Status> {
todo!()
}
async fn send_tensor(
&self,
request: Request<TensorRequest>,
) -> Result<Response<Tensor>, Status> {
todo!()
}
async fn send_example(
&self,
request: Request<ExampleRequest>,
) -> Result<Response<Loss>, Status> {
todo!()
}
async fn collect_topology(
&self,
request: Request<CollectTopologyRequest>,
) -> Result<Response<TopologyProto>, Status> {
todo!()
}
async fn send_result(
&self,
request: Request<SendResultRequest>,
) -> Result<Response<Empty>, Status> {
todo!()
}
async fn send_opaque_status(
&self,
request: Request<SendOpaqueStatusRequest>,
) -> Result<Response<Empty>, Status> {
let request = request.into_inner();
let request_id = request.request_id;
let status = request.status;
println!(
"Received SendOpaqueStatus request: {} {}",
request_id, status
);
Ok(Response::new(Empty {}))
}
async fn health_check(
&self,
request: Request<HealthCheckRequest>,
) -> Result<Response<HealthCheckResponse>, Status> {
Ok(Response::new(HealthCheckResponse { is_healthy: true }))
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// install global collector configured based on RUST_LOG env var.
tracing_subscriber::fmt::init();
let grpc_addr = "[::1]:50051".parse()?;
let node = Node::default();
// TODO: Also implement discovery
Server::builder()
.add_service(NodeServiceServer::new(node))
.serve(grpc_addr)
.await?;
Ok(())
}