Add CTDB Phase 1-5: TDB storage + Node management + Control protocol + IP manager + Recovery

This commit is contained in:
Warren
2026-06-22 14:21:39 +08:00
parent a8d81f2a9c
commit 64709ec529
6 changed files with 2347 additions and 0 deletions

View File

@@ -0,0 +1,379 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::RwLock;
use super::node::NodeId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PublicIpId(pub u32);
#[derive(Debug, Clone)]
pub struct PublicIpEntry {
pub id: PublicIpId,
pub ip: IpAddr,
pub interface: String,
pub owner: Option<NodeId>,
pub previous_owner: Option<NodeId>,
}
impl PublicIpEntry {
pub fn new(id: PublicIpId, ip: IpAddr, interface: &str) -> Self {
Self {
id,
ip,
interface: interface.to_string(),
owner: None,
previous_owner: None,
}
}
}
pub struct IpManager {
ip_pool: RwLock<Vec<PublicIpEntry>>,
assignments: RwLock<HashMap<PublicIpId, NodeId>>,
next_id: RwLock<u32>,
}
impl IpManager {
pub fn new() -> Self {
Self {
ip_pool: RwLock::new(Vec::new()),
assignments: RwLock::new(HashMap::new()),
next_id: RwLock::new(0),
}
}
pub fn add_ip(&self, ip: IpAddr, interface: &str) -> PublicIpId {
let mut pool = self.ip_pool.write().unwrap();
let mut next = self.next_id.write().unwrap();
let id = PublicIpId(*next);
*next += 1;
pool.push(PublicIpEntry::new(id, ip, interface));
id
}
pub fn remove_ip(&self, id: PublicIpId) {
let mut pool = self.ip_pool.write().unwrap();
pool.retain(|e| e.id != id);
let mut assignments = self.assignments.write().unwrap();
assignments.remove(&id);
}
pub fn assign_ip(&self, id: PublicIpId, node: NodeId) -> Result<(), String> {
let exists = {
let pool = self.ip_pool.read().unwrap();
pool.iter().any(|e| e.id == id)
};
if !exists {
return Err(format!("IP {:?} not found", id));
}
let prev_owner = {
let assignments = self.assignments.read().unwrap();
assignments.get(&id).copied()
};
{
let mut assignments = self.assignments.write().unwrap();
assignments.insert(id, node);
}
{
let mut pool = self.ip_pool.write().unwrap();
if let Some(entry) = pool.iter_mut().find(|e| e.id == id) {
if let Some(prev) = prev_owner {
entry.previous_owner = Some(prev);
}
entry.owner = Some(node);
}
}
Ok(())
}
pub fn release_ip(&self, id: PublicIpId) -> Result<(), String> {
let mut assignments = self.assignments.write().unwrap();
if let Some(prev) = assignments.remove(&id) {
let mut pool = self.ip_pool.write().unwrap();
if let Some(entry) = pool.iter_mut().find(|e| e.id == id) {
entry.previous_owner = Some(prev);
entry.owner = None;
}
Ok(())
} else {
Err(format!("IP {:?} not assigned", id))
}
}
pub fn reassign_on_failure(&self, failed_node: NodeId, active_nodes: &[NodeId]) -> Vec<(PublicIpId, NodeId)> {
let assignments = self.assignments.read().unwrap();
let failed_ips: Vec<PublicIpId> = assignments
.iter()
.filter(|(_, &node)| node == failed_node)
.map(|(&id, _)| id)
.collect();
drop(assignments);
let candidates: Vec<NodeId> = active_nodes
.iter()
.filter(|&&n| n != failed_node)
.copied()
.collect();
let mut reassignments = Vec::new();
for (i, ip_id) in failed_ips.iter().enumerate() {
if let Some(&new_owner) = candidates.get(i % candidates.len().max(1)) {
let _ = self.assign_ip(*ip_id, new_owner);
reassignments.push((*ip_id, new_owner));
}
}
reassignments
}
pub fn get_owner(&self, id: PublicIpId) -> Option<NodeId> {
self.assignments.read().unwrap().get(&id).copied()
}
pub fn owned_by(&self, node: NodeId) -> Vec<PublicIpEntry> {
let assignments = self.assignments.read().unwrap();
let pool = self.ip_pool.read().unwrap();
assignments
.iter()
.filter(|(_, &owner)| owner == node)
.filter_map(|(&id, _)| pool.iter().find(|e| e.id == id).cloned())
.collect()
}
pub fn unassigned(&self) -> Vec<PublicIpEntry> {
let assignments = self.assignments.read().unwrap();
let pool = self.ip_pool.read().unwrap();
pool.iter()
.filter(|e| !assignments.contains_key(&e.id))
.cloned()
.collect()
}
pub fn all_ips(&self) -> Vec<PublicIpEntry> {
self.ip_pool.read().unwrap().clone()
}
pub fn ip_count(&self) -> usize {
self.ip_pool.read().unwrap().len()
}
pub fn assigned_count(&self) -> usize {
self.assignments.read().unwrap().len()
}
pub fn balance(&self, active_nodes: &[NodeId]) -> Vec<(PublicIpId, NodeId)> {
let pool = self.ip_pool.read().unwrap();
let assignments = self.assignments.read().unwrap();
let total_ips = pool.len();
let node_count = active_nodes.len().max(1);
let per_node = total_ips / node_count;
let mut current_counts: HashMap<NodeId, usize> = HashMap::new();
for &node in active_nodes {
current_counts.insert(node, 0);
}
for (_, &owner) in assignments.iter() {
*current_counts.entry(owner).or_insert(0) += 1;
}
let mut reassignments = Vec::new();
for entry in pool.iter() {
let current_owner = assignments.get(&entry.id).copied();
let needs_reassign = match current_owner {
Some(owner) => {
!active_nodes.contains(&owner) || current_counts[&owner] > per_node
}
None => true,
};
if needs_reassign {
if let Some(&target) = active_nodes
.iter()
.min_by_key(|n| current_counts.get(n).copied().unwrap_or(0))
{
current_counts
.entry(target)
.and_modify(|c| *c += 1)
.or_insert(1);
reassignments.push((entry.id, target));
}
}
}
drop(pool);
drop(assignments);
for (id, node) in &reassignments {
let _ = self.assign_ip(*id, *node);
}
reassignments
}
}
impl Default for IpManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
fn ip(a: u8, b: u8, c: u8, d: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(a, b, c, d))
}
#[test]
fn test_add_ip() {
let mgr = IpManager::new();
let id = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
assert_eq!(mgr.ip_count(), 1);
assert_eq!(id, PublicIpId(0));
}
#[test]
fn test_add_multiple_ips() {
let mgr = IpManager::new();
let id0 = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
let id1 = mgr.add_ip(ip(192, 168, 1, 101), "eth0");
let id2 = mgr.add_ip(ip(192, 168, 1, 102), "eth1");
assert_eq!(mgr.ip_count(), 3);
assert_ne!(id0, id1);
assert_ne!(id1, id2);
}
#[test]
fn test_assign_ip() {
let mgr = IpManager::new();
let id = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
mgr.assign_ip(id, NodeId(0)).unwrap();
assert_eq!(mgr.get_owner(id), Some(NodeId(0)));
assert_eq!(mgr.assigned_count(), 1);
}
#[test]
fn test_release_ip() {
let mgr = IpManager::new();
let id = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
mgr.assign_ip(id, NodeId(0)).unwrap();
mgr.release_ip(id).unwrap();
assert_eq!(mgr.get_owner(id), None);
}
#[test]
fn test_release_unassigned_fails() {
let mgr = IpManager::new();
let id = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
assert!(mgr.release_ip(id).is_err());
}
#[test]
fn test_assign_nonexistent_fails() {
let mgr = IpManager::new();
assert!(mgr.assign_ip(PublicIpId(999), NodeId(0)).is_err());
}
#[test]
fn test_owned_by() {
let mgr = IpManager::new();
let id0 = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
let id1 = mgr.add_ip(ip(192, 168, 1, 101), "eth0");
let id2 = mgr.add_ip(ip(192, 168, 1, 102), "eth1");
mgr.assign_ip(id0, NodeId(0)).unwrap();
mgr.assign_ip(id1, NodeId(1)).unwrap();
mgr.assign_ip(id2, NodeId(0)).unwrap();
let node0_ips = mgr.owned_by(NodeId(0));
assert_eq!(node0_ips.len(), 2);
}
#[test]
fn test_unassigned() {
let mgr = IpManager::new();
mgr.add_ip(ip(192, 168, 1, 100), "eth0");
let id1 = mgr.add_ip(ip(192, 168, 1, 101), "eth0");
mgr.assign_ip(id1, NodeId(0)).unwrap();
let unassigned = mgr.unassigned();
assert_eq!(unassigned.len(), 1);
}
#[test]
fn test_reassign_on_failure() {
let mgr = IpManager::new();
let id0 = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
let id1 = mgr.add_ip(ip(192, 168, 1, 101), "eth0");
mgr.assign_ip(id0, NodeId(0)).unwrap();
mgr.assign_ip(id1, NodeId(0)).unwrap();
let reassignments = mgr.reassign_on_failure(NodeId(0), &[NodeId(1), NodeId(2)]);
assert_eq!(reassignments.len(), 2);
let new_owners: Vec<NodeId> = reassignments.iter().map(|(_, n)| *n).collect();
assert!(new_owners.contains(&NodeId(1)) || new_owners.contains(&NodeId(2)));
}
#[test]
fn test_reassign_skips_failed_node() {
let mgr = IpManager::new();
let id0 = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
mgr.assign_ip(id0, NodeId(0)).unwrap();
let reassignments = mgr.reassign_on_failure(NodeId(0), &[NodeId(0), NodeId(1)]);
assert_eq!(reassignments.len(), 1);
assert_ne!(reassignments[0].1, NodeId(0));
}
#[test]
fn test_remove_ip() {
let mgr = IpManager::new();
let id = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
mgr.assign_ip(id, NodeId(0)).unwrap();
mgr.remove_ip(id);
assert_eq!(mgr.ip_count(), 0);
assert_eq!(mgr.assigned_count(), 0);
}
#[test]
fn test_balance() {
let mgr = IpManager::new();
for i in 100..104 {
mgr.add_ip(ip(192, 168, 1, i), "eth0");
}
let reassignments = mgr.balance(&[NodeId(0), NodeId(1)]);
assert!(!reassignments.is_empty());
let node0 = mgr.owned_by(NodeId(0));
let node1 = mgr.owned_by(NodeId(1));
assert_eq!(node0.len() + node1.len(), 4);
assert!((node0.len() as i32 - node1.len() as i32).abs() <= 1);
}
#[test]
fn test_balance_excludes_inactive_nodes() {
let mgr = IpManager::new();
for i in 100..103 {
mgr.add_ip(ip(192, 168, 1, i), "eth0");
}
mgr.assign_ip(PublicIpId(0), NodeId(5)).unwrap();
let reassignments = mgr.balance(&[NodeId(0), NodeId(1)]);
assert!(reassignments
.iter()
.any(|(_, n)| *n == NodeId(0) || *n == NodeId(1)));
assert!(!reassignments.iter().any(|(_, n)| *n == NodeId(5)));
}
#[test]
fn test_previous_owner() {
let mgr = IpManager::new();
let id = mgr.add_ip(ip(192, 168, 1, 100), "eth0");
mgr.assign_ip(id, NodeId(0)).unwrap();
mgr.assign_ip(id, NodeId(1)).unwrap();
let pool = mgr.all_ips();
let entry = pool.iter().find(|e| e.id == id).unwrap();
assert_eq!(entry.previous_owner, Some(NodeId(0)));
assert_eq!(entry.owner, Some(NodeId(1)));
}
}

