diff --git a/common/credential-verification/src/bandwidth_storage_manager.rs b/common/credential-verification/src/bandwidth_storage_manager.rs index 9a78b9f95e2..be03dfffa58 100644 --- a/common/credential-verification/src/bandwidth_storage_manager.rs +++ b/common/credential-verification/src/bandwidth_storage_manager.rs @@ -102,6 +102,7 @@ impl BandwidthStorageManager { } async fn expire_bandwidth(&mut self) -> Result<()> { + let _sync_guard = self.client_bandwidth.sync_guard().await; self.storage.reset_bandwidth(self.client_id).await?; self.client_bandwidth.expire_bandwidth().await; Ok(()) @@ -127,13 +128,15 @@ impl BandwidthStorageManager { #[instrument(level = "trace", skip_all)] pub async fn sync_storage_bandwidth(&mut self) -> Result<()> { trace!("syncing client bandwidth with the underlying storage"); - let updated = self - .storage - .increase_bandwidth( - self.client_id, - self.client_bandwidth.delta_since_sync().await, - ) - .await?; + let _sync_guard = self.client_bandwidth.sync_guard().await; + let delta = self.client_bandwidth.take_delta_since_sync().await; + let updated = match self.storage.increase_bandwidth(self.client_id, delta).await { + Ok(updated) => updated, + Err(err) => { + self.client_bandwidth.restore_delta_since_sync(delta).await; + return Err(err.into()); + } + }; self.client_bandwidth .resync_bandwidth_with_storage(updated) diff --git a/common/credential-verification/src/client_bandwidth.rs b/common/credential-verification/src/client_bandwidth.rs index 0a0acba8c1e..3acdbfd31b9 100644 --- a/common/credential-verification/src/client_bandwidth.rs +++ b/common/credential-verification/src/client_bandwidth.rs @@ -6,7 +6,7 @@ use nym_credentials_interface::AvailableBandwidth; use std::sync::Arc; use std::time::Duration; use time::OffsetDateTime; -use tokio::sync::RwLock; +use tokio::sync::{Mutex, OwnedMutexGuard, RwLock}; const DEFAULT_CLIENT_BANDWIDTH_MAX_FLUSHING_RATE: Duration = Duration::from_secs(5 * 60); // 5 minutes const DEFAULT_CLIENT_BANDWIDTH_MAX_DELTA_FLUSHING_AMOUNT: i64 = 5 * 1024 * 1024; // 5MB @@ -33,6 +33,7 @@ impl Default for BandwidthFlushingBehaviourConfig { #[derive(Debug, Clone)] pub struct ClientBandwidth { inner: Arc>, + sync_lock: Arc>, } #[derive(Debug)] @@ -56,6 +57,7 @@ impl ClientBandwidth { bytes_at_last_sync: bandwidth.bytes, bytes_delta_since_sync: 0, })), + sync_lock: Arc::new(Mutex::new(())), } } @@ -77,9 +79,27 @@ impl ClientBandwidth { self.inner.read().await.bandwidth.bytes } + #[cfg(test)] pub(crate) async fn delta_since_sync(&self) -> i64 { self.inner.read().await.bytes_delta_since_sync } + + pub(crate) async fn sync_guard(&self) -> OwnedMutexGuard<()> { + self.sync_lock.clone().lock_owned().await + } + + pub(crate) async fn take_delta_since_sync(&self) -> i64 { + let mut guard = self.inner.write().await; + let delta = guard.bytes_delta_since_sync; + guard.bytes_delta_since_sync = 0; + delta + } + + pub(crate) async fn restore_delta_since_sync(&self, delta: i64) { + let mut guard = self.inner.write().await; + guard.bytes_delta_since_sync += delta; + } + pub(crate) async fn expiration(&self) -> OffsetDateTime { self.inner.read().await.bandwidth.expiration } @@ -115,9 +135,74 @@ impl ClientBandwidth { pub(crate) async fn resync_bandwidth_with_storage(&self, stored: i64) { let mut guard = self.inner.write().await; - guard.bandwidth.bytes = stored; guard.bytes_at_last_sync = stored; - guard.bytes_delta_since_sync = 0; + guard.bandwidth.bytes = stored + guard.bytes_delta_since_sync; guard.last_synced = OffsetDateTime::now_utc(); } } + +#[cfg(test)] +mod tests { + use super::*; + + fn test_bandwidth(bytes: i64) -> ClientBandwidth { + ClientBandwidth::new(AvailableBandwidth { + bytes, + expiration: OffsetDateTime::UNIX_EPOCH, + }) + } + + #[tokio::test] + async fn resync_preserves_delta_accumulated_during_storage_sync() { + let bandwidth = test_bandwidth(1_000); + + bandwidth.decrease_bandwidth(100).await; + let reserved_delta = bandwidth.take_delta_since_sync().await; + assert_eq!(reserved_delta, -100); + assert_eq!(bandwidth.delta_since_sync().await, 0); + + bandwidth.decrease_bandwidth(50).await; + bandwidth.resync_bandwidth_with_storage(900).await; + + assert_eq!(bandwidth.available().await, 850); + assert_eq!(bandwidth.delta_since_sync().await, -50); + } + + #[tokio::test] + async fn failed_sync_restores_reserved_delta() { + let bandwidth = test_bandwidth(1_000); + + bandwidth.decrease_bandwidth(100).await; + let reserved_delta = bandwidth.take_delta_since_sync().await; + bandwidth.decrease_bandwidth(25).await; + bandwidth.restore_delta_since_sync(reserved_delta).await; + + assert_eq!(bandwidth.available().await, 875); + assert_eq!(bandwidth.delta_since_sync().await, -125); + } + + #[tokio::test] + async fn old_read_only_sync_could_apply_the_same_delta_twice() { + let old_behaviour = test_bandwidth(1_000); + old_behaviour.decrease_bandwidth(100).await; + + let old_first_sync_delta = old_behaviour.delta_since_sync().await; + let old_second_sync_delta = old_behaviour.delta_since_sync().await; + let old_stored = 1_000 + old_first_sync_delta + old_second_sync_delta; + + assert_eq!(old_first_sync_delta, -100); + assert_eq!(old_second_sync_delta, -100); + assert_eq!(old_stored, 800); + + let new_behaviour = test_bandwidth(1_000); + new_behaviour.decrease_bandwidth(100).await; + + let new_first_sync_delta = new_behaviour.take_delta_since_sync().await; + let new_second_sync_delta = new_behaviour.take_delta_since_sync().await; + let new_stored = 1_000 + new_first_sync_delta + new_second_sync_delta; + + assert_eq!(new_first_sync_delta, -100); + assert_eq!(new_second_sync_delta, 0); + assert_eq!(new_stored, 900); + } +} diff --git a/common/gateway-storage/src/wireguard_peers.rs b/common/gateway-storage/src/wireguard_peers.rs index c999d093589..e7bdd953856 100644 --- a/common/gateway-storage/src/wireguard_peers.rs +++ b/common/gateway-storage/src/wireguard_peers.rs @@ -25,20 +25,23 @@ impl WgPeerManager { /// /// * `peer`: peer information needed by wireguard interface. pub(crate) async fn insert_peer(&self, peer: &WireguardPeer) -> Result<(), sqlx::Error> { + let psk = peer.psk.as_deref(); sqlx::query!( r#" - INSERT OR IGNORE INTO wireguard_peer(public_key, allowed_ips, client_id) - VALUES (?, ?, ?); + INSERT OR IGNORE INTO wireguard_peer(public_key, allowed_ips, client_id, psk) + VALUES (?, ?, ?, ?); UPDATE wireguard_peer - SET allowed_ips = ?, client_id = ? + SET allowed_ips = ?, client_id = ?, psk = ? WHERE public_key = ? "#, peer.public_key, peer.allowed_ips, peer.client_id, + psk, peer.allowed_ips, peer.client_id, + psk, peer.public_key, ) .execute(&self.connection_pool) @@ -124,3 +127,60 @@ impl WgPeerManager { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::WireguardPeer; + use defguard_wireguard_rs::{host::Peer, key::Key, net::IpAddrMask}; + use std::net::Ipv4Addr; + + fn test_peer(public_key: Key, psk: Key) -> Peer { + let mut peer = Peer::new(public_key); + peer.allowed_ips = vec![IpAddrMask::new(Ipv4Addr::new(10, 0, 0, 2).into(), 32)]; + peer.preshared_key = Some(psk); + peer + } + + #[sqlx::test(migrations = "./migrations")] + async fn insert_peer_persists_psk_on_insert_and_update( + pool: sqlx::SqlitePool, + ) -> Result<(), Box> { + sqlx::query("INSERT INTO clients (id, client_type) VALUES (?, ?)") + .bind(1_i64) + .bind("entry_wireguard") + .execute(&pool) + .await?; + + let manager = WgPeerManager::new(pool); + let public_key = Key::new([1; 32]); + + let first_psk = Key::new([2; 32]); + let first_psk_hex = first_psk.to_lower_hex(); + let first_peer = test_peer(public_key.clone(), first_psk.clone()); + manager + .insert_peer(&WireguardPeer::from_defguard_peer(first_peer.clone(), 1)?) + .await?; + + let retrieved = manager + .retrieve_peer(&first_peer.public_key.to_string()) + .await? + .expect("peer should be present after insert"); + assert_eq!(retrieved.psk.as_deref(), Some(first_psk_hex.as_str())); + + let second_psk = Key::new([3; 32]); + let second_psk_hex = second_psk.to_lower_hex(); + let second_peer = test_peer(public_key, second_psk.clone()); + manager + .insert_peer(&WireguardPeer::from_defguard_peer(second_peer.clone(), 1)?) + .await?; + + let retrieved = manager + .retrieve_peer(&second_peer.public_key.to_string()) + .await? + .expect("peer should be present after update"); + assert_eq!(retrieved.psk.as_deref(), Some(second_psk_hex.as_str())); + + Ok(()) + } +} diff --git a/common/wireguard/src/peer_controller/mod.rs b/common/wireguard/src/peer_controller/mod.rs index 19f936b93b1..17a62265ec9 100644 --- a/common/wireguard/src/peer_controller/mod.rs +++ b/common/wireguard/src/peer_controller/mod.rs @@ -35,7 +35,7 @@ use std::{ }; use tokio::sync::{RwLock, mpsc}; use tokio_stream::{StreamExt, wrappers::IntervalStream}; -use tracing::{debug, error, info, trace}; +use tracing::{debug, error, info, trace, warn}; #[cfg(feature = "mock")] pub mod mock; @@ -212,13 +212,15 @@ impl PeerController { pub async fn remove_peer(&mut self, key: &Key) -> Result<()> { nym_metrics::inc!("wg_peer_removal_attempts"); + let stored_peer = self.handle_query_peer_by_key(key).await?; + self.ecash_verifier .storage() .remove_wireguard_peer(&key.to_string()) .await?; self.bw_storage_managers.remove(key); - if let Ok(Some(peer)) = self.handle_query_peer_by_key(key).await + if let Some(peer) = stored_peer && let Some(ip_pair) = allocated_ip_pair(&peer) { self.ip_pool.release(ip_pair) @@ -274,6 +276,64 @@ impl PeerController { )); }; + if let Some(existing_bandwidth_storage_manager) = + self.bw_storage_managers.get(&peer.public_key).cloned() + { + warn!( + "received duplicate add request for active WireGuard peer {}; refreshing existing peer handle state", + peer.public_key + ); + + let old_allowed_ips = existing_bandwidth_storage_manager.allowed_ips().await; + let old_ip_pair = allocated_ip_pair(&Peer { + allowed_ips: old_allowed_ips, + ..Default::default() + }); + let ip_pair_changed = old_ip_pair != Some(ip_pair); + if ip_pair_changed { + self.ip_pool.confirm_allocation(ip_pair)?; + } + + if let Err(e) = self.wg_api.configure_peer(peer) { + if ip_pair_changed { + self.ip_pool.release(ip_pair); + } + nym_metrics::inc!("wg_peer_addition_failed"); + nym_metrics::inc!("wg_config_errors_total"); + return Err(e.into()); + } + + { + let mut guard = existing_bandwidth_storage_manager.inner().write().await; + guard + .sync_storage_bandwidth() + .await + .map_err(|err| Error::Internal(format!("failed to sync bandwidth: {err}")))?; + *guard = Self::generate_bandwidth_manager( + self.ecash_verifier.storage(), + &peer.public_key, + ) + .await?; + } + + existing_bandwidth_storage_manager + .set_allowed_ips(peer.allowed_ips.clone()) + .await; + + if let Ok(host_information) = self.wg_api.read_interface_data() { + *self.host_information.write().await = host_information; + } + + if ip_pair_changed { + if let Some(old_ip_pair) = old_ip_pair { + self.ip_pool.release(old_ip_pair); + } + } + + nym_metrics::inc!("wg_peer_addition_success"); + return Ok(()); + } + // Try to configure WireGuard peer if let Err(e) = self.wg_api.configure_peer(peer) { nym_metrics::inc!("wg_peer_addition_failed"); @@ -288,7 +348,14 @@ impl PeerController { )), peer.allowed_ips.clone(), ); - let cached_peer_manager = CachedPeerManager::new(peer); + let mut cached_peer_manager = CachedPeerManager::new(peer); + let host_information = self.wg_api.read_interface_data().ok(); + if let Some(kernel_peer) = host_information + .as_ref() + .and_then(|host_information| host_information.peers.get(&peer.public_key)) + { + cached_peer_manager.update(kernel_peer.into()); + } let mut handle = PeerHandle::new( peer.public_key.clone(), self.host_information.clone(), @@ -298,15 +365,15 @@ impl PeerController { self.upgrade_mode.clone(), &self.shutdown_token, ); + let public_key = peer.public_key.clone(); + self.ip_pool.confirm_allocation(ip_pair)?; + self.bw_storage_managers - .insert(peer.public_key.clone(), bandwidth_storage_manager); + .insert(public_key.clone(), bandwidth_storage_manager); // try to immediately update the host information, to eliminate races - if let Ok(host_information) = self.wg_api.read_interface_data() { + if let Some(host_information) = host_information { *self.host_information.write().await = host_information; } - let public_key = peer.public_key.clone(); - - self.ip_pool.confirm_allocation(ip_pair)?; tokio::spawn(async move { handle.run().await; @@ -339,16 +406,18 @@ impl PeerController { } async fn ip_to_key(&self, ip: IpAddr) -> Result> { - Ok(self - .bw_storage_managers - .iter() - .find_map(|(key, bw_manager)| { - bw_manager - .allowed_ips() - .iter() - .find(|ip_mask| ip_mask.address == ip) - .and(Some(key.clone())) - })) + for (key, bw_manager) in &self.bw_storage_managers { + if bw_manager + .allowed_ips() + .await + .iter() + .any(|ip_mask| ip_mask.address == ip) + { + return Ok(Some(key.clone())); + } + } + + Ok(None) } async fn handle_query_peer_by_key(&self, key: &Key) -> Result> { @@ -750,6 +819,36 @@ pub async fn stop_controller(mut shutdown_manager: nym_task::ShutdownManager) { #[cfg(all(test, feature = "mock"))] mod tests { use super::*; + use defguard_wireguard_rs::net::IpAddrMask; + use nym_credentials_interface::TicketType; + use nym_gateway_storage::traits::BandwidthGatewayStorage; + + async fn allocate_ip_pair(request_tx: &mpsc::Sender) -> IpPair { + let (response_tx, response_rx) = oneshot::channel(); + request_tx + .send(PeerControlRequest::PreAllocateIpPair { response_tx }) + .await + .unwrap(); + response_rx.await.unwrap().unwrap() + } + + async fn add_peer(request_tx: &mpsc::Sender, peer: Peer) { + let (response_tx, response_rx) = oneshot::channel(); + request_tx + .send(PeerControlRequest::AddPeer { peer, response_tx }) + .await + .unwrap(); + response_rx.await.unwrap().unwrap() + } + + fn peer_with_ip_pair(public_key: Key, ip_pair: IpPair) -> Peer { + let mut peer = Peer::new(public_key); + peer.allowed_ips = vec![ + IpAddrMask::new(ip_pair.ipv4.into(), 32), + IpAddrMask::new(ip_pair.ipv6.into(), 128), + ]; + peer + } #[tokio::test] async fn start_and_stop() { @@ -757,4 +856,34 @@ mod tests { let (_, shutdown_manager) = start_controller(request_tx.clone(), request_rx); stop_controller(shutdown_manager).await; } + + #[tokio::test] + async fn duplicate_add_for_active_peer_is_idempotent() { + let (request_tx, request_rx) = mpsc::channel(4); + let (storage, shutdown_manager) = start_controller(request_tx.clone(), request_rx); + + let public_key = Key::new([7; 32]); + let first_ip_pair = allocate_ip_pair(&request_tx).await; + let first_peer = peer_with_ip_pair(public_key.clone(), first_ip_pair); + let client_id = storage + .insert_wireguard_peer(&first_peer, TicketType::V1WireguardEntry.into()) + .await + .unwrap(); + storage.create_bandwidth_entry(client_id).await.unwrap(); + + add_peer(&request_tx, first_peer.clone()).await; + add_peer(&request_tx, first_peer).await; + + let second_ip_pair = allocate_ip_pair(&request_tx).await; + let updated_peer = peer_with_ip_pair(public_key, second_ip_pair); + let reused_client_id = storage + .insert_wireguard_peer(&updated_peer, TicketType::V1WireguardEntry.into()) + .await + .unwrap(); + assert_eq!(client_id, reused_client_id); + + add_peer(&request_tx, updated_peer).await; + + stop_controller(shutdown_manager).await; + } } diff --git a/common/wireguard/src/peer_handle.rs b/common/wireguard/src/peer_handle.rs index b5dd5d83fc0..4c91bf82348 100644 --- a/common/wireguard/src/peer_handle.rs +++ b/common/wireguard/src/peer_handle.rs @@ -20,7 +20,7 @@ use tracing::{debug, error, trace, warn}; #[derive(Clone)] pub(crate) struct SharedBandwidthStorageManager { inner: Arc>, - allowed_ips: Vec, + allowed_ips: Arc>>, } impl SharedBandwidthStorageManager { @@ -28,15 +28,22 @@ impl SharedBandwidthStorageManager { inner: Arc>, allowed_ips: Vec, ) -> Self { - Self { inner, allowed_ips } + Self { + inner, + allowed_ips: Arc::new(RwLock::new(allowed_ips)), + } } pub(crate) fn inner(&self) -> &RwLock { &self.inner } - pub(crate) fn allowed_ips(&self) -> &[IpAddrMask] { - &self.allowed_ips + pub(crate) async fn allowed_ips(&self) -> Vec { + self.allowed_ips.read().await.clone() + } + + pub(crate) async fn set_allowed_ips(&self, allowed_ips: Vec) { + *self.allowed_ips.write().await = allowed_ips; } } diff --git a/common/wireguard/src/peer_storage_manager.rs b/common/wireguard/src/peer_storage_manager.rs index 00c439350de..6a848ed48e5 100644 --- a/common/wireguard/src/peer_storage_manager.rs +++ b/common/wireguard/src/peer_storage_manager.rs @@ -113,3 +113,27 @@ impl PeerInformation { .unwrap_or(i64::MAX) } } + +#[cfg(test)] +mod tests { + use super::*; + + const GIB: u64 = 1024 * 1024 * 1024; + + #[test] + fn seeding_cache_from_empty_peer_recounts_all_existing_kernel_traffic() { + let registration_peer = Peer::default(); + let cached_from_registration = PeerInformation::from(®istration_peer); + let kernel_peer = PeerInformation { + rx_bytes: 37 * GIB, + tx_bytes: 6 * GIB, + }; + + let old_spent = kernel_peer.consumed_kernel_bandwidth(&cached_from_registration); + assert_eq!(old_spent, (43 * GIB) as i64); + + let cached_from_kernel = kernel_peer; + let fixed_spent = kernel_peer.consumed_kernel_bandwidth(&cached_from_kernel); + assert_eq!(fixed_spent, 0); + } +} diff --git a/gateway/src/node/wireguard/new_peer_registration/lp.rs b/gateway/src/node/wireguard/new_peer_registration/lp.rs index 9e7a45d057f..6bf01ec2c04 100644 --- a/gateway/src/node/wireguard/new_peer_registration/lp.rs +++ b/gateway/src/node/wireguard/new_peer_registration/lp.rs @@ -20,19 +20,23 @@ impl PeerRegistrator { peer: PeerPublicKey, psk: Key, ) -> Result<(), GatewayWireguardError> { - // 1. check if the peer is currently being handled - if self.peer_manager.check_active_peer(peer).await? { - // 2. if so, force disconnect it (as we're handling new request from the same peer) - self.peer_manager.remove_peer(peer).await?; - } + let active_peer = if self.peer_manager.check_active_peer(peer).await? { + self.peer_manager.query_peer(peer).await? + } else { + None + }; - // 3. update the on-disk PSK let encoded_psk = psk.to_lower_hex(); self.ecash_verifier .storage() .update_peer_psk(&peer.to_string(), Some(&encoded_psk)) .await?; + if let Some(mut active_peer) = active_peer { + active_peer.preshared_key = Some(psk); + self.peer_manager.add_peer(active_peer).await?; + } + Ok(()) }