Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions common/credential-verification/src/bandwidth_storage_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ impl BandwidthStorageManager {
}

async fn expire_bandwidth(&mut self) -> Result<()> {
let _sync_guard = self.client_bandwidth.sync_guard().await;
self.storage.reset_bandwidth(self.client_id).await?;
self.client_bandwidth.expire_bandwidth().await;
Ok(())
Expand All @@ -127,13 +128,15 @@ impl BandwidthStorageManager {
#[instrument(level = "trace", skip_all)]
pub async fn sync_storage_bandwidth(&mut self) -> Result<()> {
trace!("syncing client bandwidth with the underlying storage");
let updated = self
.storage
.increase_bandwidth(
self.client_id,
self.client_bandwidth.delta_since_sync().await,
)
.await?;
let _sync_guard = self.client_bandwidth.sync_guard().await;
let delta = self.client_bandwidth.take_delta_since_sync().await;
let updated = match self.storage.increase_bandwidth(self.client_id, delta).await {
Ok(updated) => updated,
Err(err) => {
self.client_bandwidth.restore_delta_since_sync(delta).await;
return Err(err.into());
}
};

self.client_bandwidth
.resync_bandwidth_with_storage(updated)
Expand Down
91 changes: 88 additions & 3 deletions common/credential-verification/src/client_bandwidth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use nym_credentials_interface::AvailableBandwidth;
use std::sync::Arc;
use std::time::Duration;
use time::OffsetDateTime;
use tokio::sync::RwLock;
use tokio::sync::{Mutex, OwnedMutexGuard, RwLock};

const DEFAULT_CLIENT_BANDWIDTH_MAX_FLUSHING_RATE: Duration = Duration::from_secs(5 * 60); // 5 minutes
const DEFAULT_CLIENT_BANDWIDTH_MAX_DELTA_FLUSHING_AMOUNT: i64 = 5 * 1024 * 1024; // 5MB
Expand All @@ -33,6 +33,7 @@ impl Default for BandwidthFlushingBehaviourConfig {
#[derive(Debug, Clone)]
pub struct ClientBandwidth {
inner: Arc<RwLock<ClientBandwidthInner>>,
sync_lock: Arc<Mutex<()>>,
}

#[derive(Debug)]
Expand All @@ -56,6 +57,7 @@ impl ClientBandwidth {
bytes_at_last_sync: bandwidth.bytes,
bytes_delta_since_sync: 0,
})),
sync_lock: Arc::new(Mutex::new(())),
}
}

Expand All @@ -77,9 +79,27 @@ impl ClientBandwidth {
self.inner.read().await.bandwidth.bytes
}

#[cfg(test)]
pub(crate) async fn delta_since_sync(&self) -> i64 {
self.inner.read().await.bytes_delta_since_sync
}

pub(crate) async fn sync_guard(&self) -> OwnedMutexGuard<()> {
self.sync_lock.clone().lock_owned().await
}

pub(crate) async fn take_delta_since_sync(&self) -> i64 {
let mut guard = self.inner.write().await;
let delta = guard.bytes_delta_since_sync;
guard.bytes_delta_since_sync = 0;
delta
}

pub(crate) async fn restore_delta_since_sync(&self, delta: i64) {
let mut guard = self.inner.write().await;
guard.bytes_delta_since_sync += delta;
}

pub(crate) async fn expiration(&self) -> OffsetDateTime {
self.inner.read().await.bandwidth.expiration
}
Expand Down Expand Up @@ -115,9 +135,74 @@ impl ClientBandwidth {
pub(crate) async fn resync_bandwidth_with_storage(&self, stored: i64) {
let mut guard = self.inner.write().await;

guard.bandwidth.bytes = stored;
guard.bytes_at_last_sync = stored;
guard.bytes_delta_since_sync = 0;
guard.bandwidth.bytes = stored + guard.bytes_delta_since_sync;
guard.last_synced = OffsetDateTime::now_utc();
}
}