View File

@@ -0,0 +1,5 @@
pub mod tdb;
pub mod node;
pub mod protocol;
pub mod ip_manager;
pub mod recovery;

View File

@@ -0,0 +1,353 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeId(pub u32);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NodeState {
Up,
Down,
Unhealthy,
Banned,
Disabled,
}
impl NodeState {
pub fn is_active(&self) -> bool {
matches!(self, NodeState::Up)
}
pub fn as_str(&self) -> &'static str {
match self {
NodeState::Up => "UP",
NodeState::Down => "DOWN",
NodeState::Unhealthy => "UNHEALTHY",
NodeState::Banned => "BANNED",
NodeState::Disabled => "DISABLED",
}
}
}
#[derive(Debug, Clone)]
pub struct NodeInfo {
pub id: NodeId,
pub addr: SocketAddr,
pub state: NodeState,
pub last_heartbeat: Option<Instant>,
pub public_ips: Vec<String>,
pub generation: u64,
}
impl NodeInfo {
pub fn new(id: NodeId, addr: SocketAddr) -> Self {
Self {
id,
addr,
state: NodeState::Down,
last_heartbeat: None,
public_ips: Vec::new(),
generation: 0,
}
}
pub fn is_alive(&self, timeout: Duration) -> bool {
match self.last_heartbeat {
Some(t) => t.elapsed() < timeout,
None => false,
}
}
}
#[derive(Debug, Clone)]
pub struct NodeMask {
nodes: Vec<bool>,
}
impl NodeMask {
pub fn new(size: usize) -> Self {
Self {
nodes: vec![false; size],
}
}
pub fn set(&mut self, id: NodeId, active: bool) {
if (id.0 as usize) < self.nodes.len() {
self.nodes[id.0 as usize] = active;
}
}
pub fn is_active(&self, id: NodeId) -> bool {
self.nodes.get(id.0 as usize).copied().unwrap_or(false)
}
pub fn active_nodes(&self) -> Vec<NodeId> {
self.nodes
.iter()
.enumerate()
.filter(|(_, &active)| active)
.map(|(i, _)| NodeId(i as u32))
.collect()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn active_count(&self) -> usize {
self.nodes.iter().filter(|&&a| a).count()
}
}
pub struct NodeManager {
nodes: RwLock<HashMap<NodeId, NodeInfo>>,
self_id: NodeId,
heartbeat_timeout: Duration,
heartbeat_interval: Duration,
}
impl NodeManager {
pub fn new(self_id: NodeId, self_addr: SocketAddr) -> Self {
let mut nodes = HashMap::new();
nodes.insert(
self_id,
NodeInfo {
id: self_id,
addr: self_addr,
state: NodeState::Up,
last_heartbeat: Some(Instant::now()),
public_ips: Vec::new(),
generation: 0,
},
);
Self {
nodes: RwLock::new(nodes),
self_id,
heartbeat_timeout: Duration::from_secs(5),
heartbeat_interval: Duration::from_secs(1),
}
}
pub fn add_node(&self, id: NodeId, addr: SocketAddr) {
let mut nodes = self.nodes.write().unwrap();
nodes.insert(id, NodeInfo::new(id, addr));
}
pub fn remove_node(&self, id: NodeId) {
let mut nodes = self.nodes.write().unwrap();
nodes.remove(&id);
}
pub fn record_heartbeat(&self, id: NodeId) {
let mut nodes = self.nodes.write().unwrap();
if let Some(node) = nodes.get_mut(&id) {
node.last_heartbeat = Some(Instant::now());
if node.state == NodeState::Down {
node.state = NodeState::Up;
node.generation += 1;
}
}
}
pub fn set_node_state(&self, id: NodeId, state: NodeState) {
let mut nodes = self.nodes.write().unwrap();
if let Some(node) = nodes.get_mut(&id) {
node.state = state;
if state == NodeState::Up && node.last_heartbeat.is_none() {
node.last_heartbeat = Some(Instant::now());
}
node.generation += 1;
}
}
pub fn get_node(&self, id: NodeId) -> Option<NodeInfo> {
self.nodes.read().unwrap().get(&id).cloned()
}
pub fn all_nodes(&self) -> Vec<NodeInfo> {
self.nodes.read().unwrap().values().cloned().collect()
}
pub fn active_nodes(&self) -> Vec<NodeInfo> {
self.nodes
.read()
.unwrap()
.values()
.filter(|n| n.state == NodeState::Up)
.cloned()
.collect()
}
pub fn check_health(&self) -> Vec<(NodeId, NodeState)> {
let timeout = self.heartbeat_timeout;
let mut nodes = self.nodes.write().unwrap();
let mut transitions = Vec::new();
for (id, node) in nodes.iter_mut() {
if *id == self.self_id {
continue;
}
match &node.last_heartbeat {
Some(t) => {
if t.elapsed() > timeout && node.state == NodeState::Up {
node.state = NodeState::Down;
node.generation += 1;
transitions.push((*id, NodeState::Down));
}
}
None => {
if node.state == NodeState::Up {
node.state = NodeState::Down;
node.generation += 1;
transitions.push((*id, NodeState::Down));
}
}
}
}
transitions
}
pub fn self_id(&self) -> NodeId {
self.self_id
}
pub fn node_count(&self) -> usize {
self.nodes.read().unwrap().len()
}
pub fn build_nodemask(&self) -> NodeMask {
let nodes = self.nodes.read().unwrap();
let max_id = nodes.keys().map(|k| k.0).max().unwrap_or(0) as usize;
let mut mask = NodeMask::new(max_id + 1);
for (id, node) in nodes.iter() {
mask.set(*id, node.state.is_active());
}
mask
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddrV4};
fn addr(port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
}
#[test]
fn test_node_creation() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
let self_node = mgr.get_node(NodeId(0)).unwrap();
assert_eq!(self_node.state, NodeState::Up);
assert!(self_node.last_heartbeat.is_some());
}
#[test]
fn test_add_remove_node() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
assert_eq!(mgr.node_count(), 2);
mgr.remove_node(NodeId(1));
assert_eq!(mgr.node_count(), 1);
}
#[test]
fn test_heartbeat_updates() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Down);
mgr.record_heartbeat(NodeId(1));
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Up);
}
#[test]
fn test_health_check_timeout() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
mgr.record_heartbeat(NodeId(1));
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Up);
std::thread::sleep(Duration::from_millis(100));
let mgr = NodeManager {
nodes: RwLock::new(mgr.nodes.read().unwrap().clone()),
self_id: NodeId(0),
heartbeat_timeout: Duration::from_millis(50),
heartbeat_interval: Duration::from_millis(10),
};
let transitions = mgr.check_health();
assert!(transitions.contains(&(NodeId(1), NodeState::Down)));
}
#[test]
fn test_nodemask() {
let mut mask = NodeMask::new(5);
mask.set(NodeId(0), true);
mask.set(NodeId(2), true);
mask.set(NodeId(4), true);
assert!(mask.is_active(NodeId(0)));
assert!(!mask.is_active(NodeId(1)));
assert_eq!(mask.active_count(), 3);
assert_eq!(mask.len(), 5);
}
#[test]
fn test_build_nodemask() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
mgr.add_node(NodeId(2), addr(4002));
mgr.record_heartbeat(NodeId(1));
let mask = mgr.build_nodemask();
assert!(mask.is_active(NodeId(0)));
assert!(mask.is_active(NodeId(1)));
assert!(!mask.is_active(NodeId(2)));
}
#[test]
fn test_node_state_string() {
assert_eq!(NodeState::Up.as_str(), "UP");
assert_eq!(NodeState::Down.as_str(), "DOWN");
assert_eq!(NodeState::Unhealthy.as_str(), "UNHEALTHY");
assert_eq!(NodeState::Banned.as_str(), "BANNED");
assert_eq!(NodeState::Disabled.as_str(), "DISABLED");
}
#[test]
fn test_set_node_state() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
mgr.set_node_state(NodeId(1), NodeState::Banned);
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Banned);
}
#[test]
fn test_generation_increment() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
let gen0 = mgr.get_node(NodeId(1)).unwrap().generation;
mgr.set_node_state(NodeId(1), NodeState::Down);
let gen1 = mgr.get_node(NodeId(1)).unwrap().generation;
assert!(gen1 > gen0);
}
#[test]
fn test_active_nodes() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
mgr.add_node(NodeId(2), addr(4002));
mgr.record_heartbeat(NodeId(1));
let active = mgr.active_nodes();
assert_eq!(active.len(), 2);
}
#[test]
fn test_is_alive() {
let node = NodeInfo::new(NodeId(0), addr(4000));
assert!(!node.is_alive(Duration::from_secs(5)));
let mut node = node;
node.last_heartbeat = Some(Instant::now());
assert!(node.is_alive(Duration::from_secs(5)));
}
}

