Initial start
This commit is contained in:
parent
108f454c52
commit
09fe616b44
1216
Cargo.lock
generated
1216
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -4,3 +4,11 @@ version = "0.1.0"
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
prost = "0.13.4"
|
||||||
|
serde = { version = "1.0.217", features = ["derive"] }
|
||||||
|
serde_json = "1.0.138"
|
||||||
|
tokio = { version = "1.43.0", features = ["full"] }
|
||||||
|
tonic = "0.12.3"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
tonic-build = "0.12.3"
|
||||||
|
|||||||
4
build.rs
Normal file
4
build.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
tonic_build::compile_protos("node_service.proto")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
116
node_service.proto
Normal file
116
node_service.proto
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package node_service;
|
||||||
|
|
||||||
|
service NodeService {
|
||||||
|
rpc SendPrompt (PromptRequest) returns (Tensor) {}
|
||||||
|
rpc SendTensor (TensorRequest) returns (Tensor) {}
|
||||||
|
rpc SendExample (ExampleRequest) returns (Loss) {}
|
||||||
|
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
|
||||||
|
rpc SendResult (SendResultRequest) returns (Empty) {}
|
||||||
|
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
|
||||||
|
rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
message Shard {
|
||||||
|
string model_id = 1;
|
||||||
|
int32 start_layer = 2;
|
||||||
|
int32 end_layer = 3;
|
||||||
|
int32 n_layers = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PromptRequest {
|
||||||
|
Shard shard = 1;
|
||||||
|
string prompt = 2;
|
||||||
|
optional string request_id = 3;
|
||||||
|
optional InferenceState inference_state = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TensorRequest {
|
||||||
|
Shard shard = 1;
|
||||||
|
Tensor tensor = 2;
|
||||||
|
optional string request_id = 3;
|
||||||
|
optional InferenceState inference_state = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExampleRequest {
|
||||||
|
Shard shard = 1;
|
||||||
|
Tensor example = 2;
|
||||||
|
Tensor target = 3;
|
||||||
|
Tensor length = 4;
|
||||||
|
bool train = 5;
|
||||||
|
optional string request_id = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Loss {
|
||||||
|
float loss = 1;
|
||||||
|
optional Tensor grads = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Tensor {
|
||||||
|
bytes tensor_data = 1;
|
||||||
|
repeated int32 shape = 2;
|
||||||
|
string dtype = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TensorList {
|
||||||
|
repeated Tensor tensors = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message InferenceState {
|
||||||
|
map<string, Tensor> tensor_data = 1;
|
||||||
|
map<string, TensorList> tensor_list_data = 2;
|
||||||
|
string other_data_json = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CollectTopologyRequest {
|
||||||
|
repeated string visited = 1;
|
||||||
|
int32 max_depth = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Topology {
|
||||||
|
map<string, DeviceCapabilities> nodes = 1;
|
||||||
|
map<string, PeerConnections> peer_graph = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PeerConnection {
|
||||||
|
string to_id = 1;
|
||||||
|
optional string description = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PeerConnections {
|
||||||
|
repeated PeerConnection connections = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DeviceFlops {
|
||||||
|
double fp32 = 1;
|
||||||
|
double fp16 = 2;
|
||||||
|
double int8 = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DeviceCapabilities {
|
||||||
|
string model = 1;
|
||||||
|
string chip = 2;
|
||||||
|
int32 memory = 3;
|
||||||
|
DeviceFlops flops = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SendResultRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
repeated int32 result = 2;
|
||||||
|
optional Tensor tensor = 3;
|
||||||
|
bool is_finished = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SendOpaqueStatusRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
string status = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message HealthCheckRequest {}
|
||||||
|
|
||||||
|
message HealthCheckResponse {
|
||||||
|
bool is_healthy = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Empty {}
|
||||||
182
src/main.rs
182
src/main.rs
@ -1,3 +1,181 @@
|
|||||||
fn main() {
|
use serde::{Deserialize, Serialize};
|
||||||
println!("Hello, world!");
|
use serde_json::Value;
|
||||||
|
use tonic::{transport::Server, Request, Response, Status};
|
||||||
|
|
||||||
|
use crate::node_service::{
|
||||||
|
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
|
||||||
|
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology,
|
||||||
|
};
|
||||||
|
use node_service::node_service_server::{NodeService, NodeServiceServer};
|
||||||
|
use node_service::{Shard, TensorRequest};
|
||||||
|
|
||||||
|
pub mod node_service {
|
||||||
|
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Node {}
|
||||||
|
|
||||||
|
impl Default for Node {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
enum OpaqueStatus {
|
||||||
|
NodeStatus(NodeStatus),
|
||||||
|
DownloadProgress(DownloadProgress),
|
||||||
|
SupportedInferenceEngines(SupportedInferenceEngines),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
|
struct NodeStatus {
|
||||||
|
node_id: String,
|
||||||
|
status: String,
|
||||||
|
base_shard: Shard,
|
||||||
|
shard: Shard,
|
||||||
|
prompt: String,
|
||||||
|
request_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NodeStatus {
|
||||||
|
fn is_start(&self) -> bool {
|
||||||
|
self.status.starts_with("start_")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_end(&self) -> bool {
|
||||||
|
self.status.starts_with("end_")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
|
struct DownloadProgress {
|
||||||
|
node_id: String,
|
||||||
|
progress: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
|
struct SupportedInferenceEngines {
|
||||||
|
node_id: String,
|
||||||
|
engines: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Node {
|
||||||
|
fn on_opaque_status(&self, request_id: String, status: String) {
|
||||||
|
let status = serde_json::from_str::<OpaqueStatus>(&status).unwrap();
|
||||||
|
|
||||||
|
match status {
|
||||||
|
OpaqueStatus::NodeStatus(node_status) => self.on_node_status(node_status),
|
||||||
|
OpaqueStatus::DownloadProgress(download_progress) => {
|
||||||
|
self.on_download_progress(download_progress)
|
||||||
|
}
|
||||||
|
OpaqueStatus::SupportedInferenceEngines(supported_inference_engines) => {
|
||||||
|
self.on_supported_inference_engines(supported_inference_engines)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_node_status(&self, node_status: NodeStatus) {
|
||||||
|
println!("Received NodeStatus: {}", node_status.status);
|
||||||
|
|
||||||
|
// if node_status.is_start() {
|
||||||
|
// self.current_topology.active_node_id = node_status.node_id;
|
||||||
|
// } else if node_status.is_end() {
|
||||||
|
// if node_status.node_id == self.current_topology.active_node_id {
|
||||||
|
// self.current_topology.active_node_id = None;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_download_progress(&self, download_progress: DownloadProgress) {
|
||||||
|
// This is only used for visualization so we can ignore it for now
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_supported_inference_engines(
|
||||||
|
&self,
|
||||||
|
supported_inference_engines: SupportedInferenceEngines,
|
||||||
|
) {
|
||||||
|
println!(
|
||||||
|
"Received SupportedInferenceEngines: {}",
|
||||||
|
supported_inference_engines.engines.join(", ")
|
||||||
|
);
|
||||||
|
// let node_id = supported_inference_engines.node_id;
|
||||||
|
// let engines = supported_inference_engines.engines;
|
||||||
|
// self.topology_inference_engines_pool.append(engines);
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tonic::async_trait]
|
||||||
|
impl NodeService for Node {
|
||||||
|
async fn send_prompt(
|
||||||
|
&self,
|
||||||
|
request: Request<PromptRequest>,
|
||||||
|
) -> Result<Response<Tensor>, Status> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_tensor(
|
||||||
|
&self,
|
||||||
|
request: Request<TensorRequest>,
|
||||||
|
) -> Result<Response<Tensor>, Status> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_example(
|
||||||
|
&self,
|
||||||
|
request: Request<ExampleRequest>,
|
||||||
|
) -> Result<Response<Loss>, Status> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn collect_topology(
|
||||||
|
&self,
|
||||||
|
request: Request<CollectTopologyRequest>,
|
||||||
|
) -> Result<Response<Topology>, Status> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_result(
|
||||||
|
&self,
|
||||||
|
request: Request<SendResultRequest>,
|
||||||
|
) -> Result<Response<Empty>, Status> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_opaque_status(
|
||||||
|
&self,
|
||||||
|
request: Request<SendOpaqueStatusRequest>,
|
||||||
|
) -> Result<Response<Empty>, Status> {
|
||||||
|
let request_id = request.into_inner().request_id;
|
||||||
|
let status = request.into_inner().status;
|
||||||
|
println!(
|
||||||
|
"Received SendOpaqueStatus request: {} {}",
|
||||||
|
request_id, status
|
||||||
|
);
|
||||||
|
Ok(Response::new(Empty {}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(
|
||||||
|
&self,
|
||||||
|
request: Request<HealthCheckRequest>,
|
||||||
|
) -> Result<Response<HealthCheckResponse>, Status> {
|
||||||
|
Ok(Response::new(HealthCheckResponse { is_healthy: true }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let grpc_addr = "[::1]:50051".parse()?;
|
||||||
|
let node = Node::default();
|
||||||
|
|
||||||
|
// TODO: Also implement discovery
|
||||||
|
|
||||||
|
Server::builder()
|
||||||
|
.add_service(NodeServiceServer::new(node))
|
||||||
|
.serve(grpc_addr)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user