exo-rs/src/main.rs

329 lines
9.1 KiB
Rust

mod device_capability_data;
mod discovery;
mod network;
mod orchestration;
mod partitioning;
mod topology;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tonic::{transport::Server, Request, Response, Status};
use crate::discovery::{NodeInfo, UdpDiscovery};
use crate::node_service::{
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse,
InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor,
Topology as TopologyProto,
};
use node_service::node_service_server::{NodeService, NodeServiceServer};
use node_service::TensorRequest;
use std::collections::HashSet;
use topology::Topology;
use uuid::Uuid;
pub mod node_service {
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
}
#[derive(Debug)]
struct Node {
node_info: NodeInfo,
current_topology: Topology,
udp_discovery: UdpDiscovery,
}
impl Node {
#[tracing::instrument]
pub async fn process_prompt(
&self,
base_shard: Shard,
prompt: String,
request_id: String,
inference_state: Option<InferenceState>,
) {
let shard = self
.current_topology
.get_shard_for_node(base_shard, &self.node_info.node_id);
todo!();
// if shard.is_first_layer() {
// let result = self
// .inference_engine
// .infer_prompt(request_id, shard, prompt, inference_state)
// .await;
// self.process_inference_result(shard, result, request_id, inference_state)
// } else {
// self.forward_prompt(shard, prompt, request_id, inference_state)
// }
}
}
impl Default for Node {
fn default() -> Self {
let node_info = NodeInfo::default();
Self {
node_info: node_info.clone(),
current_topology: Topology::default(),
udp_discovery: UdpDiscovery::new(node_info),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(tag = "type")]
enum OpaqueStatus {
NodeStatus(NodeStatus),
DownloadProgress(DownloadProgress),
SupportedInferenceEngines(SupportedInferenceEngines),
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct Shard {
pub model_id: String,
pub start_layer: u32,
pub end_layer: u32,
#[serde(rename = "n_layers")]
pub total_layers: u32,
}
impl Shard {
pub fn is_first_layer(&self) -> bool {
self.start_layer == 0
}
pub fn is_last_layer(&self) -> bool {
self.end_layer == self.total_layers - 1
}
pub fn len(&self) -> u32 {
self.end_layer - self.start_layer + 1
}
}
impl From<node_service::Shard> for Shard {
fn from(proto: node_service::Shard) -> Self {
Self {
model_id: proto.model_id,
start_layer: proto.start_layer as u32,
end_layer: proto.end_layer as u32,
total_layers: proto.n_layers as u32,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct NodeStatus {
node_id: String,
status: String,
base_shard: Shard,
shard: Shard,
prompt: String,
request_id: String,
}
impl NodeStatus {
fn is_start(&self) -> bool {
self.status.starts_with("start_")
}
fn is_end(&self) -> bool {
self.status.starts_with("end_")
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct DownloadProgress {
node_id: String,
progress: Value,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct SupportedInferenceEngines {
node_id: String,
engines: Vec<String>,
}
impl Node {
fn on_opaque_status(&self, _request_id: String, status: String) {
let status = serde_json::from_str::<OpaqueStatus>(&status).unwrap();
match status {
OpaqueStatus::NodeStatus(node_status) => self.on_node_status(node_status),
OpaqueStatus::DownloadProgress(download_progress) => {
self.on_download_progress(download_progress)
}
OpaqueStatus::SupportedInferenceEngines(supported_inference_engines) => {
self.on_supported_inference_engines(supported_inference_engines)
}
}
}
fn on_node_status(&self, node_status: NodeStatus) {
println!("Received NodeStatus: {}", node_status.status);
// This seems to only be used for visualization so we can ignore it for now
// if node_status.is_start() {
// self.current_topology.active_node_id = node_status.node_id;
// } else if node_status.is_end() {
// if node_status.node_id == self.current_topology.active_node_id {
// self.current_topology.active_node_id = None;
// }
// }
}
fn on_download_progress(&self, download_progress: DownloadProgress) {
// This is only used for visualization so we can ignore it for now
}
fn on_supported_inference_engines(
&self,
supported_inference_engines: SupportedInferenceEngines,
) {
println!(
"Received SupportedInferenceEngines: {}",
supported_inference_engines.engines.join(", ")
);
// let node_id = supported_inference_engines.node_id;
// let engines = supported_inference_engines.engines;
// self.topology_inference_engines_pool.append(engines);
todo!();
}
}
#[tonic::async_trait]
impl NodeService for Node {
async fn send_prompt(
&self,
request: Request<PromptRequest>,
) -> Result<Response<Tensor>, Status> {
let request = request.into_inner();
let request_id = request
.request_id
.unwrap_or_else(|| Uuid::new_v4().to_string());
let result = self.process_prompt(
request
.shard
.expect("No shard given. ExoPy does not allow this")
.into(),
request.prompt,
request_id,
request.inference_state,
);
todo!();
}
async fn send_tensor(
&self,
request: Request<TensorRequest>,
) -> Result<Response<Tensor>, Status> {
todo!()
}
async fn send_example(
&self,
request: Request<ExampleRequest>,
) -> Result<Response<Loss>, Status> {
todo!()
}
// TODO: Why aren't we using the request?
async fn collect_topology(
&self,
request: Request<CollectTopologyRequest>,
) -> Result<Response<TopologyProto>, Status> {
let request = request.into_inner();
let max_depth = request.max_depth as u8;
let visited = request.visited;
self.update_topology_inner(max_depth, visited.into_iter().collect())
.await;
Ok(Response::new(self.current_topology.clone().into()))
}
async fn send_result(
&self,
request: Request<SendResultRequest>,
) -> Result<Response<Empty>, Status> {
todo!()
}
async fn send_opaque_status(
&self,
request: Request<SendOpaqueStatusRequest>,
) -> Result<Response<Empty>, Status> {
let request = request.into_inner();
let request_id = request.request_id;
let status = request.status;
println!(
"Received SendOpaqueStatus request: {} {}",
request_id, status
);
Ok(Response::new(Empty {}))
}
async fn health_check(
&self,
request: Request<HealthCheckRequest>,
) -> Result<Response<HealthCheckResponse>, Status> {
Ok(Response::new(HealthCheckResponse { is_healthy: true }))
}
}
impl Node {
async fn update_topology(&mut self) {
let overall_max_depth = 4;
let visited: HashSet<String> = HashSet::new();
self.current_topology = self.update_topology_inner(overall_max_depth, visited).await;
}
async fn update_topology_inner(&self, max_depth: u8, mut visited: HashSet<String>) -> Topology {
let mut new_topology = Topology::default();
new_topology.update_node(
self.node_info.node_id.clone(),
self.node_info.device_capabilities.clone(),
);
for peer in self.udp_discovery.peers.lock().await.values() {
new_topology.update_node(peer.node_id.clone(), peer.device_capabilities.clone());
new_topology.update_edge(
self.node_info.node_id.clone(),
peer.node_id.clone(),
peer.description.clone(),
);
visited.insert(peer.node_id.clone());
if !visited.contains(&peer.node_id) {
let topology = peer.collect_topology(visited.clone(), max_depth - 1).await;
new_topology.merge_restricted(&peer.node_id, topology);
}
}
new_topology
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// install global collector configured based on RUST_LOG env var.
tracing_subscriber::fmt::init();
let grpc_addr = "[::1]:50051".parse()?;
let node: Node = Node::default();
// TODO: Also implement discovery
Server::builder()
.add_service(NodeServiceServer::new(node))
.serve(grpc_addr)
.await?;
Ok(())
}