View File

@@ -0,0 +1,515 @@
use byteorder::{BigEndian, LittleEndian, ReadBytesExt, WriteBytesExt};
use std::io::{self, Cursor, Read, Write};
use std::net::TcpStream;
pub const CTDB_MAGIC: u32 = 0x43544442;
pub const CTDB_VERSION: u32 = 1;
pub const CTDB_HEADER_SIZE: usize = 24;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum CtdbCommand {
Connect = 1,
Disconnect = 2,
Ping = 3,
Pong = 4,
GetDb = 10,
Fetch = 11,
Store = 12,
Delete = 13,
Keys = 14,
SetNodeMask = 20,
GetNodeMask = 21,
NodeStatus = 22,
TakeIp = 30,
ReleaseIp = 31,
Monitor = 40,
Recovery = 50,
RecoveryDone = 51,
Unknown = 0xFFFF,
}
impl CtdbCommand {
pub fn from_u32(v: u32) -> Self {
match v {
1 => CtdbCommand::Connect,
2 => CtdbCommand::Disconnect,
3 => CtdbCommand::Ping,
4 => CtdbCommand::Pong,
10 => CtdbCommand::GetDb,
11 => CtdbCommand::Fetch,
12 => CtdbCommand::Store,
13 => CtdbCommand::Delete,
14 => CtdbCommand::Keys,
20 => CtdbCommand::SetNodeMask,
21 => CtdbCommand::GetNodeMask,
22 => CtdbCommand::NodeStatus,
30 => CtdbCommand::TakeIp,
31 => CtdbCommand::ReleaseIp,
40 => CtdbCommand::Monitor,
50 => CtdbCommand::Recovery,
51 => CtdbCommand::RecoveryDone,
_ => CtdbCommand::Unknown,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CtdbStatus {
Success = 0,
Error = 1,
NotFound = 2,
Exists = 3,
Corrupt = 4,
Timeout = 5,
NotActive = 6,
}
impl CtdbStatus {
pub fn from_u32(v: u32) -> Self {
match v {
0 => CtdbStatus::Success,
2 => CtdbStatus::NotFound,
3 => CtdbStatus::Exists,
4 => CtdbStatus::Corrupt,
5 => CtdbStatus::Timeout,
6 => CtdbStatus::NotActive,
_ => CtdbStatus::Error,
}
}
}
#[derive(Debug, Clone)]
pub struct CtdbHeader {
pub magic: u32,
pub version: u32,
pub command: u32,
pub status: u32,
pub length: u64,
}
impl CtdbHeader {
pub fn new(command: CtdbCommand, status: CtdbStatus, length: u64) -> Self {
Self {
magic: CTDB_MAGIC,
version: CTDB_VERSION,
command: command as u32,
status: status as u32,
length,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(CTDB_HEADER_SIZE);
buf.extend_from_slice(&self.magic.to_le_bytes());
buf.extend_from_slice(&self.version.to_le_bytes());
buf.extend_from_slice(&self.command.to_le_bytes());
buf.extend_from_slice(&self.status.to_le_bytes());
buf.extend_from_slice(&self.length.to_le_bytes());
buf
}
pub fn from_bytes(buf: &[u8]) -> Result<Self, CtdbProtoError> {
if buf.len() < CTDB_HEADER_SIZE {
return Err(CtdbProtoError::HeaderTooShort);
}
let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap());
let version = u32::from_le_bytes(buf[4..8].try_into().unwrap());
let command = u32::from_le_bytes(buf[8..12].try_into().unwrap());
let status = u32::from_le_bytes(buf[12..16].try_into().unwrap());
let length = u64::from_le_bytes(buf[16..24].try_into().unwrap());
Ok(Self {
magic,
version,
command,
status,
length,
})
}
pub fn is_valid(&self) -> bool {
self.magic == CTDB_MAGIC && self.version == CTDB_VERSION
}
}
#[derive(Debug, Clone)]
pub struct CtdbMessage {
pub header: CtdbHeader,
pub payload: Vec<u8>,
}
impl CtdbMessage {
pub fn new(command: CtdbCommand, status: CtdbStatus, payload: Vec<u8>) -> Self {
let length = payload.len() as u64;
Self {
header: CtdbHeader::new(command, status, length),
payload,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = self.header.to_bytes();
buf.extend_from_slice(&self.payload);
buf
}
pub fn from_bytes(buf: &[u8]) -> Result<Self, CtdbProtoError> {
let header = CtdbHeader::from_bytes(buf)?;
if !header.is_valid() {
return Err(CtdbProtoError::InvalidMagic);
}
let payload = if buf.len() > CTDB_HEADER_SIZE {
buf[CTDB_HEADER_SIZE..].to_vec()
} else {
Vec::new()
};
Ok(Self { header, payload })
}
pub fn command(&self) -> CtdbCommand {
CtdbCommand::from_u32(self.header.command)
}
pub fn status(&self) -> CtdbStatus {
CtdbStatus::from_u32(self.header.status)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CtdbProtoError {
HeaderTooShort,
InvalidMagic,
IoError,
InvalidPayload,
}
impl std::fmt::Display for CtdbProtoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CtdbProtoError::HeaderTooShort => write!(f, "header too short"),
CtdbProtoError::InvalidMagic => write!(f, "invalid magic number"),
CtdbProtoError::IoError => write!(f, "I/O error"),
CtdbProtoError::InvalidPayload => write!(f, "invalid payload"),
}
}
}
impl std::error::Error for CtdbProtoError {}
impl From<io::Error> for CtdbProtoError {
fn from(_: io::Error) -> Self {
CtdbProtoError::IoError
}
}
pub mod payload {
use super::CtdbProtoError;
pub fn encode_kv(key: &[u8], value: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
buf.extend_from_slice(key);
buf.extend_from_slice(&(value.len() as u32).to_le_bytes());
buf.extend_from_slice(value);
buf
}
pub fn decode_kv(payload: &[u8]) -> Result<(Vec<u8>, Vec<u8>), CtdbProtoError> {
if payload.len() < 4 {
return Err(CtdbProtoError::InvalidPayload);
}
let key_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
if payload.len() < 4 + key_len + 4 {
return Err(CtdbProtoError::InvalidPayload);
}
let key = payload[4..4 + key_len].to_vec();
let val_len_offset = 4 + key_len;
let val_len = u32::from_le_bytes(
payload[val_len_offset..val_len_offset + 4].try_into().unwrap(),
) as usize;
let val_start = val_len_offset + 4;
if payload.len() < val_start + val_len {
return Err(CtdbProtoError::InvalidPayload);
}
let value = payload[val_start..val_start + val_len].to_vec();
Ok((key, value))
}
pub fn encode_key(key: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
buf.extend_from_slice(key);
buf
}
pub fn decode_key(payload: &[u8]) -> Result<Vec<u8>, CtdbProtoError> {
if payload.len() < 4 {
return Err(CtdbProtoError::InvalidPayload);
}
let key_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
if payload.len() < 4 + key_len {
return Err(CtdbProtoError::InvalidPayload);
}
Ok(payload[4..4 + key_len].to_vec())
}
pub fn encode_node_id(id: u32) -> Vec<u8> {
id.to_le_bytes().to_vec()
}
pub fn decode_node_id(payload: &[u8]) -> Result<u32, CtdbProtoError> {
if payload.len() < 4 {
return Err(CtdbProtoError::InvalidPayload);
}
Ok(u32::from_le_bytes(payload[0..4].try_into().unwrap()))
}
pub fn encode_nodemask(active: &[u32]) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&(active.len() as u32).to_le_bytes());
for &id in active {
buf.extend_from_slice(&id.to_le_bytes());
}
buf
}
pub fn decode_nodemask(payload: &[u8]) -> Result<Vec<u32>, CtdbProtoError> {
if payload.len() < 4 {
return Err(CtdbProtoError::InvalidPayload);
}
let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
let mut ids = Vec::with_capacity(count);
for i in 0..count {
let offset = 4 + i * 4;
if payload.len() < offset + 4 {
return Err(CtdbProtoError::InvalidPayload);
}
ids.push(u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()));
}
Ok(ids)
}
pub fn encode_ip(ip: &str, interface: &str) -> Vec<u8> {
let ip_bytes = ip.as_bytes();
let if_bytes = interface.as_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&(ip_bytes.len() as u32).to_le_bytes());
buf.extend_from_slice(ip_bytes);
buf.extend_from_slice(&(if_bytes.len() as u32).to_le_bytes());
buf.extend_from_slice(if_bytes);
buf
}
pub fn decode_ip(payload: &[u8]) -> Result<(String, String), CtdbProtoError> {
if payload.len() < 4 {
return Err(CtdbProtoError::InvalidPayload);
}
let ip_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
if payload.len() < 4 + ip_len + 4 {
return Err(CtdbProtoError::InvalidPayload);
}
let ip = String::from_utf8_lossy(&payload[4..4 + ip_len]).to_string();
let if_offset = 4 + ip_len;
let if_len = u32::from_le_bytes(
payload[if_offset..if_offset + 4].try_into().unwrap(),
) as usize;
let if_start = if_offset + 4;
if payload.len() < if_start + if_len {
return Err(CtdbProtoError::InvalidPayload);
}
let interface = String::from_utf8_lossy(&payload[if_start..if_start + if_len]).to_string();
Ok((ip, interface))
}
}
pub struct CtdbConnection {
stream: TcpStream,
}
impl CtdbConnection {
pub fn new(stream: TcpStream) -> Self {
Self { stream }
}
pub fn connect(addr: &str) -> Result<Self, CtdbProtoError> {
let stream = TcpStream::connect(addr)?;
Ok(Self { stream })
}
pub fn send_message(&mut self, msg: &CtdbMessage) -> Result<(), CtdbProtoError> {
let bytes = msg.to_bytes();
self.stream.write_all(&bytes)?;
Ok(())
}
pub fn recv_message(&mut self) -> Result<CtdbMessage, CtdbProtoError> {
let mut header_buf = [0u8; CTDB_HEADER_SIZE];
self.stream.read_exact(&mut header_buf)?;
let header = CtdbHeader::from_bytes(&header_buf)?;
if !header.is_valid() {
return Err(CtdbProtoError::InvalidMagic);
}
let payload_len = header.length as usize;
let mut payload = vec![0u8; payload_len];
if payload_len > 0 {
self.stream.read_exact(&mut payload)?;
}
Ok(CtdbMessage { header, payload })
}
pub fn ping(&mut self) -> Result<(), CtdbProtoError> {
let msg = CtdbMessage::new(CtdbCommand::Ping, CtdbStatus::Success, vec![]);
self.send_message(&msg)?;
let resp = self.recv_message()?;
if resp.command() == CtdbCommand::Pong && resp.status() == CtdbStatus::Success {
Ok(())
} else {
Err(CtdbProtoError::InvalidPayload)
}
}
pub fn store(&mut self, key: &[u8], value: &[u8]) -> Result<bool, CtdbProtoError> {
let payload = payload::encode_kv(key, value);
let msg = CtdbMessage::new(CtdbCommand::Store, CtdbStatus::Success, payload);
self.send_message(&msg)?;
let resp = self.recv_message()?;
Ok(resp.status() == CtdbStatus::Success)
}
pub fn fetch(&mut self, key: &[u8]) -> Result<Vec<u8>, CtdbProtoError> {
let payload = payload::encode_key(key);
let msg = CtdbMessage::new(CtdbCommand::Fetch, CtdbStatus::Success, payload);
self.send_message(&msg)?;
let resp = self.recv_message()?;
if resp.status() == CtdbStatus::Success {
Ok(resp.payload)
} else {
Err(CtdbProtoError::InvalidPayload)
}
}
pub fn delete(&mut self, key: &[u8]) -> Result<bool, CtdbProtoError> {
let payload = payload::encode_key(key);
let msg = CtdbMessage::new(CtdbCommand::Delete, CtdbStatus::Success, payload);
self.send_message(&msg)?;
let resp = self.recv_message()?;
Ok(resp.status() == CtdbStatus::Success)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_header_roundtrip() {
let header = CtdbHeader::new(CtdbCommand::Ping, CtdbStatus::Success, 42);
let bytes = header.to_bytes();
let restored = CtdbHeader::from_bytes(&bytes).unwrap();
assert_eq!(restored.magic, CTDB_MAGIC);
assert_eq!(restored.version, CTDB_VERSION);
assert_eq!(restored.command, CtdbCommand::Ping as u32);
assert_eq!(restored.status, CtdbStatus::Success as u32);
assert_eq!(restored.length, 42);
assert!(restored.is_valid());
}
#[test]
fn test_message_roundtrip() {
let msg = CtdbMessage::new(
CtdbCommand::Store,
CtdbStatus::Success,
b"test_payload".to_vec(),
);
let bytes = msg.to_bytes();
let restored = CtdbMessage::from_bytes(&bytes).unwrap();
assert_eq!(restored.command(), CtdbCommand::Store);
assert_eq!(restored.payload, b"test_payload");
}
#[test]
fn test_command_from_u32() {
assert_eq!(CtdbCommand::from_u32(1), CtdbCommand::Connect);
assert_eq!(CtdbCommand::from_u32(3), CtdbCommand::Ping);
assert_eq!(CtdbCommand::from_u32(0xFFFF), CtdbCommand::Unknown);
assert_eq!(CtdbCommand::from_u32(999), CtdbCommand::Unknown);
}
#[test]
fn test_status_from_u32() {
assert_eq!(CtdbStatus::from_u32(0), CtdbStatus::Success);
assert_eq!(CtdbStatus::from_u32(2), CtdbStatus::NotFound);
assert_eq!(CtdbStatus::from_u32(99), CtdbStatus::Error);
}
#[test]
fn test_payload_encode_decode_kv() {
let (key, val) = (b"mykey", b"myvalue");
let encoded = payload::encode_kv(key, val);
let (k, v) = payload::decode_kv(&encoded).unwrap();
assert_eq!(k, key);
assert_eq!(v, val);
}
#[test]
fn test_payload_encode_decode_key() {
let key = b"test_key";
let encoded = payload::encode_key(key);
let decoded = payload::decode_key(&encoded).unwrap();
assert_eq!(decoded, key);
}
#[test]
fn test_payload_node_id() {
let encoded = payload::encode_node_id(42);
let decoded = payload::decode_node_id(&encoded).unwrap();
assert_eq!(decoded, 42);
}
#[test]
fn test_payload_nodemask() {
let ids = vec![0u32, 1, 2, 3];
let encoded = payload::encode_nodemask(&ids);
let decoded = payload::decode_nodemask(&encoded).unwrap();
assert_eq!(decoded, ids);
}
#[test]
fn test_payload_ip() {
let encoded = payload::encode_ip("192.168.1.100", "eth0");
let (ip, iface) = payload::decode_ip(&encoded).unwrap();
assert_eq!(ip, "192.168.1.100");
assert_eq!(iface, "eth0");
}
#[test]
fn test_invalid_magic() {
let mut bad_header = CtdbHeader::new(CtdbCommand::Ping, CtdbStatus::Success, 0);
bad_header.magic = 0xDEADBEEF;
assert!(!bad_header.is_valid());
}
#[test]
fn test_empty_message() {
let msg = CtdbMessage::new(CtdbCommand::Connect, CtdbStatus::Success, vec![]);
let bytes = msg.to_bytes();
let restored = CtdbMessage::from_bytes(&bytes).unwrap();
assert!(restored.payload.is_empty());
assert_eq!(restored.header.length, 0);
}
#[test]
fn test_header_too_short() {
let result = CtdbHeader::from_bytes(&[0u8; 10]);
assert!(result.is_err());
}
#[test]
fn test_large_payload() {
let large = vec![0xABu8; 65000];
let msg = CtdbMessage::new(CtdbCommand::Fetch, CtdbStatus::Success, large.clone());
let bytes = msg.to_bytes();
let restored = CtdbMessage::from_bytes(&bytes).unwrap();
assert_eq!(restored.payload.len(), 65000);
assert_eq!(restored.payload, large);
}
}

View File

@@ -0,0 +1,417 @@
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use super::ip_manager::IpManager;
use super::node::{NodeManager, NodeId, NodeState};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecoveryState {
Idle,
Initiated,
InProgress,
Verifying,
Completed,
Failed,
}
impl RecoveryState {
pub fn as_str(&self) -> &'static str {
match self {
RecoveryState::Idle => "IDLE",
RecoveryState::Initiated => "INITIATED",
RecoveryState::InProgress => "IN_PROGRESS",
RecoveryState::Verifying => "VERIFYING",
RecoveryState::Completed => "COMPLETED",
RecoveryState::Failed => "FAILED",
}
}
}
#[derive(Debug, Clone)]
pub struct RecoveryEvent {
pub timestamp: Instant,
pub event_type: RecoveryEventType,
pub node_id: Option<NodeId>,
pub message: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecoveryEventType {
NodeDown,
NodeRejoin,
IpMigration,
RecoveryStart,
RecoveryComplete,
RecoveryFailed,
}
pub struct RecoveryManager {
state: RwLock<RecoveryState>,
events: RwLock<Vec<RecoveryEvent>>,
recovery_cooldown: Duration,
last_recovery: RwLock<Option<Instant>>,
max_events: usize,
}
impl RecoveryManager {
pub fn new() -> Self {
Self {
state: RwLock::new(RecoveryState::Idle),
events: RwLock::new(Vec::new()),
recovery_cooldown: Duration::from_secs(30),
last_recovery: RwLock::new(None),
max_events: 100,
}
}
pub fn with_cooldown(mut self, cooldown: Duration) -> Self {
self.recovery_cooldown = cooldown;
self
}
pub fn state(&self) -> RecoveryState {
*self.state.read().unwrap()
}
pub fn set_state(&self, state: RecoveryState) {
*self.state.write().unwrap() = state;
}
pub fn log_event(&self, event_type: RecoveryEventType, node_id: Option<NodeId>, msg: &str) {
let event = RecoveryEvent {
timestamp: Instant::now(),
event_type,
node_id,
message: msg.to_string(),
};
let mut events = self.events.write().unwrap();
events.push(event);
if events.len() > self.max_events {
events.remove(0);
}
}
pub fn events(&self) -> Vec<RecoveryEvent> {
self.events.read().unwrap().clone()
}
pub fn can_recover(&self) -> bool {
match *self.last_recovery.read().unwrap() {
Some(t) => t.elapsed() > self.recovery_cooldown,
None => true,
}
}
pub fn handle_node_failure(
&self,
failed_node: NodeId,
node_mgr: &NodeManager,
ip_mgr: &IpManager,
) -> Vec<(super::ip_manager::PublicIpId, NodeId)> {
self.log_event(RecoveryEventType::NodeDown, Some(failed_node), &format!(
"Node {:?} marked DOWN", failed_node
));
if !self.can_recover() {
self.log_event(
RecoveryEventType::RecoveryFailed,
Some(failed_node),
"Recovery skipped: cooldown active",
);
return Vec::new();
}
self.set_state(RecoveryState::Initiated);
self.log_event(RecoveryEventType::RecoveryStart, Some(failed_node), "Starting recovery");
*self.last_recovery.write().unwrap() = Some(Instant::now());
self.set_state(RecoveryState::InProgress);
let active_nodes: Vec<NodeId> = node_mgr
.active_nodes()
.iter()
.filter(|n| n.id != failed_node)
.map(|n| n.id)
.collect();
if active_nodes.is_empty() {
self.log_event(
RecoveryEventType::RecoveryFailed,
Some(failed_node),
"No active nodes available for IP migration",
);
self.set_state(RecoveryState::Failed);
return Vec::new();
}
let reassignments = ip_mgr.reassign_on_failure(failed_node, &active_nodes);
for (ip_id, new_owner) in &reassignments {
self.log_event(
RecoveryEventType::IpMigration,
Some(*new_owner),
&format!("IP {:?} migrated to node {:?}", ip_id, new_owner),
);
}
self.set_state(RecoveryState::Completed);
self.log_event(
RecoveryEventType::RecoveryComplete,
Some(failed_node),
&format!("Recovery complete: {} IPs migrated", reassignments.len()),
);
reassignments
}
pub fn handle_node_rejoin(
&self,
rejoining_node: NodeId,
node_mgr: &NodeManager,
ip_mgr: &IpManager,
) -> Vec<(super::ip_manager::PublicIpId, NodeId)> {
self.log_event(RecoveryEventType::NodeRejoin, Some(rejoining_node), &format!(
"Node {:?} rejoining cluster", rejoining_node
));
let active_nodes: Vec<NodeId> = node_mgr.active_nodes().iter().map(|n| n.id).collect();
self.set_state(RecoveryState::InProgress);
let rebalance = ip_mgr.balance(&active_nodes);
for (ip_id, new_owner) in &rebalance {
self.log_event(
RecoveryEventType::IpMigration,
Some(*new_owner),
&format!("IP {:?} rebalanced to node {:?}", ip_id, new_owner),
);
}
self.set_state(RecoveryState::Completed);
self.log_event(
RecoveryEventType::RecoveryComplete,
Some(rejoining_node),
&format!("Rebalance complete: {} IPs reassigned", rebalance.len()),
);
rebalance
}
pub fn check_and_recover(
&self,
node_mgr: &NodeManager,
ip_mgr: &IpManager,
) -> Vec<(super::ip_manager::PublicIpId, NodeId)> {
let transitions = node_mgr.check_health();
if transitions.is_empty() {
return Vec::new();
}
let mut all_reassignments = Vec::new();
for (node_id, new_state) in &transitions {
match new_state {
NodeState::Down => {
let reassignments =
self.handle_node_failure(*node_id, node_mgr, ip_mgr);
all_reassignments.extend(reassignments);
}
NodeState::Up => {
let reassignments =
self.handle_node_rejoin(*node_id, node_mgr, ip_mgr);
all_reassignments.extend(reassignments);
}
_ => {}
}
}
all_reassignments
}
pub fn reset(&self) {
*self.state.write().unwrap() = RecoveryState::Idle;
*self.last_recovery.write().unwrap() = None;
}
pub fn events_by_node(&self, node_id: NodeId) -> Vec<RecoveryEvent> {
self.events
.read()
.unwrap()
.iter()
.filter(|e| e.node_id == Some(node_id))
.cloned()
.collect()
}
pub fn last_event(&self) -> Option<RecoveryEvent> {
self.events.read().unwrap().last().cloned()
}
}
impl Default for RecoveryManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::thread;
fn addr(port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
}
fn ip(a: u8, b: u8, c: u8, d: u8) -> std::net::IpAddr {
std::net::IpAddr::V4(Ipv4Addr::new(a, b, c, d))
}
fn setup() -> (NodeManager, IpManager, RecoveryManager) {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
mgr.add_node(NodeId(2), addr(4002));
mgr.record_heartbeat(NodeId(1));
mgr.record_heartbeat(NodeId(2));
let ip_mgr = IpManager::new();
ip_mgr.add_ip(ip(192, 168, 1, 100), "eth0");
ip_mgr.add_ip(ip(192, 168, 1, 101), "eth0");
ip_mgr.add_ip(ip(192, 168, 1, 102), "eth1");
ip_mgr.assign_ip(super::super::ip_manager::PublicIpId(0), NodeId(0)).unwrap();
ip_mgr.assign_ip(super::super::ip_manager::PublicIpId(1), NodeId(1)).unwrap();
ip_mgr.assign_ip(super::super::ip_manager::PublicIpId(2), NodeId(2)).unwrap();
let recovery = RecoveryManager::new();
(mgr, ip_mgr, recovery)
}
#[test]
fn test_initial_state() {
let recovery = RecoveryManager::new();
assert_eq!(recovery.state(), RecoveryState::Idle);
assert!(recovery.can_recover());
assert!(recovery.events().is_empty());
}
#[test]
fn test_node_failure_triggers_recovery() {
let (mgr, ip_mgr, recovery) = setup();
let reassignments = recovery.handle_node_failure(NodeId(1), &mgr, &ip_mgr);
assert!(!reassignments.is_empty());
assert_eq!(recovery.state(), RecoveryState::Completed);
let events = recovery.events();
assert!(events.iter().any(|e| e.event_type == RecoveryEventType::NodeDown));
assert!(events.iter().any(|e| e.event_type == RecoveryEventType::RecoveryStart));
assert!(events.iter().any(|e| e.event_type == RecoveryEventType::RecoveryComplete));
}
#[test]
fn test_no_active_nodes_fails() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
let ip_mgr = IpManager::new();
ip_mgr.add_ip(ip(192, 168, 1, 100), "eth0");
let recovery = RecoveryManager::new();
let reassignments = recovery.handle_node_failure(NodeId(0), &mgr, &ip_mgr);
assert!(reassignments.is_empty());
assert_eq!(recovery.state(), RecoveryState::Failed);
}
#[test]
fn test_cooldown_prevents_rapid_recovery() {
let (mgr, ip_mgr, recovery) = setup();
recovery.handle_node_failure(NodeId(1), &mgr, &ip_mgr);
assert!(!recovery.can_recover());
}
#[test]
fn test_cooldown_expires() {
let recovery = RecoveryManager::new().with_cooldown(Duration::from_millis(10));
*recovery.last_recovery.write().unwrap() = Some(Instant::now());
assert!(!recovery.can_recover());
thread::sleep(Duration::from_millis(20));
assert!(recovery.can_recover());
}
#[test]
fn test_node_rejoin_balances() {
let (mgr, ip_mgr, recovery) = setup();
recovery.handle_node_failure(NodeId(1), &mgr, &ip_mgr);
thread::sleep(Duration::from_millis(10));
let recovery2 = RecoveryManager::new();
let reassignments = recovery2.handle_node_rejoin(NodeId(1), &mgr, &ip_mgr);
assert!(!reassignments.is_empty());
}
#[test]
fn test_log_event() {
let recovery = RecoveryManager::new();
recovery.log_event(RecoveryEventType::NodeDown, Some(NodeId(5)), "test message");
assert_eq!(recovery.events().len(), 1);
let event = &recovery.events()[0];
assert_eq!(event.node_id, Some(NodeId(5)));
assert_eq!(event.message, "test message");
}
#[test]
fn test_event_max_limit() {
let recovery = RecoveryManager::new();
for i in 0..200 {
recovery.log_event(RecoveryEventType::NodeDown, Some(NodeId(i)), "msg");
}
assert_eq!(recovery.events().len(), 100);
}
#[test]
fn test_events_by_node() {
let recovery = RecoveryManager::new();
recovery.log_event(RecoveryEventType::NodeDown, Some(NodeId(1)), "a");
recovery.log_event(RecoveryEventType::NodeDown, Some(NodeId(2)), "b");
recovery.log_event(RecoveryEventType::IpMigration, Some(NodeId(1)), "c");
let node1_events = recovery.events_by_node(NodeId(1));
assert_eq!(node1_events.len(), 2);
}
#[test]
fn test_last_event() {
let recovery = RecoveryManager::new();
recovery.log_event(RecoveryEventType::NodeDown, Some(NodeId(0)), "first");
recovery.log_event(RecoveryEventType::RecoveryComplete, Some(NodeId(0)), "last");
let last = recovery.last_event().unwrap();
assert_eq!(last.message, "last");
}
#[test]
fn test_reset() {
let (mgr, ip_mgr, recovery) = setup();
recovery.handle_node_failure(NodeId(1), &mgr, &ip_mgr);
assert_ne!(recovery.state(), RecoveryState::Idle);
recovery.reset();
assert_eq!(recovery.state(), RecoveryState::Idle);
assert!(recovery.can_recover());
}
#[test]
fn test_check_and_recover_no_transitions() {
let (mgr, ip_mgr, recovery) = setup();
let result = recovery.check_and_recover(&mgr, &ip_mgr);
assert!(result.is_empty());
}
#[test]
fn test_recovery_state_string() {
assert_eq!(RecoveryState::Idle.as_str(), "IDLE");
assert_eq!(RecoveryState::InProgress.as_str(), "IN_PROGRESS");
assert_eq!(RecoveryState::Failed.as_str(), "FAILED");
assert_eq!(RecoveryState::Completed.as_str(), "COMPLETED");
}
}