#[cfg(test)]
mod tests {
use super::*;

fn test_bandwidth(bytes: i64) -> ClientBandwidth {
ClientBandwidth::new(AvailableBandwidth {
bytes,
expiration: OffsetDateTime::UNIX_EPOCH,
})
}

#[tokio::test]
async fn resync_preserves_delta_accumulated_during_storage_sync() {
let bandwidth = test_bandwidth(1_000);

bandwidth.decrease_bandwidth(100).await;
let reserved_delta = bandwidth.take_delta_since_sync().await;
assert_eq!(reserved_delta, -100);
assert_eq!(bandwidth.delta_since_sync().await, 0);

bandwidth.decrease_bandwidth(50).await;
bandwidth.resync_bandwidth_with_storage(900).await;

assert_eq!(bandwidth.available().await, 850);
assert_eq!(bandwidth.delta_since_sync().await, -50);
}

#[tokio::test]
async fn failed_sync_restores_reserved_delta() {
let bandwidth = test_bandwidth(1_000);

bandwidth.decrease_bandwidth(100).await;
let reserved_delta = bandwidth.take_delta_since_sync().await;
bandwidth.decrease_bandwidth(25).await;
bandwidth.restore_delta_since_sync(reserved_delta).await;

assert_eq!(bandwidth.available().await, 875);
assert_eq!(bandwidth.delta_since_sync().await, -125);
}

#[tokio::test]
async fn old_read_only_sync_could_apply_the_same_delta_twice() {
let old_behaviour = test_bandwidth(1_000);
old_behaviour.decrease_bandwidth(100).await;

let old_first_sync_delta = old_behaviour.delta_since_sync().await;
let old_second_sync_delta = old_behaviour.delta_since_sync().await;
let old_stored = 1_000 + old_first_sync_delta + old_second_sync_delta;

assert_eq!(old_first_sync_delta, -100);
assert_eq!(old_second_sync_delta, -100);
assert_eq!(old_stored, 800);

let new_behaviour = test_bandwidth(1_000);
new_behaviour.decrease_bandwidth(100).await;

let new_first_sync_delta = new_behaviour.take_delta_since_sync().await;
let new_second_sync_delta = new_behaviour.take_delta_since_sync().await;
let new_stored = 1_000 + new_first_sync_delta + new_second_sync_delta;

assert_eq!(new_first_sync_delta, -100);
assert_eq!(new_second_sync_delta, 0);
assert_eq!(new_stored, 900);
}
}
66 changes: 63 additions & 3 deletions common/gateway-storage/src/wireguard_peers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,23 @@ impl WgPeerManager {
///
/// * `peer`: peer information needed by wireguard interface.
pub(crate) async fn insert_peer(&self, peer: &WireguardPeer) -> Result<(), sqlx::Error> {
let psk = peer.psk.as_deref();
sqlx::query!(
r#"
INSERT OR IGNORE INTO wireguard_peer(public_key, allowed_ips, client_id)
VALUES (?, ?, ?);
INSERT OR IGNORE INTO wireguard_peer(public_key, allowed_ips, client_id, psk)
VALUES (?, ?, ?, ?);

UPDATE wireguard_peer
SET allowed_ips = ?, client_id = ?
SET allowed_ips = ?, client_id = ?, psk = ?
WHERE public_key = ?
"#,
peer.public_key,
peer.allowed_ips,
peer.client_id,
psk,
peer.allowed_ips,
peer.client_id,
psk,
peer.public_key,
)
.execute(&self.connection_pool)
Expand Down Expand Up @@ -124,3 +127,60 @@ impl WgPeerManager {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::models::WireguardPeer;
use defguard_wireguard_rs::{host::Peer, key::Key, net::IpAddrMask};
use std::net::Ipv4Addr;

fn test_peer(public_key: Key, psk: Key) -> Peer {
let mut peer = Peer::new(public_key);
peer.allowed_ips = vec![IpAddrMask::new(Ipv4Addr::new(10, 0, 0, 2).into(), 32)];
peer.preshared_key = Some(psk);
peer
}

#[sqlx::test(migrations = "./migrations")]
async fn insert_peer_persists_psk_on_insert_and_update(
pool: sqlx::SqlitePool,
) -> Result<(), Box<dyn std::error::Error>> {
sqlx::query("INSERT INTO clients (id, client_type) VALUES (?, ?)")
.bind(1_i64)
.bind("entry_wireguard")
.execute(&pool)
.await?;

let manager = WgPeerManager::new(pool);
let public_key = Key::new([1; 32]);

let first_psk = Key::new([2; 32]);
let first_psk_hex = first_psk.to_lower_hex();
let first_peer = test_peer(public_key.clone(), first_psk.clone());
manager
.insert_peer(&WireguardPeer::from_defguard_peer(first_peer.clone(), 1)?)
.await?;

let retrieved = manager
.retrieve_peer(&first_peer.public_key.to_string())
.await?
.expect("peer should be present after insert");
assert_eq!(retrieved.psk.as_deref(), Some(first_psk_hex.as_str()));

let second_psk = Key::new([3; 32]);
let second_psk_hex = second_psk.to_lower_hex();
let second_peer = test_peer(public_key, second_psk.clone());
manager
.insert_peer(&WireguardPeer::from_defguard_peer(second_peer.clone(), 1)?)
.await?;

let retrieved = manager
.retrieve_peer(&second_peer.public_key.to_string())
.await?
.expect("peer should be present after update");
assert_eq!(retrieved.psk.as_deref(), Some(second_psk_hex.as_str()));

Ok(())
}
}
Loading