329 lines
9.1 KiB
Rust
329 lines
9.1 KiB
Rust
mod device_capability_data;
|
|
mod discovery;
|
|
mod network;
|
|
mod orchestration;
|
|
mod partitioning;
|
|
mod topology;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use tonic::{transport::Server, Request, Response, Status};
|
|
|
|
use crate::discovery::{NodeInfo, UdpDiscovery};
|
|
use crate::node_service::{
|
|
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse,
|
|
InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor,
|
|
Topology as TopologyProto,
|
|
};
|
|
use node_service::node_service_server::{NodeService, NodeServiceServer};
|
|
use node_service::TensorRequest;
|
|
use std::collections::HashSet;
|
|
use topology::Topology;
|
|
use uuid::Uuid;
|
|
|
|
pub mod node_service {
|
|
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct Node {
|
|
node_info: NodeInfo,
|
|
current_topology: Topology,
|
|
udp_discovery: UdpDiscovery,
|
|
}
|
|
|
|
impl Node {
|
|
#[tracing::instrument]
|
|
pub async fn process_prompt(
|
|
&self,
|
|
base_shard: Shard,
|
|
prompt: String,
|
|
request_id: String,
|
|
inference_state: Option<InferenceState>,
|
|
) {
|
|
let shard = self
|
|
.current_topology
|
|
.get_shard_for_node(base_shard, &self.node_info.node_id);
|
|
|
|
todo!();
|
|
// if shard.is_first_layer() {
|
|
// let result = self
|
|
// .inference_engine
|
|
// .infer_prompt(request_id, shard, prompt, inference_state)
|
|
// .await;
|
|
// self.process_inference_result(shard, result, request_id, inference_state)
|
|
// } else {
|
|
// self.forward_prompt(shard, prompt, request_id, inference_state)
|
|
// }
|
|
}
|
|
}
|
|
|
|
impl Default for Node {
|
|
fn default() -> Self {
|
|
let node_info = NodeInfo::default();
|
|
|
|
Self {
|
|
node_info: node_info.clone(),
|
|
current_topology: Topology::default(),
|
|
udp_discovery: UdpDiscovery::new(node_info),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
#[serde(tag = "type")]
|
|
enum OpaqueStatus {
|
|
NodeStatus(NodeStatus),
|
|
DownloadProgress(DownloadProgress),
|
|
SupportedInferenceEngines(SupportedInferenceEngines),
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct Shard {
|
|
pub model_id: String,
|
|
pub start_layer: u32,
|
|
pub end_layer: u32,
|
|
#[serde(rename = "n_layers")]
|
|
pub total_layers: u32,
|
|
}
|
|
|
|
impl Shard {
|
|
pub fn is_first_layer(&self) -> bool {
|
|
self.start_layer == 0
|
|
}
|
|
|
|
pub fn is_last_layer(&self) -> bool {
|
|
self.end_layer == self.total_layers - 1
|
|
}
|
|
|
|
pub fn len(&self) -> u32 {
|
|
self.end_layer - self.start_layer + 1
|
|
}
|
|
}
|
|
|
|
impl From<node_service::Shard> for Shard {
|
|
fn from(proto: node_service::Shard) -> Self {
|
|
Self {
|
|
model_id: proto.model_id,
|
|
start_layer: proto.start_layer as u32,
|
|
end_layer: proto.end_layer as u32,
|
|
total_layers: proto.n_layers as u32,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct NodeStatus {
|
|
node_id: String,
|
|
status: String,
|
|
base_shard: Shard,
|
|
shard: Shard,
|
|
prompt: String,
|
|
request_id: String,
|
|
}
|
|
|
|
impl NodeStatus {
|
|
fn is_start(&self) -> bool {
|
|
self.status.starts_with("start_")
|
|
}
|
|
|
|
fn is_end(&self) -> bool {
|
|
self.status.starts_with("end_")
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct DownloadProgress {
|
|
node_id: String,
|
|
progress: Value,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct SupportedInferenceEngines {
|
|
node_id: String,
|
|
engines: Vec<String>,
|
|
}
|
|
|
|
impl Node {
|
|
fn on_opaque_status(&self, _request_id: String, status: String) {
|
|
let status = serde_json::from_str::<OpaqueStatus>(&status).unwrap();
|
|
|
|
match status {
|
|
OpaqueStatus::NodeStatus(node_status) => self.on_node_status(node_status),
|
|
OpaqueStatus::DownloadProgress(download_progress) => {
|
|
self.on_download_progress(download_progress)
|
|
}
|
|
OpaqueStatus::SupportedInferenceEngines(supported_inference_engines) => {
|
|
self.on_supported_inference_engines(supported_inference_engines)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn on_node_status(&self, node_status: NodeStatus) {
|
|
println!("Received NodeStatus: {}", node_status.status);
|
|
// This seems to only be used for visualization so we can ignore it for now
|
|
// if node_status.is_start() {
|
|
// self.current_topology.active_node_id = node_status.node_id;
|
|
// } else if node_status.is_end() {
|
|
// if node_status.node_id == self.current_topology.active_node_id {
|
|
// self.current_topology.active_node_id = None;
|
|
// }
|
|
// }
|
|
}
|
|
|
|
fn on_download_progress(&self, download_progress: DownloadProgress) {
|
|
// This is only used for visualization so we can ignore it for now
|
|
}
|
|
|
|
fn on_supported_inference_engines(
|
|
&self,
|
|
supported_inference_engines: SupportedInferenceEngines,
|
|
) {
|
|
println!(
|
|
"Received SupportedInferenceEngines: {}",
|
|
supported_inference_engines.engines.join(", ")
|
|
);
|
|
// let node_id = supported_inference_engines.node_id;
|
|
// let engines = supported_inference_engines.engines;
|
|
// self.topology_inference_engines_pool.append(engines);
|
|
todo!();
|
|
}
|
|
}
|
|
|
|
#[tonic::async_trait]
|
|
impl NodeService for Node {
|
|
async fn send_prompt(
|
|
&self,
|
|
request: Request<PromptRequest>,
|
|
) -> Result<Response<Tensor>, Status> {
|
|
let request = request.into_inner();
|
|
let request_id = request
|
|
.request_id
|
|
.unwrap_or_else(|| Uuid::new_v4().to_string());
|
|
|
|
let result = self.process_prompt(
|
|
request
|
|
.shard
|
|
.expect("No shard given. ExoPy does not allow this")
|
|
.into(),
|
|
request.prompt,
|
|
request_id,
|
|
request.inference_state,
|
|
);
|
|
|
|
todo!();
|
|
}
|
|
|
|
async fn send_tensor(
|
|
&self,
|
|
request: Request<TensorRequest>,
|
|
) -> Result<Response<Tensor>, Status> {
|
|
todo!()
|
|
}
|
|
|
|
async fn send_example(
|
|
&self,
|
|
request: Request<ExampleRequest>,
|
|
) -> Result<Response<Loss>, Status> {
|
|
todo!()
|
|
}
|
|
|
|
// TODO: Why aren't we using the request?
|
|
async fn collect_topology(
|
|
&self,
|
|
request: Request<CollectTopologyRequest>,
|
|
) -> Result<Response<TopologyProto>, Status> {
|
|
let request = request.into_inner();
|
|
let max_depth = request.max_depth as u8;
|
|
let visited = request.visited;
|
|
|
|
self.update_topology_inner(max_depth, visited.into_iter().collect())
|
|
.await;
|
|
|
|
Ok(Response::new(self.current_topology.clone().into()))
|
|
}
|
|
|
|
async fn send_result(
|
|
&self,
|
|
request: Request<SendResultRequest>,
|
|
) -> Result<Response<Empty>, Status> {
|
|
todo!()
|
|
}
|
|
|
|
async fn send_opaque_status(
|
|
&self,
|
|
request: Request<SendOpaqueStatusRequest>,
|
|
) -> Result<Response<Empty>, Status> {
|
|
let request = request.into_inner();
|
|
let request_id = request.request_id;
|
|
let status = request.status;
|
|
|
|
println!(
|
|
"Received SendOpaqueStatus request: {} {}",
|
|
request_id, status
|
|
);
|
|
Ok(Response::new(Empty {}))
|
|
}
|
|
|
|
async fn health_check(
|
|
&self,
|
|
request: Request<HealthCheckRequest>,
|
|
) -> Result<Response<HealthCheckResponse>, Status> {
|
|
Ok(Response::new(HealthCheckResponse { is_healthy: true }))
|
|
}
|
|
}
|
|
|
|
impl Node {
|
|
async fn update_topology(&mut self) {
|
|
let overall_max_depth = 4;
|
|
let visited: HashSet<String> = HashSet::new();
|
|
|
|
self.current_topology = self.update_topology_inner(overall_max_depth, visited).await;
|
|
}
|
|
|
|
async fn update_topology_inner(&self, max_depth: u8, mut visited: HashSet<String>) -> Topology {
|
|
let mut new_topology = Topology::default();
|
|
|
|
new_topology.update_node(
|
|
self.node_info.node_id.clone(),
|
|
self.node_info.device_capabilities.clone(),
|
|
);
|
|
|
|
for peer in self.udp_discovery.peers.lock().await.values() {
|
|
new_topology.update_node(peer.node_id.clone(), peer.device_capabilities.clone());
|
|
new_topology.update_edge(
|
|
self.node_info.node_id.clone(),
|
|
peer.node_id.clone(),
|
|
peer.description.clone(),
|
|
);
|
|
|
|
visited.insert(peer.node_id.clone());
|
|
|
|
if !visited.contains(&peer.node_id) {
|
|
let topology = peer.collect_topology(visited.clone(), max_depth - 1).await;
|
|
new_topology.merge_restricted(&peer.node_id, topology);
|
|
}
|
|
}
|
|
|
|
new_topology
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
// install global collector configured based on RUST_LOG env var.
|
|
tracing_subscriber::fmt::init();
|
|
|
|
let grpc_addr = "[::1]:50051".parse()?;
|
|
let node: Node = Node::default();
|
|
|
|
// TODO: Also implement discovery
|
|
|
|
Server::builder()
|
|
.add_service(NodeServiceServer::new(node))
|
|
.serve(grpc_addr)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|