View File

@@ -0,0 +1,678 @@
use std::collections::HashMap;
use std::io::{self, Read, Write, Seek, SeekFrom};
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};
const TDB_MAGIC: u32 = 0x1BADFACE;
const TDB_VERSION: u32 = 1;
const DEFAULT_HASH_SIZE: u32 = 1024;
const TDB_HEADER_SIZE: u64 = 128;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecordFlag {
Active,
Free,
Deleted,
}
impl RecordFlag {
fn as_u32(&self) -> u32 {
match self {
RecordFlag::Active => 0,
RecordFlag::Free => 1,
RecordFlag::Deleted => 2,
}
}
fn from_u32(v: u32) -> Self {
match v {
1 => RecordFlag::Free,
2 => RecordFlag::Deleted,
_ => RecordFlag::Active,
}
}
}
#[derive(Debug, Clone)]
pub struct TdbRecord {
pub key: Vec<u8>,
pub data: Vec<u8>,
pub flag: RecordFlag,
pub hash_next: u64,
}
impl TdbRecord {
pub fn new(key: Vec<u8>, data: Vec<u8>) -> Self {
Self {
key,
data,
flag: RecordFlag::Active,
hash_next: 0,
}
}
pub fn key_str(&self) -> &str {
std::str::from_utf8(&self.key).unwrap_or("")
}
pub fn data_str(&self) -> &str {
std::str::from_utf8(&self.data).unwrap_or("")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TdbError {
IoError,
NotFound,
Corrupt,
Exists,
LockFailed,
}
impl std::fmt::Display for TdbError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TdbError::IoError => write!(f, "I/O error"),
TdbError::NotFound => write!(f, "record not found"),
TdbError::Corrupt => write!(f, "database corrupt"),
TdbError::Exists => write!(f, "record already exists"),
TdbError::LockFailed => write!(f, "lock failed"),
}
}
}
impl std::error::Error for TdbError {}
impl From<io::Error> for TdbError {
fn from(_: io::Error) -> Self {
TdbError::IoError
}
}
pub type TdbResult<T> = Result<T, TdbError>;
fn hash_key(key: &[u8], hash_size: u32) -> u32 {
let mut h: u32 = 0;
for &b in key {
h = h.wrapping_mul(31).wrapping_add(b as u32);
}
h % hash_size
}
#[derive(Debug, Clone)]
struct HashEntry {
offset: u64,
}
struct TdbHeader {
magic: u32,
version: u32,
hash_size: u32,
record_count: u64,
}
impl TdbHeader {
fn new(hash_size: u32) -> Self {
Self {
magic: TDB_MAGIC,
version: TDB_VERSION,
hash_size,
record_count: 0,
}
}
fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(TDB_HEADER_SIZE as usize);
buf.extend_from_slice(&self.magic.to_le_bytes());
buf.extend_from_slice(&self.version.to_le_bytes());
buf.extend_from_slice(&self.hash_size.to_le_bytes());
buf.extend_from_slice(&self.record_count.to_le_bytes());
while buf.len() < TDB_HEADER_SIZE as usize {
buf.push(0);
}
buf
}
fn from_bytes(buf: &[u8]) -> TdbResult<Self> {
if buf.len() < 20 {
return Err(TdbError::Corrupt);
}
let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap());
if magic != TDB_MAGIC {
return Err(TdbError::Corrupt);
}
let version = u32::from_le_bytes(buf[4..8].try_into().unwrap());
let hash_size = u32::from_le_bytes(buf[8..12].try_into().unwrap());
let record_count = u64::from_le_bytes(buf[12..20].try_into().unwrap());
Ok(Self {
magic,
version,
hash_size,
record_count,
})
}
}
pub struct TdbEngine {
header: RwLock<TdbHeader>,
records: RwLock<HashMap<Vec<u8>, TdbRecord>>,
hash_table: RwLock<Vec<Vec<u64>>>,
file_path: Option<std::path::PathBuf>,
dirty: Mutex<bool>,
}
impl TdbEngine {
pub fn new() -> Self {
Self::with_hash_size(DEFAULT_HASH_SIZE)
}
pub fn with_hash_size(hash_size: u32) -> Self {
Self {
header: RwLock::new(TdbHeader::new(hash_size)),
records: RwLock::new(HashMap::new()),
hash_table: RwLock::new(vec![Vec::new(); 1]),
file_path: None,
dirty: Mutex::new(false),
}
}
pub fn open<P: AsRef<Path>>(path: P) -> TdbResult<Self> {
let path = path.as_ref().to_path_buf();
if path.exists() {
Self::load_from_file(&path)
} else {
let engine = Self::new();
let mut engine = engine;
engine.file_path = Some(path.clone());
engine.save_to_file(&path)?;
Ok(engine)
}
}
pub fn create<P: AsRef<Path>>(path: P, hash_size: u32) -> TdbResult<Self> {
let path = path.as_ref().to_path_buf();
let engine = Self::with_hash_size(hash_size);
let mut engine = engine;
engine.file_path = Some(path.clone());
engine.save_to_file(&path)?;
Ok(engine)
}
fn load_from_file(path: &Path) -> TdbResult<Self> {
let mut file = std::fs::File::open(path)?;
let mut header_buf = vec![0u8; TDB_HEADER_SIZE as usize];
file.read_exact(&mut header_buf)?;
let header = TdbHeader::from_bytes(&header_buf)?;
let hash_size = header.hash_size;
let mut hash_table = vec![Vec::new(); hash_size as usize];
let mut records = HashMap::new();
let mut pos = TDB_HEADER_SIZE;
let file_meta = file.metadata()?;
let file_size = file_meta.len();
while pos < file_size {
file.seek(SeekFrom::Start(pos))?;
let mut flag_buf = [0u8; 4];
let mut key_len_buf = [0u8; 4];
let mut data_len_buf = [0u8; 4];
let mut hash_next_buf = [0u8; 8];
if file.read_exact(&mut flag_buf).is_err() {
break;
}
let _ = file.read_exact(&mut key_len_buf);
let _ = file.read_exact(&mut data_len_buf);
let _ = file.read_exact(&mut hash_next_buf);
let flag = RecordFlag::from_u32(u32::from_le_bytes(flag_buf));
let key_len = u32::from_le_bytes(key_len_buf) as usize;
let data_len = u32::from_le_bytes(data_len_buf) as usize;
let hash_next = u64::from_le_bytes(hash_next_buf);
let mut key = vec![0u8; key_len];
let mut data = vec![0u8; data_len];
let _ = file.read_exact(&mut key);
let _ = file.read_exact(&mut data);
let rec_offset = pos;
let bucket = hash_key(&key, hash_size) as usize;
if bucket < hash_table.len() {
hash_table[bucket].push(rec_offset);
}
if flag == RecordFlag::Active {
let record = TdbRecord {
key: key.clone(),
data,
flag,
hash_next,
};
records.insert(key, record);
}
let rec_size = 4 + 4 + 4 + 8 + key_len as u64 + data_len as u64;
pos += rec_size;
}
let record_count = records.len() as u64;
let mut header = header;
header.record_count = record_count;
Ok(Self {
header: RwLock::new(header),
records: RwLock::new(records),
hash_table: RwLock::new(hash_table),
file_path: Some(path.to_path_buf()),
dirty: Mutex::new(false),
})
}
fn save_to_file(&self, path: &Path) -> TdbResult<()> {
let header = self.header.read().unwrap();
let header_bytes = header.to_bytes();
let mut file = std::fs::File::create(path)?;
file.write_all(&header_bytes)?;
let records = self.records.read().unwrap();
for (_, record) in records.iter() {
let flag = record.flag.as_u32().to_le_bytes();
let key_len = (record.key.len() as u32).to_le_bytes();
let data_len = (record.data.len() as u32).to_le_bytes();
let hash_next = record.hash_next.to_le_bytes();
file.write_all(&flag)?;
file.write_all(&key_len)?;
file.write_all(&data_len)?;
file.write_all(&hash_next)?;
file.write_all(&record.key)?;
file.write_all(&record.data)?;
}
file.flush()?;
Ok(())
}
pub fn store(&self, key: Vec<u8>, data: Vec<u8>) -> TdbResult<bool> {
let mut records = self.records.write().unwrap();
let existed = records.contains_key(&key);
let record = TdbRecord::new(key.clone(), data);
records.insert(key, record);
*self.dirty.lock().unwrap() = true;
let mut header = self.header.write().unwrap();
header.record_count = records.len() as u64;
drop(header);
drop(records);
self.try_flush()?;
Ok(!existed)
}
pub fn fetch(&self, key: &[u8]) -> TdbResult<Vec<u8>> {
let records = self.records.read().unwrap();
match records.get(key) {
Some(record) => Ok(record.data.clone()),
None => Err(TdbError::NotFound),
}
}
pub fn delete(&self, key: &[u8]) -> TdbResult<()> {
let mut records = self.records.write().unwrap();
if records.remove(key).is_some() {
*self.dirty.lock().unwrap() = true;
let mut header = self.header.write().unwrap();
header.record_count = records.len() as u64;
drop(header);
drop(records);
self.try_flush()?;
Ok(())
} else {
Err(TdbError::NotFound)
}
}
pub fn exists(&self, key: &[u8]) -> bool {
self.records.read().unwrap().contains_key(key)
}
pub fn keys(&self) -> Vec<Vec<u8>> {
self.records.read().unwrap().keys().cloned().collect()
}
pub fn record_count(&self) -> u64 {
self.header.read().unwrap().record_count
}
pub fn iter(&self) -> Vec<(Vec<u8>, Vec<u8>)> {
self.records
.read()
.unwrap()
.iter()
.map(|(k, v)| (k.clone(), v.data.clone()))
.collect()
}
pub fn try_flush(&self) -> TdbResult<()> {
let dirty = *self.dirty.lock().unwrap();
if dirty {
if let Some(ref path) = self.file_path {
self.save_to_file(path)?;
*self.dirty.lock().unwrap() = false;
}
}
Ok(())
}
pub fn flush(&self) -> TdbResult<()> {
if let Some(ref path) = self.file_path {
self.save_to_file(path)?;
*self.dirty.lock().unwrap() = false;
}
Ok(())
}
}
impl Default for TdbEngine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TdbStats {
pub record_count: u64,
pub hash_size: u32,
pub active_records: u64,
pub total_data_size: u64,
pub avg_key_size: f64,
pub avg_data_size: f64,
}
impl TdbEngine {
pub fn stats(&self) -> TdbStats {
let records = self.records.read().unwrap();
let header = self.header.read().unwrap();
let count = records.len();
let total_data: u64 = records.values().map(|r| r.data.len() as u64).sum();
let total_key: u64 = records.values().map(|r| r.key.len() as u64).sum();
TdbStats {
record_count: header.record_count,
hash_size: header.hash_size,
active_records: count as u64,
total_data_size: total_data,
avg_key_size: if count > 0 {
total_key as f64 / count as f64
} else {
0.0
},
avg_data_size: if count > 0 {
total_data as f64 / count as f64
} else {
0.0
},
}
}
}
pub struct TdbTransaction<'a> {
engine: &'a TdbEngine,
snapshot: HashMap<Vec<u8>, TdbRecord>,
committed: bool,
}
impl<'a> TdbTransaction<'a> {
pub fn begin(engine: &'a TdbEngine) -> Self {
let snapshot = engine.records.read().unwrap().clone();
Self {
engine,
snapshot,
committed: false,
}
}
pub fn store(&self, key: Vec<u8>, data: Vec<u8>) -> TdbResult<bool> {
self.engine.store(key, data)
}
pub fn fetch(&self, key: &[u8]) -> TdbResult<Vec<u8>> {
self.engine.fetch(key)
}
pub fn delete(&self, key: &[u8]) -> TdbResult<()> {
self.engine.delete(key)
}
pub fn commit(&mut self) -> TdbResult<()> {
self.committed = true;
self.engine.flush()
}
pub fn rollback(&mut self) -> TdbResult<()> {
if !self.committed {
let mut records = self.engine.records.write().unwrap();
*records = self.snapshot.clone();
let mut header = self.engine.header.write().unwrap();
header.record_count = records.len() as u64;
*self.engine.dirty.lock().unwrap() = true;
self.engine.flush()?;
}
Ok(())
}
}
impl<'a> Drop for TdbTransaction<'a> {
fn drop(&mut self) {
if !self.committed {
let _ = self.rollback();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_tdb_create_and_store() {
let engine = TdbEngine::new();
let created = engine.store(b"key1".to_vec(), b"value1".to_vec()).unwrap();
assert!(created);
assert_eq!(engine.fetch(b"key1").unwrap(), b"value1");
assert_eq!(engine.record_count(), 1);
}
#[test]
fn test_tdb_overwrite() {
let engine = TdbEngine::new();
engine.store(b"key1".to_vec(), b"v1".to_vec()).unwrap();
let created = engine.store(b"key1".to_vec(), b"v2".to_vec()).unwrap();
assert!(!created);
assert_eq!(engine.fetch(b"key1").unwrap(), b"v2");
assert_eq!(engine.record_count(), 1);
}
#[test]
fn test_tdb_delete() {
let engine = TdbEngine::new();
engine.store(b"key1".to_vec(), b"v1".to_vec()).unwrap();
engine.delete(b"key1").unwrap();
assert!(!engine.exists(b"key1"));
assert_eq!(engine.record_count(), 0);
assert!(engine.fetch(b"key1").is_err());
}
#[test]
fn test_tdb_not_found() {
let engine = TdbEngine::new();
assert!(engine.fetch(b"missing").is_err());
assert!(engine.delete(b"missing").is_err());
}
#[test]
fn test_tdb_multiple_records() {
let engine = TdbEngine::new();
for i in 0..100 {
engine.store(
format!("key{}", i).into_bytes(),
format!("value{}", i).into_bytes(),
).unwrap();
}
assert_eq!(engine.record_count(), 100);
assert_eq!(engine.fetch(b"key50").unwrap(), b"value50");
}
#[test]
fn test_tdb_persist_to_disk() {
let tmp = TempDir::new().unwrap();
let db_path = tmp.path().join("test.tdb");
{
let engine = TdbEngine::open(&db_path).unwrap();
engine.store(b"persistent".to_vec(), b"yes".to_vec()).unwrap();
engine.store(b"key2".to_vec(), b"val2".to_vec()).unwrap();
engine.flush().unwrap();
}
let engine2 = TdbEngine::open(&db_path).unwrap();
assert_eq!(engine2.fetch(b"persistent").unwrap(), b"yes");
assert_eq!(engine2.fetch(b"key2").unwrap(), b"val2");
assert_eq!(engine2.record_count(), 2);
}
#[test]
fn test_tdb_keys() {
let engine = TdbEngine::new();
engine.store(b"a".to_vec(), b"1".to_vec()).unwrap();
engine.store(b"b".to_vec(), b"2".to_vec()).unwrap();
engine.store(b"c".to_vec(), b"3".to_vec()).unwrap();
let mut keys = engine.keys();
keys.sort();
assert_eq!(keys, vec![b"a".to_vec(), b"b".to_vec(), b"c".to_vec()]);
}
#[test]
fn test_tdb_iter() {
let engine = TdbEngine::new();
engine.store(b"k1".to_vec(), b"d1".to_vec()).unwrap();
engine.store(b"k2".to_vec(), b"d2".to_vec()).unwrap();
let entries: Vec<_> = engine.iter();
assert_eq!(entries.len(), 2);
}
#[test]
fn test_tdb_stats() {
let engine = TdbEngine::new();
engine.store(b"key1".to_vec(), b"12345".to_vec()).unwrap();
engine.store(b"key2".to_vec(), b"67890".to_vec()).unwrap();
let stats = engine.stats();
assert_eq!(stats.active_records, 2);
assert_eq!(stats.total_data_size, 10);
assert_eq!(stats.avg_key_size, 4.0);
assert_eq!(stats.avg_data_size, 5.0);
}
#[test]
fn test_tdb_create_with_hash_size() {
let engine = TdbEngine::with_hash_size(256);
let header = engine.header.read().unwrap();
assert_eq!(header.hash_size, 256);
}
#[test]
fn test_tdb_transaction_commit() {
let engine = TdbEngine::new();
engine.store(b"init".to_vec(), b"data".to_vec()).unwrap();
{
let mut tx = TdbTransaction::begin(&engine);
tx.store(b"tx_key".to_vec(), b"tx_data".to_vec()).unwrap();
tx.commit().unwrap();
}
assert!(engine.exists(b"tx_key"));
assert_eq!(engine.fetch(b"tx_key").unwrap(), b"tx_data");
}
#[test]
fn test_tdb_transaction_rollback() {
let engine = TdbEngine::new();
engine.store(b"keep".to_vec(), b"yes".to_vec()).unwrap();
{
let mut tx = TdbTransaction::begin(&engine);
tx.store(b"temp".to_vec(), b"no".to_vec()).unwrap();
tx.delete(b"keep").unwrap();
tx.rollback().unwrap();
}
assert!(!engine.exists(b"temp"));
assert!(engine.exists(b"keep"));
}
#[test]
fn test_tdb_transaction_auto_rollback_on_drop() {
let engine = TdbEngine::new();
engine.store(b"orig".to_vec(), b"val".to_vec()).unwrap();
{
let _tx = TdbTransaction::begin(&engine);
_tx.store(b"temp".to_vec(), b"x".to_vec()).unwrap();
}
assert!(!engine.exists(b"temp"));
assert!(engine.exists(b"orig"));
}
#[test]
fn test_hash_key_distribution() {
let hash_size = 1024u32;
let mut buckets = std::collections::HashSet::new();
for i in 0..1000 {
let key = format!("key_{}", i);
let bucket = hash_key(key.as_bytes(), hash_size);
buckets.insert(bucket);
}
assert!(buckets.len() > 200);
}
#[test]
fn test_tdb_corrupt_file_detection() {
let tmp = TempDir::new().unwrap();
let db_path = tmp.path().join("corrupt.tdb");
std::fs::write(&db_path, b"NOT_A_TDB_FILE_GARBAGE_DATA_HERE").unwrap();
assert!(TdbEngine::open(&db_path).is_err());
}
#[test]
fn test_record_flag_conversion() {
assert_eq!(RecordFlag::Active.as_u32(), 0);
assert_eq!(RecordFlag::Free.as_u32(), 1);
assert_eq!(RecordFlag::Deleted.as_u32(), 2);
assert_eq!(RecordFlag::from_u32(0), RecordFlag::Active);
assert_eq!(RecordFlag::from_u32(1), RecordFlag::Free);
assert_eq!(RecordFlag::from_u32(2), RecordFlag::Deleted);
}
#[test]
fn test_tdb_empty_keys_and_values() {
let engine = TdbEngine::new();
engine.store(vec![], vec![]).unwrap();
assert!(engine.exists(&[]));
assert_eq!(engine.fetch(&[]).unwrap(), b"");
}
#[test]
fn test_tdb_large_value() {
let engine = TdbEngine::new();
let large_data = vec![0x42u8; 100_000];
engine.store(b"big".to_vec(), large_data.clone()).unwrap();
let fetched = engine.fetch(b"big").unwrap();
assert_eq!(fetched.len(), 100_000);
assert_eq!(fetched, large_data);
}
#[test]
fn test_tdb_binary_keys() {
let engine = TdbEngine::new();
let key = vec![0u8, 1u8, 255u8, 128u8, 64u8];
engine.store(key.clone(), b"binary_val".to_vec()).unwrap();
assert_eq!(engine.fetch(&key).unwrap(), b"binary_val");
}
}