Add PeerHandle
This commit is contained in:
parent
5f6f8b6a43
commit
2d02fde6c6
20
Cargo.lock
generated
20
Cargo.lock
generated
@ -169,6 +169,15 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.13.0"
|
||||
@ -218,6 +227,16 @@ version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.0.35"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
@ -1037,6 +1056,7 @@ dependencies = [
|
||||
"axum",
|
||||
"base64",
|
||||
"bytes",
|
||||
"flate2",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
|
||||
@ -8,7 +8,7 @@ prost = "0.13.4"
|
||||
serde = { version = "1.0.217", features = ["derive"] }
|
||||
serde_json = "1.0.138"
|
||||
tokio = { version = "1.43.0", features = ["full"] }
|
||||
tonic = "0.12.3"
|
||||
tonic = { version = "0.12.3", features = ["gzip"] }
|
||||
thiserror = "2.0"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
@ -1,262 +0,0 @@
|
||||
mod udp;
|
||||
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::{net::UdpSocket, sync::RwLock, time};
|
||||
use tracing::{error};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DiscoveryError {
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("JSON serialization error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("UTF-8 conversion error: {0}")]
|
||||
Utf8(#[from] std::string::FromUtf8Error),
|
||||
|
||||
#[error("Address parse error: {0}")]
|
||||
AddrParse(#[from] std::net::AddrParseError),
|
||||
|
||||
#[error("Network interface error: {0}")]
|
||||
NetworkInterface(String),
|
||||
|
||||
#[error("Peer error: {0}")]
|
||||
Peer(String),
|
||||
}
|
||||
|
||||
// Constants
|
||||
const DEBUG: i32 = 0;
|
||||
const BROADCAST_ADDR: &str = "255.255.255.255";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct DiscoveryMessage {
|
||||
message_type: String,
|
||||
node_id: String,
|
||||
grpc_port: u16,
|
||||
device_capabilities: DeviceCapabilities,
|
||||
priority: i32,
|
||||
interface_name: String,
|
||||
interface_type: String,
|
||||
}
|
||||
|
||||
struct PeerInfo {
|
||||
peer_handle: PeerHandle,
|
||||
connected_at: Instant,
|
||||
last_seen: Instant,
|
||||
priority: i32,
|
||||
}
|
||||
|
||||
struct UdpDiscovery {
|
||||
node_id: String,
|
||||
node_port: u16,
|
||||
listen_port: u16,
|
||||
broadcast_port: u16,
|
||||
broadcast_interval: Duration,
|
||||
discovery_timeout: Duration,
|
||||
device_capabilities: DeviceCapabilities,
|
||||
allowed_node_ids: Option<Vec<String>>,
|
||||
allowed_interface_types: Option<Vec<String>>,
|
||||
known_peers: Arc<RwLock<HashMap<String, PeerInfo>>>,
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
pub fn new(
|
||||
node_id: String,
|
||||
node_port: u16,
|
||||
listen_port: u16,
|
||||
broadcast_port: u16,
|
||||
broadcast_interval: Duration,
|
||||
discovery_timeout: Duration,
|
||||
device_capabilities: DeviceCapabilities,
|
||||
allowed_node_ids: Option<Vec<String>>,
|
||||
allowed_interface_types: Option<Vec<String>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
node_port,
|
||||
listen_port,
|
||||
broadcast_port,
|
||||
broadcast_interval,
|
||||
discovery_timeout,
|
||||
device_capabilities,
|
||||
allowed_node_ids,
|
||||
allowed_interface_types,
|
||||
known_peers: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start(&self) -> Result<(), DiscoveryError> {
|
||||
let broadcast_task = self.task_broadcast_presence();
|
||||
let listen_task = self.task_listen_for_peers();
|
||||
let cleanup_task = self.task_cleanup_peers();
|
||||
|
||||
tokio::try_join!(broadcast_task, listen_task, cleanup_task)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
async fn task_listen_for_peers(&self) -> Result<(), DiscoveryError> {
|
||||
let socket = UdpSocket::bind(format!("0.0.0.0:{}", self.listen_port)).await?;
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
loop {
|
||||
let (len, addr) = socket.recv_from(&mut buf).await?;
|
||||
if len == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(message) = String::from_utf8(buf[..len].to_vec()) {
|
||||
if let Ok(discovery_message) = serde_json::from_str::<DiscoveryMessage>(&message) {
|
||||
self.handle_discovery_message(discovery_message, addr)
|
||||
.await
|
||||
.map_err(|e| DiscoveryError::Peer(e.to_string()))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
async fn task_broadcast_presence(&self) -> Result<(), DiscoveryError> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
socket.set_broadcast(true)?;
|
||||
|
||||
loop {
|
||||
let interfaces = get_all_ip_addresses_and_interfaces()?;
|
||||
for (addr, interface_name) in interfaces {
|
||||
let (interface_priority, interface_type) =
|
||||
get_interface_priority_and_type(&interface_name)
|
||||
.await
|
||||
.map_err(|e| DiscoveryError::NetworkInterface(e.to_string()))?;
|
||||
|
||||
let message = DiscoveryMessage {
|
||||
message_type: "discovery".to_string(),
|
||||
node_id: self.node_id.clone(),
|
||||
grpc_port: self.node_port,
|
||||
device_capabilities: self.device_capabilities.clone(),
|
||||
priority: interface_priority,
|
||||
interface_name: interface_name.clone(),
|
||||
interface_type: interface_type.clone(),
|
||||
};
|
||||
|
||||
let message_json = serde_json::to_string(&message)?;
|
||||
let broadcast_addr =
|
||||
SocketAddr::new(get_broadcast_address(&addr).parse()?, self.broadcast_port);
|
||||
|
||||
if let Err(e) = socket
|
||||
.send_to(message_json.as_bytes(), &broadcast_addr)
|
||||
.await
|
||||
{
|
||||
error!("Error broadcasting to {}: {}", broadcast_addr, e);
|
||||
}
|
||||
}
|
||||
|
||||
time::sleep(self.broadcast_interval).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
async fn task_cleanup_peers(&self) -> Result<(), DiscoveryError> {
|
||||
loop {
|
||||
let now = Instant::now();
|
||||
|
||||
// TODO: Do we really want a read then write lock or should we just take a write lock
|
||||
// to begin with?
|
||||
let peers_to_remove = {
|
||||
let mut peers_to_remove = Vec::new();
|
||||
|
||||
let peers = self.known_peers.read().await;
|
||||
for (peer_id, peer_info) in peers.iter() {
|
||||
if self.should_remove_peer(peer_info, now).await {
|
||||
peers_to_remove.push(peer_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
peers_to_remove
|
||||
};
|
||||
|
||||
{
|
||||
let mut peers = self.known_peers.write().await;
|
||||
for peer_id in peers_to_remove {
|
||||
peers.remove(&peer_id);
|
||||
}
|
||||
};
|
||||
|
||||
time::sleep(self.broadcast_interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn should_remove_peer(&self, peer_info: &PeerInfo, now: Instant) -> bool {
|
||||
let is_connected = peer_info
|
||||
.peer_handle
|
||||
.is_connected()
|
||||
.await
|
||||
.ok()
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_connected {
|
||||
return true;
|
||||
}
|
||||
|
||||
if now.duration_since(peer_info.connected_at) > self.discovery_timeout {
|
||||
return true;
|
||||
}
|
||||
|
||||
if now.duration_since(peer_info.last_seen) > self.discovery_timeout {
|
||||
return true;
|
||||
}
|
||||
|
||||
peer_info
|
||||
.peer_handle
|
||||
.health_check()
|
||||
.await
|
||||
.ok()
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_broadcast_address(ip_addr: &IpAddr) -> String {
|
||||
match ip_addr {
|
||||
IpAddr::V4(addr) => {
|
||||
let octets = addr.octets();
|
||||
format!("{}.{}.{}.255", octets[0], octets[1], octets[2])
|
||||
}
|
||||
_ => BROADCAST_ADDR.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
// You'll need to implement these functions based on your system
|
||||
fn get_all_ip_addresses_and_interfaces() -> Result<Vec<(IpAddr, String)>, DiscoveryError> {
|
||||
// Implementation needed
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn get_interface_priority_and_type(
|
||||
interface_name: &str,
|
||||
) -> Result<(i32, String), DiscoveryError> {
|
||||
// Implementation needed
|
||||
Ok((0, "unknown".to_string()))
|
||||
}
|
||||
|
||||
struct PeerHandle;
|
||||
|
||||
impl PeerHandle {
|
||||
async fn is_connected(&self) -> Result<bool, DiscoveryError> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<bool, DiscoveryError> {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
@ -1,98 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{debug, error, info};
|
||||
use crate::discovery::PeerInfo;
|
||||
|
||||
struct UdpDiscovery {
|
||||
listen_port: u16,
|
||||
listen_buffer_size: usize,
|
||||
allowed_interface_types: Option<Vec<String>>,
|
||||
known_peers: RwLock<HashMap<String, PeerInfo>>,
|
||||
|
||||
listen_handle: JoinHandle<()>,
|
||||
presence_handle: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct DiscoveryMessage {
|
||||
#[serde(rename = "type")]
|
||||
message_type: String,
|
||||
node_id: String,
|
||||
grpc_port: u16,
|
||||
device_capabilities: DeviceCapabilities,
|
||||
priority: i32,
|
||||
interface_name: String,
|
||||
interface_type: String,
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
async fn listen(&self) {
|
||||
let listen_socket = tokio::net::UdpSocket::bind("0.0.0.0:42069").await.unwrap();
|
||||
let mut buf = vec![0; self.listen_buffer_size];
|
||||
|
||||
loop {
|
||||
// This will block waiting for a message.
|
||||
// If this fails it will end the loop and the task, which is what we want.
|
||||
let Ok((len, addr)) = listen_socket.recv_from(&mut buf).await.unwrap();
|
||||
|
||||
let Some(message) = serde_json::from_slice::<DiscoveryMessage>(&buf[..len]) else {
|
||||
error!(
|
||||
"Received invalid discovery message from {} {:?}",
|
||||
addr,
|
||||
str::from_utf8(&buf[..len])
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Some(ref allowed_interface_types) = self.allowed_interface_types {
|
||||
if !allowed_interface_types.contains(&message.interface_type) {
|
||||
debug!("Ignoring message from {} because interface type {} is not allowed", addr, message.interface_type);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
self.on_discovery_message(message, addr).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_discovery_message(&self, message: DiscoveryMessage, addr: SocketAddr) {
|
||||
let known_peers = self.known_peers.write().await;
|
||||
let existing = known_peers.get(&message.node_id);
|
||||
|
||||
if let Some(existing) = existing {
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
|
||||
if peer_id in self.known_peers:
|
||||
existing_peer_prio = self.known_peers[peer_id][3]
|
||||
if existing_peer_prio >= peer_prio:
|
||||
if DEBUG >= 1:
|
||||
print(
|
||||
f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
|
||||
return
|
||||
new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}",
|
||||
f"{peer_interface_type} ({peer_interface_name})",
|
||||
device_capabilities)
|
||||
if not await new_peer_handle.health_check():
|
||||
if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
|
||||
return
|
||||
if DEBUG >= 1: print(
|
||||
f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
|
||||
self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time(), peer_prio)
|
||||
else:
|
||||
if not await self.known_peers[peer_id][0].health_check():
|
||||
if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
|
||||
if peer_id in self.known_peers: del self.known_peers[peer_id]
|
||||
return
|
||||
if peer_id in self.known_peers: self.known_peers[peer_id] = (
|
||||
self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
|
||||
*/
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,4 @@
|
||||
mod topology;
|
||||
mod discovery;
|
||||
mod orchestration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -1,20 +1,28 @@
|
||||
use std::net::SocketAddr;
|
||||
use tonic::codec::CompressionEncoding;
|
||||
use crate::node_service::node_service_client::NodeServiceClient;
|
||||
use crate::topology::DeviceCapabilities;
|
||||
|
||||
struct PeerHandle {
|
||||
node_id: String,
|
||||
address: SocketAddr,
|
||||
description: Option<String>,
|
||||
client: NodeServiceClient<tonic::transport::Channel>,
|
||||
device_capabilities: DeviceCapabilities,
|
||||
}
|
||||
|
||||
impl PeerHandle {
|
||||
fn new(node_id: String, address: SocketAddr, device_capabilities: DeviceCapabilities) -> Self {
|
||||
crate::node_service::node_service_client::NodeServiceClient::connect(address);
|
||||
async fn new(node_id: String, address: SocketAddr, description: Option<String>, device_capabilities: DeviceCapabilities) -> Result<Self, tonic::transport::Error> {
|
||||
let endpoint = format!("http://{}", address);
|
||||
let client = NodeServiceClient::connect(endpoint).await?
|
||||
.accept_compressed(CompressionEncoding::Gzip);
|
||||
|
||||
Self {
|
||||
Ok(Self {
|
||||
node_id,
|
||||
description,
|
||||
address,
|
||||
client,
|
||||
device_capabilities,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user