206 lines
5.4 KiB
Rust
206 lines
5.4 KiB
Rust
mod topology;
|
|
mod orchestration;
|
|
mod discovery;
|
|
mod network;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use tonic::{transport::Server, Request, Response, Status};
|
|
|
|
use crate::node_service::{
|
|
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
|
|
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
|
|
};
|
|
use node_service::node_service_server::{NodeService, NodeServiceServer};
|
|
use node_service::{TensorRequest};
|
|
use topology::Topology;
|
|
|
|
pub mod node_service {
|
|
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
|
|
}
|
|
|
|
struct Node {
|
|
current_topology: Topology,
|
|
}
|
|
|
|
impl Default for Node {
|
|
fn default() -> Self {
|
|
Self {
|
|
current_topology: Topology::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
#[serde(tag = "type")]
|
|
enum OpaqueStatus {
|
|
NodeStatus(NodeStatus),
|
|
DownloadProgress(DownloadProgress),
|
|
SupportedInferenceEngines(SupportedInferenceEngines),
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct Shard {
|
|
pub model_id: String,
|
|
pub start_layer: i32,
|
|
pub end_layer: i32,
|
|
pub n_layers: i32,
|
|
}
|
|
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct NodeStatus {
|
|
node_id: String,
|
|
status: String,
|
|
base_shard: Shard,
|
|
shard: Shard,
|
|
prompt: String,
|
|
request_id: String,
|
|
}
|
|
|
|
impl NodeStatus {
|
|
fn is_start(&self) -> bool {
|
|
self.status.starts_with("start_")
|
|
}
|
|
|
|
fn is_end(&self) -> bool {
|
|
self.status.starts_with("end_")
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct DownloadProgress {
|
|
node_id: String,
|
|
progress: Value,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct SupportedInferenceEngines {
|
|
node_id: String,
|
|
engines: Vec<String>,
|
|
}
|
|
|
|
impl Node {
|
|
fn on_opaque_status(&self, _request_id: String, status: String) {
|
|
let status = serde_json::from_str::<OpaqueStatus>(&status).unwrap();
|
|
|
|
match status {
|
|
OpaqueStatus::NodeStatus(node_status) => self.on_node_status(node_status),
|
|
OpaqueStatus::DownloadProgress(download_progress) => {
|
|
self.on_download_progress(download_progress)
|
|
}
|
|
OpaqueStatus::SupportedInferenceEngines(supported_inference_engines) => {
|
|
self.on_supported_inference_engines(supported_inference_engines)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn on_node_status(&self, node_status: NodeStatus) {
|
|
println!("Received NodeStatus: {}", node_status.status);
|
|
// This seems to only be used for visualization so we can ignore it for now
|
|
// if node_status.is_start() {
|
|
// self.current_topology.active_node_id = node_status.node_id;
|
|
// } else if node_status.is_end() {
|
|
// if node_status.node_id == self.current_topology.active_node_id {
|
|
// self.current_topology.active_node_id = None;
|
|
// }
|
|
// }
|
|
}
|
|
|
|
fn on_download_progress(&self, download_progress: DownloadProgress) {
|
|
// This is only used for visualization so we can ignore it for now
|
|
}
|
|
|
|
fn on_supported_inference_engines(
|
|
&self,
|
|
supported_inference_engines: SupportedInferenceEngines,
|
|
) {
|
|
println!(
|
|
"Received SupportedInferenceEngines: {}",
|
|
supported_inference_engines.engines.join(", ")
|
|
);
|
|
// let node_id = supported_inference_engines.node_id;
|
|
// let engines = supported_inference_engines.engines;
|
|
// self.topology_inference_engines_pool.append(engines);
|
|
todo!();
|
|
}
|
|
}
|
|
|
|
#[tonic::async_trait]
|
|
impl NodeService for Node {
|
|
async fn send_prompt(
|
|
&self,
|
|
request: Request<PromptRequest>,
|
|
) -> Result<Response<Tensor>, Status> {
|
|
todo!()
|
|
}
|
|
|
|
async fn send_tensor(
|
|
&self,
|
|
request: Request<TensorRequest>,
|
|
) -> Result<Response<Tensor>, Status> {
|
|
todo!()
|
|
}
|
|
|
|
async fn send_example(
|
|
&self,
|
|
request: Request<ExampleRequest>,
|
|
) -> Result<Response<Loss>, Status> {
|
|
todo!()
|
|
}
|
|
|
|
async fn collect_topology(
|
|
&self,
|
|
request: Request<CollectTopologyRequest>,
|
|
) -> Result<Response<TopologyProto>, Status> {
|
|
todo!()
|
|
}
|
|
|
|
async fn send_result(
|
|
&self,
|
|
request: Request<SendResultRequest>,
|
|
) -> Result<Response<Empty>, Status> {
|
|
todo!()
|
|
}
|
|
|
|
async fn send_opaque_status(
|
|
&self,
|
|
request: Request<SendOpaqueStatusRequest>,
|
|
) -> Result<Response<Empty>, Status> {
|
|
let request = request.into_inner();
|
|
let request_id = request.request_id;
|
|
let status = request.status;
|
|
|
|
println!(
|
|
"Received SendOpaqueStatus request: {} {}",
|
|
request_id, status
|
|
);
|
|
Ok(Response::new(Empty {}))
|
|
}
|
|
|
|
async fn health_check(
|
|
&self,
|
|
request: Request<HealthCheckRequest>,
|
|
) -> Result<Response<HealthCheckResponse>, Status> {
|
|
Ok(Response::new(HealthCheckResponse { is_healthy: true }))
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
// install global collector configured based on RUST_LOG env var.
|
|
tracing_subscriber::fmt::init();
|
|
|
|
let grpc_addr = "[::1]:50051".parse()?;
|
|
let node = Node::default();
|
|
|
|
// TODO: Also implement discovery
|
|
|
|
Server::builder()
|
|
.add_service(NodeServiceServer::new(node))
|
|
.serve(grpc_addr)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|