Add basic process tensor

This commit is contained in:
Joshua Coles 2025-02-12 15:00:43 +00:00
parent a631a1c0a9
commit 39ac5a86dd

View File

@ -33,6 +33,24 @@ struct Node {
}
impl Node {
#[tracing::instrument]
pub(crate) async fn process_tensor(
&self,
base_shard: Shard,
tensor: Option<Tensor>,
request_id: String,
inference_state: Option<InferenceState>,
) -> Tensor {
let shard = self
.current_topology
.get_shard_for_node(base_shard, &self.node_info.node_id);
let result = self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state);
let result = self.process_inference_result(shard, result, request_id, inference_state);
result
}
#[tracing::instrument]
pub async fn process_prompt(
&self,
@ -46,6 +64,8 @@ impl Node {
.get_shard_for_node(base_shard, &self.node_info.node_id);
todo!();
// The python code is a little weird wrt return types here
// if shard.is_first_layer() {
// let result = self
// .inference_engine
@ -218,7 +238,14 @@ impl NodeService for Node {
&self,
request: Request<TensorRequest>,
) -> Result<Response<Tensor>, Status> {
todo!()
let request = request.into_inner();
let shard = request.shard.expect("No shard provided").into();
let request_id = request.request_id.expect("No request id provided");
let result =
self.process_tensor(shard, request.tensor, request_id, request.inference_state);
Ok(Response::new(result.into()))
}
async fn send_example(