Add PeerHandle

This commit is contained in:
Joshua Coles 2025-02-12 07:36:07 +00:00
parent 5f6f8b6a43
commit 2d02fde6c6
6 changed files with 34 additions and 367 deletions

20
Cargo.lock generated
View File

@ -169,6 +169,15 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "crc32fast"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
dependencies = [
"cfg-if",
]
[[package]]
name = "either"
version = "1.13.0"
@ -218,6 +227,16 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
version = "1.0.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "fnv"
version = "1.0.7"
@ -1037,6 +1056,7 @@ dependencies = [
"axum",
"base64",
"bytes",
"flate2",
"h2",
"http",
"http-body",

View File

@ -8,7 +8,7 @@ prost = "0.13.4"
serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.138"
tokio = { version = "1.43.0", features = ["full"] }
tonic = "0.12.3"
tonic = { version = "0.12.3", features = ["gzip"] }
thiserror = "2.0"
tracing = "0.1"
tracing-subscriber = "0.3"

View File

@ -1,262 +0,0 @@
mod udp;
use crate::topology::DeviceCapabilities;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
net::{IpAddr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use thiserror::Error;
use tokio::{net::UdpSocket, sync::RwLock, time};
use tracing::{error};
#[derive(Error, Debug)]
pub enum DiscoveryError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON serialization error: {0}")]
Json(#[from] serde_json::Error),
#[error("UTF-8 conversion error: {0}")]
Utf8(#[from] std::string::FromUtf8Error),
#[error("Address parse error: {0}")]
AddrParse(#[from] std::net::AddrParseError),
#[error("Network interface error: {0}")]
NetworkInterface(String),
#[error("Peer error: {0}")]
Peer(String),
}
// Constants
const DEBUG: i32 = 0;
const BROADCAST_ADDR: &str = "255.255.255.255";
#[derive(Debug, Clone, Serialize, Deserialize)]
struct DiscoveryMessage {
message_type: String,
node_id: String,
grpc_port: u16,
device_capabilities: DeviceCapabilities,
priority: i32,
interface_name: String,
interface_type: String,
}
struct PeerInfo {
peer_handle: PeerHandle,
connected_at: Instant,
last_seen: Instant,
priority: i32,
}
struct UdpDiscovery {
node_id: String,
node_port: u16,
listen_port: u16,
broadcast_port: u16,
broadcast_interval: Duration,
discovery_timeout: Duration,
device_capabilities: DeviceCapabilities,
allowed_node_ids: Option<Vec<String>>,
allowed_interface_types: Option<Vec<String>>,
known_peers: Arc<RwLock<HashMap<String, PeerInfo>>>,
}
impl UdpDiscovery {
pub fn new(
node_id: String,
node_port: u16,
listen_port: u16,
broadcast_port: u16,
broadcast_interval: Duration,
discovery_timeout: Duration,
device_capabilities: DeviceCapabilities,
allowed_node_ids: Option<Vec<String>>,
allowed_interface_types: Option<Vec<String>>,
) -> Self {
Self {
node_id,
node_port,
listen_port,
broadcast_port,
broadcast_interval,
discovery_timeout,
device_capabilities,
allowed_node_ids,
allowed_interface_types,
known_peers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn start(&self) -> Result<(), DiscoveryError> {
let broadcast_task = self.task_broadcast_presence();
let listen_task = self.task_listen_for_peers();
let cleanup_task = self.task_cleanup_peers();
tokio::try_join!(broadcast_task, listen_task, cleanup_task)?;
Ok(())
}
}
impl UdpDiscovery {
async fn task_listen_for_peers(&self) -> Result<(), DiscoveryError> {
let socket = UdpSocket::bind(format!("0.0.0.0:{}", self.listen_port)).await?;
let mut buf = vec![0u8; 65535];
loop {
let (len, addr) = socket.recv_from(&mut buf).await?;
if len == 0 {
continue;
}
if let Ok(message) = String::from_utf8(buf[..len].to_vec()) {
if let Ok(discovery_message) = serde_json::from_str::<DiscoveryMessage>(&message) {
self.handle_discovery_message(discovery_message, addr)
.await
.map_err(|e| DiscoveryError::Peer(e.to_string()))?;
}
}
}
}
}
impl UdpDiscovery {
async fn task_broadcast_presence(&self) -> Result<(), DiscoveryError> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.set_broadcast(true)?;
loop {
let interfaces = get_all_ip_addresses_and_interfaces()?;
for (addr, interface_name) in interfaces {
let (interface_priority, interface_type) =
get_interface_priority_and_type(&interface_name)
.await
.map_err(|e| DiscoveryError::NetworkInterface(e.to_string()))?;
let message = DiscoveryMessage {
message_type: "discovery".to_string(),
node_id: self.node_id.clone(),
grpc_port: self.node_port,
device_capabilities: self.device_capabilities.clone(),
priority: interface_priority,
interface_name: interface_name.clone(),
interface_type: interface_type.clone(),
};
let message_json = serde_json::to_string(&message)?;
let broadcast_addr =
SocketAddr::new(get_broadcast_address(&addr).parse()?, self.broadcast_port);
if let Err(e) = socket
.send_to(message_json.as_bytes(), &broadcast_addr)
.await
{
error!("Error broadcasting to {}: {}", broadcast_addr, e);
}
}
time::sleep(self.broadcast_interval).await;
}
}
}
impl UdpDiscovery {
async fn task_cleanup_peers(&self) -> Result<(), DiscoveryError> {
loop {
let now = Instant::now();
// TODO: Do we really want a read then write lock or should we just take a write lock
// to begin with?
let peers_to_remove = {
let mut peers_to_remove = Vec::new();
let peers = self.known_peers.read().await;
for (peer_id, peer_info) in peers.iter() {
if self.should_remove_peer(peer_info, now).await {
peers_to_remove.push(peer_id.clone());
}
}
peers_to_remove
};
{
let mut peers = self.known_peers.write().await;
for peer_id in peers_to_remove {
peers.remove(&peer_id);
}
};
time::sleep(self.broadcast_interval).await;
}
}
async fn should_remove_peer(&self, peer_info: &PeerInfo, now: Instant) -> bool {
let is_connected = peer_info
.peer_handle
.is_connected()
.await
.ok()
.unwrap_or(false);
if !is_connected {
return true;
}
if now.duration_since(peer_info.connected_at) > self.discovery_timeout {
return true;
}
if now.duration_since(peer_info.last_seen) > self.discovery_timeout {
return true;
}
peer_info
.peer_handle
.health_check()
.await
.ok()
.unwrap_or(false)
}
}
fn get_broadcast_address(ip_addr: &IpAddr) -> String {
match ip_addr {
IpAddr::V4(addr) => {
let octets = addr.octets();
format!("{}.{}.{}.255", octets[0], octets[1], octets[2])
}
_ => BROADCAST_ADDR.to_string(),
}
}
// You'll need to implement these functions based on your system
fn get_all_ip_addresses_and_interfaces() -> Result<Vec<(IpAddr, String)>, DiscoveryError> {
// Implementation needed
Ok(vec![])
}
async fn get_interface_priority_and_type(
interface_name: &str,
) -> Result<(i32, String), DiscoveryError> {
// Implementation needed
Ok((0, "unknown".to_string()))
}
struct PeerHandle;
impl PeerHandle {
async fn is_connected(&self) -> Result<bool, DiscoveryError> {
Ok(true)
}
async fn health_check(&self) -> Result<bool, DiscoveryError> {
Ok(true)
}
}

View File

@ -1,98 +0,0 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use crate::topology::DeviceCapabilities;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tracing::{debug, error, info};
use crate::discovery::PeerInfo;
struct UdpDiscovery {
listen_port: u16,
listen_buffer_size: usize,
allowed_interface_types: Option<Vec<String>>,
known_peers: RwLock<HashMap<String, PeerInfo>>,
listen_handle: JoinHandle<()>,
presence_handle: JoinHandle<()>,
}
#[derive(Debug, Serialize, Deserialize)]
struct DiscoveryMessage {
#[serde(rename = "type")]
message_type: String,
node_id: String,
grpc_port: u16,
device_capabilities: DeviceCapabilities,
priority: i32,
interface_name: String,
interface_type: String,
}
impl UdpDiscovery {
async fn listen(&self) {
let listen_socket = tokio::net::UdpSocket::bind("0.0.0.0:42069").await.unwrap();
let mut buf = vec![0; self.listen_buffer_size];
loop {
// This will block waiting for a message.
// If this fails it will end the loop and the task, which is what we want.
let Ok((len, addr)) = listen_socket.recv_from(&mut buf).await.unwrap();
let Some(message) = serde_json::from_slice::<DiscoveryMessage>(&buf[..len]) else {
error!(
"Received invalid discovery message from {} {:?}",
addr,
str::from_utf8(&buf[..len])
);
continue;
};
if let Some(ref allowed_interface_types) = self.allowed_interface_types {
if !allowed_interface_types.contains(&message.interface_type) {
debug!("Ignoring message from {} because interface type {} is not allowed", addr, message.interface_type);
continue;
}
}
self.on_discovery_message(message, addr).await;
}
}
async fn on_discovery_message(&self, message: DiscoveryMessage, addr: SocketAddr) {
let known_peers = self.known_peers.write().await;
let existing = known_peers.get(&message.node_id);
if let Some(existing) = existing {
}
/*
if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
if peer_id in self.known_peers:
existing_peer_prio = self.known_peers[peer_id][3]
if existing_peer_prio >= peer_prio:
if DEBUG >= 1:
print(
f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
return
new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}",
f"{peer_interface_type} ({peer_interface_name})",
device_capabilities)
if not await new_peer_handle.health_check():
if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
return
if DEBUG >= 1: print(
f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time(), peer_prio)
else:
if not await self.known_peers[peer_id][0].health_check():
if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
if peer_id in self.known_peers: del self.known_peers[peer_id]
return
if peer_id in self.known_peers: self.known_peers[peer_id] = (
self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
*/
}
}
}

View File

@ -1,5 +1,4 @@
mod topology;
mod discovery;
mod orchestration;
use serde::{Deserialize, Serialize};

View File

@ -1,20 +1,28 @@
use std::net::SocketAddr;
use tonic::codec::CompressionEncoding;
use crate::node_service::node_service_client::NodeServiceClient;
use crate::topology::DeviceCapabilities;
struct PeerHandle {
node_id: String,
address: SocketAddr,
description: Option<String>,
client: NodeServiceClient<tonic::transport::Channel>,
device_capabilities: DeviceCapabilities,
}
impl PeerHandle {
fn new(node_id: String, address: SocketAddr, device_capabilities: DeviceCapabilities) -> Self {
crate::node_service::node_service_client::NodeServiceClient::connect(address);
async fn new(node_id: String, address: SocketAddr, description: Option<String>, device_capabilities: DeviceCapabilities) -> Result<Self, tonic::transport::Error> {
let endpoint = format!("http://{}", address);
let client = NodeServiceClient::connect(endpoint).await?
.accept_compressed(CompressionEncoding::Gzip);
Self {
Ok(Self {
node_id,
description,
address,
client,
device_capabilities,
}
})
}
}
}