Initial start
This commit is contained in:
parent
108f454c52
commit
09fe616b44
1216
Cargo.lock
generated
1216
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -4,3 +4,11 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
prost = "0.13.4"
|
||||
serde = { version = "1.0.217", features = ["derive"] }
|
||||
serde_json = "1.0.138"
|
||||
tokio = { version = "1.43.0", features = ["full"] }
|
||||
tonic = "0.12.3"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.12.3"
|
||||
|
||||
4
build.rs
Normal file
4
build.rs
Normal file
@ -0,0 +1,4 @@
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tonic_build::compile_protos("node_service.proto")?;
|
||||
Ok(())
|
||||
}
|
||||
116
node_service.proto
Normal file
116
node_service.proto
Normal file
@ -0,0 +1,116 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package node_service;
|
||||
|
||||
service NodeService {
|
||||
rpc SendPrompt (PromptRequest) returns (Tensor) {}
|
||||
rpc SendTensor (TensorRequest) returns (Tensor) {}
|
||||
rpc SendExample (ExampleRequest) returns (Loss) {}
|
||||
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
|
||||
rpc SendResult (SendResultRequest) returns (Empty) {}
|
||||
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
|
||||
rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
|
||||
}
|
||||
|
||||
message Shard {
|
||||
string model_id = 1;
|
||||
int32 start_layer = 2;
|
||||
int32 end_layer = 3;
|
||||
int32 n_layers = 4;
|
||||
}
|
||||
|
||||
message PromptRequest {
|
||||
Shard shard = 1;
|
||||
string prompt = 2;
|
||||
optional string request_id = 3;
|
||||
optional InferenceState inference_state = 4;
|
||||
}
|
||||
|
||||
message TensorRequest {
|
||||
Shard shard = 1;
|
||||
Tensor tensor = 2;
|
||||
optional string request_id = 3;
|
||||
optional InferenceState inference_state = 4;
|
||||
}
|
||||
|
||||
message ExampleRequest {
|
||||
Shard shard = 1;
|
||||
Tensor example = 2;
|
||||
Tensor target = 3;
|
||||
Tensor length = 4;
|
||||
bool train = 5;
|
||||
optional string request_id = 6;
|
||||
}
|
||||
|
||||
message Loss {
|
||||
float loss = 1;
|
||||
optional Tensor grads = 2;
|
||||
}
|
||||
|
||||
message Tensor {
|
||||
bytes tensor_data = 1;
|
||||
repeated int32 shape = 2;
|
||||
string dtype = 3;
|
||||
}
|
||||
|
||||
message TensorList {
|
||||
repeated Tensor tensors = 1;
|
||||
}
|
||||
|
||||
message InferenceState {
|
||||
map<string, Tensor> tensor_data = 1;
|
||||
map<string, TensorList> tensor_list_data = 2;
|
||||
string other_data_json = 3;
|
||||
}
|
||||
|
||||
message CollectTopologyRequest {
|
||||
repeated string visited = 1;
|
||||
int32 max_depth = 2;
|
||||
}
|
||||
|
||||
message Topology {
|
||||
map<string, DeviceCapabilities> nodes = 1;
|
||||
map<string, PeerConnections> peer_graph = 2;
|
||||
}
|
||||
|
||||
message PeerConnection {
|
||||
string to_id = 1;
|
||||
optional string description = 2;
|
||||
}
|
||||
|
||||
message PeerConnections {
|
||||
repeated PeerConnection connections = 1;
|
||||
}
|
||||
|
||||
message DeviceFlops {
|
||||
double fp32 = 1;
|
||||
double fp16 = 2;
|
||||
double int8 = 3;
|
||||
}
|
||||
|
||||
message DeviceCapabilities {
|
||||
string model = 1;
|
||||
string chip = 2;
|
||||
int32 memory = 3;
|
||||
DeviceFlops flops = 4;
|
||||
}
|
||||
|
||||
message SendResultRequest {
|
||||
string request_id = 1;
|
||||
repeated int32 result = 2;
|
||||
optional Tensor tensor = 3;
|
||||
bool is_finished = 4;
|
||||
}
|
||||
|
||||
message SendOpaqueStatusRequest {
|
||||
string request_id = 1;
|
||||
string status = 2;
|
||||
}
|
||||
|
||||
message HealthCheckRequest {}
|
||||
|
||||
message HealthCheckResponse {
|
||||
bool is_healthy = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
182
src/main.rs
182
src/main.rs
@ -1,3 +1,181 @@
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tonic::{transport::Server, Request, Response, Status};
|
||||
|
||||
use crate::node_service::{
|
||||
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
|
||||
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology,
|
||||
};
|
||||
use node_service::node_service_server::{NodeService, NodeServiceServer};
|
||||
use node_service::{Shard, TensorRequest};
|
||||
|
||||
pub mod node_service {
|
||||
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
|
||||
}
|
||||
|
||||
struct Node {}
|
||||
|
||||
impl Default for Node {
|
||||
fn default() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
enum OpaqueStatus {
|
||||
NodeStatus(NodeStatus),
|
||||
DownloadProgress(DownloadProgress),
|
||||
SupportedInferenceEngines(SupportedInferenceEngines),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct NodeStatus {
|
||||
node_id: String,
|
||||
status: String,
|
||||
base_shard: Shard,
|
||||
shard: Shard,
|
||||
prompt: String,
|
||||
request_id: String,
|
||||
}
|
||||
|
||||
impl NodeStatus {
|
||||
fn is_start(&self) -> bool {
|
||||
self.status.starts_with("start_")
|
||||
}
|
||||
|
||||
fn is_end(&self) -> bool {
|
||||
self.status.starts_with("end_")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct DownloadProgress {
|
||||
node_id: String,
|
||||
progress: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct SupportedInferenceEngines {
|
||||
node_id: String,
|
||||
engines: Vec<String>,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
fn on_opaque_status(&self, request_id: String, status: String) {
|
||||
let status = serde_json::from_str::<OpaqueStatus>(&status).unwrap();
|
||||
|
||||
match status {
|
||||
OpaqueStatus::NodeStatus(node_status) => self.on_node_status(node_status),
|
||||
OpaqueStatus::DownloadProgress(download_progress) => {
|
||||
self.on_download_progress(download_progress)
|
||||
}
|
||||
OpaqueStatus::SupportedInferenceEngines(supported_inference_engines) => {
|
||||
self.on_supported_inference_engines(supported_inference_engines)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_node_status(&self, node_status: NodeStatus) {
|
||||
println!("Received NodeStatus: {}", node_status.status);
|
||||
|
||||
// if node_status.is_start() {
|
||||
// self.current_topology.active_node_id = node_status.node_id;
|
||||
// } else if node_status.is_end() {
|
||||
// if node_status.node_id == self.current_topology.active_node_id {
|
||||
// self.current_topology.active_node_id = None;
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
fn on_download_progress(&self, download_progress: DownloadProgress) {
|
||||
// This is only used for visualization so we can ignore it for now
|
||||
}
|
||||
|
||||
fn on_supported_inference_engines(
|
||||
&self,
|
||||
supported_inference_engines: SupportedInferenceEngines,
|
||||
) {
|
||||
println!(
|
||||
"Received SupportedInferenceEngines: {}",
|
||||
supported_inference_engines.engines.join(", ")
|
||||
);
|
||||
// let node_id = supported_inference_engines.node_id;
|
||||
// let engines = supported_inference_engines.engines;
|
||||
// self.topology_inference_engines_pool.append(engines);
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl NodeService for Node {
|
||||
async fn send_prompt(
|
||||
&self,
|
||||
request: Request<PromptRequest>,
|
||||
) -> Result<Response<Tensor>, Status> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn send_tensor(
|
||||
&self,
|
||||
request: Request<TensorRequest>,
|
||||
) -> Result<Response<Tensor>, Status> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn send_example(
|
||||
&self,
|
||||
request: Request<ExampleRequest>,
|
||||
) -> Result<Response<Loss>, Status> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn collect_topology(
|
||||
&self,
|
||||
request: Request<CollectTopologyRequest>,
|
||||
) -> Result<Response<Topology>, Status> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn send_result(
|
||||
&self,
|
||||
request: Request<SendResultRequest>,
|
||||
) -> Result<Response<Empty>, Status> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn send_opaque_status(
|
||||
&self,
|
||||
request: Request<SendOpaqueStatusRequest>,
|
||||
) -> Result<Response<Empty>, Status> {
|
||||
let request_id = request.into_inner().request_id;
|
||||
let status = request.into_inner().status;
|
||||
println!(
|
||||
"Received SendOpaqueStatus request: {} {}",
|
||||
request_id, status
|
||||
);
|
||||
Ok(Response::new(Empty {}))
|
||||
}
|
||||
|
||||
async fn health_check(
|
||||
&self,
|
||||
request: Request<HealthCheckRequest>,
|
||||
) -> Result<Response<HealthCheckResponse>, Status> {
|
||||
Ok(Response::new(HealthCheckResponse { is_healthy: true }))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let grpc_addr = "[::1]:50051".parse()?;
|
||||
let node = Node::default();
|
||||
|
||||
// TODO: Also implement discovery
|
||||
|
||||
Server::builder()
|
||||
.add_service(NodeServiceServer::new(node))
|
||||
.serve(grpc_addr)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user