From 39ac5a86dd2c7955ce4c3944f6113f9c8ab52cd0 Mon Sep 17 00:00:00 2001 From: Joshua Coles Date: Wed, 12 Feb 2025 15:00:43 +0000 Subject: [PATCH] Add basic process tensor --- src/main.rs | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 50c27f8..a5a10d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,24 @@ struct Node { } impl Node { + #[tracing::instrument] + pub(crate) async fn process_tensor( + &self, + base_shard: Shard, + tensor: Option, + request_id: String, + inference_state: Option, + ) -> Tensor { + let shard = self + .current_topology + .get_shard_for_node(base_shard, &self.node_info.node_id); + + let result = self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state); + let result = self.process_inference_result(shard, result, request_id, inference_state); + + result + } + #[tracing::instrument] pub async fn process_prompt( &self, @@ -46,6 +64,8 @@ impl Node { .get_shard_for_node(base_shard, &self.node_info.node_id); todo!(); + // The python code is a little weird wrt return types here + // if shard.is_first_layer() { // let result = self // .inference_engine @@ -218,7 +238,14 @@ impl NodeService for Node { &self, request: Request, ) -> Result, Status> { - todo!() + let request = request.into_inner(); + let shard = request.shard.expect("No shard provided").into(); + let request_id = request.request_id.expect("No request id provided"); + + let result = + self.process_tensor(shard, request.tensor, request_id, request.inference_state); + + Ok(Response::new(result.into())) } async fn send_example(