Initial start

This commit is contained in:
Joshua Coles 2025-02-12 06:27:56 +00:00
parent 108f454c52
commit 09fe616b44
5 changed files with 1524 additions and 2 deletions

1216
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -4,3 +4,11 @@ version = "0.1.0"
edition = "2021"
[dependencies]
prost = "0.13.4"
serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.138"
tokio = { version = "1.43.0", features = ["full"] }
tonic = "0.12.3"
[build-dependencies]
tonic-build = "0.12.3"

4
build.rs Normal file
View File

@ -0,0 +1,4 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("node_service.proto")?;
Ok(())
}

116
node_service.proto Normal file
View File

@ -0,0 +1,116 @@
syntax = "proto3";
package node_service;
service NodeService {
rpc SendPrompt (PromptRequest) returns (Tensor) {}
rpc SendTensor (TensorRequest) returns (Tensor) {}
rpc SendExample (ExampleRequest) returns (Loss) {}
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
rpc SendResult (SendResultRequest) returns (Empty) {}
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
}
message Shard {
string model_id = 1;
int32 start_layer = 2;
int32 end_layer = 3;
int32 n_layers = 4;
}
message PromptRequest {
Shard shard = 1;
string prompt = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message TensorRequest {
Shard shard = 1;
Tensor tensor = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message ExampleRequest {
Shard shard = 1;
Tensor example = 2;
Tensor target = 3;
Tensor length = 4;
bool train = 5;
optional string request_id = 6;
}
message Loss {
float loss = 1;
optional Tensor grads = 2;
}
message Tensor {
bytes tensor_data = 1;
repeated int32 shape = 2;
string dtype = 3;
}
message TensorList {
repeated Tensor tensors = 1;
}
message InferenceState {
map<string, Tensor> tensor_data = 1;
map<string, TensorList> tensor_list_data = 2;
string other_data_json = 3;
}
message CollectTopologyRequest {
repeated string visited = 1;
int32 max_depth = 2;
}
message Topology {
map<string, DeviceCapabilities> nodes = 1;
map<string, PeerConnections> peer_graph = 2;
}
message PeerConnection {
string to_id = 1;
optional string description = 2;
}
message PeerConnections {
repeated PeerConnection connections = 1;
}
message DeviceFlops {
double fp32 = 1;
double fp16 = 2;
double int8 = 3;
}
message DeviceCapabilities {
string model = 1;
string chip = 2;
int32 memory = 3;
DeviceFlops flops = 4;
}
message SendResultRequest {
string request_id = 1;
repeated int32 result = 2;
optional Tensor tensor = 3;
bool is_finished = 4;
}
message SendOpaqueStatusRequest {
string request_id = 1;
string status = 2;
}
message HealthCheckRequest {}
message HealthCheckResponse {
bool is_healthy = 1;
}
message Empty {}

View File

@ -1,3 +1,181 @@
fn main() {
println!("Hello, world!");
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tonic::{transport::Server, Request, Response, Status};
use crate::node_service::{
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology,
};
use node_service::node_service_server::{NodeService, NodeServiceServer};
use node_service::{Shard, TensorRequest};
pub mod node_service {
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
}
struct Node {}
impl Default for Node {
fn default() -> Self {
Self {}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(tag = "type")]
enum OpaqueStatus {
NodeStatus(NodeStatus),
DownloadProgress(DownloadProgress),
SupportedInferenceEngines(SupportedInferenceEngines),
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct NodeStatus {
node_id: String,
status: String,
base_shard: Shard,
shard: Shard,
prompt: String,
request_id: String,
}
impl NodeStatus {
fn is_start(&self) -> bool {
self.status.starts_with("start_")
}
fn is_end(&self) -> bool {
self.status.starts_with("end_")
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct DownloadProgress {
node_id: String,
progress: Value,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct SupportedInferenceEngines {
node_id: String,
engines: Vec<String>,
}
impl Node {
fn on_opaque_status(&self, request_id: String, status: String) {
let status = serde_json::from_str::<OpaqueStatus>(&status).unwrap();
match status {
OpaqueStatus::NodeStatus(node_status) => self.on_node_status(node_status),
OpaqueStatus::DownloadProgress(download_progress) => {
self.on_download_progress(download_progress)
}
OpaqueStatus::SupportedInferenceEngines(supported_inference_engines) => {
self.on_supported_inference_engines(supported_inference_engines)
}
}
}
fn on_node_status(&self, node_status: NodeStatus) {
println!("Received NodeStatus: {}", node_status.status);
// if node_status.is_start() {
// self.current_topology.active_node_id = node_status.node_id;
// } else if node_status.is_end() {
// if node_status.node_id == self.current_topology.active_node_id {
// self.current_topology.active_node_id = None;
// }
// }
}
fn on_download_progress(&self, download_progress: DownloadProgress) {
// This is only used for visualization so we can ignore it for now
}
fn on_supported_inference_engines(
&self,
supported_inference_engines: SupportedInferenceEngines,
) {
println!(
"Received SupportedInferenceEngines: {}",
supported_inference_engines.engines.join(", ")
);
// let node_id = supported_inference_engines.node_id;
// let engines = supported_inference_engines.engines;
// self.topology_inference_engines_pool.append(engines);
todo!();
}
}
#[tonic::async_trait]
impl NodeService for Node {
async fn send_prompt(
&self,
request: Request<PromptRequest>,
) -> Result<Response<Tensor>, Status> {
todo!()
}
async fn send_tensor(
&self,
request: Request<TensorRequest>,
) -> Result<Response<Tensor>, Status> {
todo!()
}
async fn send_example(
&self,
request: Request<ExampleRequest>,
) -> Result<Response<Loss>, Status> {
todo!()
}
async fn collect_topology(
&self,
request: Request<CollectTopologyRequest>,
) -> Result<Response<Topology>, Status> {
todo!()
}
async fn send_result(
&self,
request: Request<SendResultRequest>,
) -> Result<Response<Empty>, Status> {
todo!()
}
async fn send_opaque_status(
&self,
request: Request<SendOpaqueStatusRequest>,
) -> Result<Response<Empty>, Status> {
let request_id = request.into_inner().request_id;
let status = request.into_inner().status;
println!(
"Received SendOpaqueStatus request: {} {}",
request_id, status
);
Ok(Response::new(Empty {}))
}
async fn health_check(
&self,
request: Request<HealthCheckRequest>,
) -> Result<Response<HealthCheckResponse>, Status> {
Ok(Response::new(HealthCheckResponse { is_healthy: true }))
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let grpc_addr = "[::1]:50051".parse()?;
let node = Node::default();
// TODO: Also implement discovery
Server::builder()
.add_service(NodeServiceServer::new(node))
.serve(grpc_addr)
.await?;
Ok(())
}