diff --git a/.cirrus.yml b/.cirrus.yml deleted file mode 100644 index 685a697..0000000 --- a/.cirrus.yml +++ /dev/null @@ -1,33 +0,0 @@ -compute_engine_instance: - image_project: rocky-linux-cloud - image: family/rocky-linux-9 - -task: - env: - PATH: /root/.cargo/bin:${PATH} - CODECOV_TOKEN: ENCRYPTED[5608a167c2ad93fcac429e78e49661525794539aad86af2553c6eb0d1f3dd583f75a9bb4a2864d761bb336f8eec5c68d] - - prepare_script: - - dnf install -y cmake librdmacm libibverbs gcc clang - - rdma link add rxe_eth0 type rxe netdev eth0 - rust_script: - - curl https://sh.rustup.rs -sSf --output rustup.sh - - sh rustup.sh -y - - curl -LsSf https://get.nexte.st/latest/linux | tar zxf - -C ${CARGO_HOME:-~/.cargo}/bin - - curl --proto '=https' --tlsv1.2 -sSf https://just.systems/install.sh | bash -s -- --to ${CARGO_HOME:-~/.cargo}/bin - - cargo install cargo-llvm-cov - codecov_script: - - curl -Os https://cli.codecov.io/latest/linux/codecov - - chmod +x codecov - rdma_core_script: - - dnf install -y git libnl3-devel libudev-devel make pkgconfig valgrind-devel - - git clone https://github.com/linux-rdma/rdma-core.git - - ./rdma-core/build.sh - test_script: - - export LD_LIBRARY_PATH=./rdma-core/build/lib - - just test-basic-with-cov - - just test-rc-pingpong-with-cov - - just test-cmtime-with-cov - - just generate-cov - - sed -i 's#/tmp/cirrus-ci-build/##g' lcov.info - - ./codecov --verbose upload-process --disable-search --fail-on-error --git-service github -f ./lcov.info diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d985fad..29c88f6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,3 +39,45 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true file: lcov.info + + test-rxe: + runs-on: ubuntu-latest + steps: + - name: Checkout repository code + uses: actions/checkout@v4 + - name: Install tools required + uses: taiki-e/install-action@v2 + with: + tool: just,cargo-nextest,cargo-llvm-cov + - name: Clone rxe kmod + run: | + uname -a + git clone https://github.com/pizhenwei/rxe.git + make -C rxe + sudo insmod ./rxe/rdma_rxe.ko + ip add + sudo rdma link add rxe_eth0 type rxe netdev eth0 + - name: Build rdma-core + run: | + sudo apt update + sudo apt install -y make pkg-config cmake libnl-3-dev libnl-route-3-dev libnl-genl-3-dev + git clone https://github.com/linux-rdma/rdma-core.git + ./rdma-core/build.sh + - name: Test with RXE + run: | + export LD_LIBRARY_PATH=./rdma-core/build/lib + cargo llvm-cov --no-report run --example ibv_devinfo + if modinfo mlx5_ib >/dev/null 2>&1 && lsmod | grep -q '^mlx5_ib'; then + sudo rmmod mlx5_ib + fi + just test-basic-with-cov + just test-rc-pingpong-with-cov + just test-cmtime-with-cov + just generate-cov + - name: Upload coverage information + uses: codecov/codecov-action@v5 + with: + fail_ci_if_error: false + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + file: lcov.info diff --git a/Cargo.toml b/Cargo.toml index a1a97a8..a58efe0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ description = "A better wrapper for using RDMA programming APIs in Rust flavor" license= "MPL-2.0" repository = "https://github.com/RDMA-Rust/sideway" readme = "README.md" -keywords = ["RDMA", "verbs", "cm", "libibverbs", "librdmacm"] +keywords = ["RDMA", "verbs", "cm", "libibverbs", "librdmacm", "ibverbs", "rdmacm"] authors = [ "Luke Yue ", "FujiZ ", diff --git a/src/rdmacm/communication_manager.rs b/src/rdmacm/communication_manager.rs index d938b3e..a32d247 100644 --- a/src/rdmacm/communication_manager.rs +++ b/src/rdmacm/communication_manager.rs @@ -166,10 +166,10 @@ use std::{io, mem::MaybeUninit, net::SocketAddr, ptr::NonNull, sync::Arc}; use os_socketaddr::OsSocketAddr; use rdma_mummy_sys::{ - ibv_qp_attr, rdma_accept, rdma_ack_cm_event, rdma_bind_addr, rdma_cm_event, rdma_cm_event_type, rdma_cm_id, - rdma_conn_param, rdma_connect, rdma_create_event_channel, rdma_create_id, rdma_destroy_event_channel, + ibv_qp_attr, ibv_qp_type, rdma_accept, rdma_ack_cm_event, rdma_bind_addr, rdma_cm_event, rdma_cm_event_type, + rdma_cm_id, rdma_conn_param, rdma_connect, rdma_create_event_channel, rdma_create_id, rdma_destroy_event_channel, rdma_destroy_id, rdma_disconnect, rdma_establish, rdma_event_channel, rdma_get_cm_event, rdma_init_qp_attr, - rdma_listen, rdma_port_space, rdma_resolve_addr, rdma_resolve_route, + rdma_listen, rdma_migrate_id, rdma_port_space, rdma_resolve_addr, rdma_resolve_route, }; use crate::ibverbs::device_context::DeviceContext; @@ -226,6 +226,7 @@ static DEVICE_LISTS: LazyLock>>> = LazyL /// An RDMA event represents an event from an RDMA event channel, reported by an [`Identifier`]. pub struct Event { event: NonNull, + event_channel: Option>, cm_id: Option>, listener_id: Option>, } @@ -238,16 +239,22 @@ pub struct EventChannel { /// An RDMA CM identifier (`rdma_cm_id`), conceptually similar to a socket, an [`Identifier`] would /// report some of the RDMA CM operations' result as an [`Event`] to its [`EventChannel`]. pub struct Identifier { - _event_channel: Arc, + // Keeps the raw rdma_cm_id's current event channel alive. The mutex is a + // Rust-side migration guard: it serializes rdma_migrate_id plus the Arc + // replacement so concurrent migrations cannot leave this lifetime anchor + // pointing at a different channel than rdma_cm_id::channel. It is not a + // general RDMA CM operation lock; other operations do not read this field. + event_channel: Mutex>, cm_id: NonNull, user_context: Mutex>>, } /// A connection paramter used for configure the communication when connecting or establishing /// datagram communication. Used in [`Identifier::connect`] and [`Identifier::accept`]. -pub struct ConnectionParameter(rdma_conn_param); +pub struct ConnectionParameter(rdma_conn_param, Vec); /// The RDMA port space. +#[derive(Debug, Clone, Copy)] pub enum PortSpace { /// Provides for any InfiniBand services (UD, UC, RC, XRC, etc.). InfiniBand = rdma_port_space::RDMA_PS_IB as isize, @@ -392,6 +399,21 @@ pub enum ListenErrorKind { Rdmacm(#[from] io::Error), } +/// Error returned by [`Identifier::migrate`] for moving an [`Identifier`] to another +/// [`EventChannel`]. +#[derive(Debug, thiserror::Error)] +#[error("failed to migrate rdma cm identifier")] +#[non_exhaustive] +pub struct MigrateError(#[from] pub MigrateErrorKind); + +/// The enum type for [`MigrateError`]. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +#[non_exhaustive] +pub enum MigrateErrorKind { + Rdmacm(#[from] io::Error), +} + /// Error returned by [`Identifier::connect`] for connecting to a remote endpoint. #[derive(Debug, thiserror::Error)] #[error("failed to connect")] @@ -517,6 +539,74 @@ impl Event { unsafe { self.event.as_ref().status } } + /// Get the private data sent by the remote peer. + /// + /// This is available for [`EventType::ConnectRequest`], + /// [`EventType::ConnectResponse`], and [`EventType::Rejected`] events, where + /// RDMA CM uses the `rdma_cm_event.param` union to carry connection or + /// datagram-service private data. + /// + /// Note that the actual amount of data transferred is transport dependent + /// and may be larger than that requested, with trailing zero as padding. + /// + /// **For `AF_IB` connected requests, Linux formats RDMA CM's + /// `struct cma_hdr` at the start of the IB CM REQ private-data area. + /// The first byte is `cma_version`, currently `0`, so byte 0 of + /// the [`EventType::ConnectRequest`] private data is overwritten + /// with `0` and is not the peer's original application byte.** + /// + /// # Returns + /// The private data slice, or an empty slice if the event has no private + /// data. + /// + /// # Example + /// ```ignore + /// match event.event_type() { + /// EventType::ConnectRequest => { + /// let data = event.private_data(); + /// if !data.is_empty() { + /// println!("Received {} bytes of private data", data.len()); + /// } + /// } + /// _ => {} + /// } + /// ``` + pub fn private_data(&self) -> &[u8] { + unsafe { + let event = self.event.as_ref(); + match self.event_type() { + EventType::ConnectRequest | EventType::ConnectResponse | EventType::Rejected => { + let Some(id) = NonNull::new(event.id) else { + return &[]; + }; + + if id.as_ref().qp_type == ibv_qp_type::IBV_QPT_UD { + let param = &event.param.ud; + Self::private_data_slice(event, param.private_data, param.private_data_len) + } else { + let param = &event.param.conn; + Self::private_data_slice(event, param.private_data, param.private_data_len) + } + }, + _ => &[], + } + } + } + + fn private_data_slice( + _event: &rdma_cm_event, private_data: *const std::ffi::c_void, private_data_len: u8, + ) -> &[u8] { + let len = private_data_len as usize; + if len == 0 || private_data.is_null() { + &[] + } else { + // SAFETY: The caller selected the active RDMA CM event union member + // for an event type that carries private data. The returned slice is + // valid until the event is acknowledged. + unsafe { std::slice::from_raw_parts(private_data.cast(), len) } + } + } + /// Acknowledge and free the communication event. /// /// # Note @@ -534,6 +624,7 @@ impl Event { return Err(AcknowledgeEventErrorKind::Rdmacm(io::Error::last_os_error()).into()); } + self.event_channel.take(); self.cm_id.take(); self.listener_id.take(); @@ -555,7 +646,7 @@ impl Drop for Event { fn new_cm_id_for_raw(event_channel: Arc, raw: *mut rdma_cm_id) -> Arc { let cm = unsafe { Arc::new(Identifier { - _event_channel: event_channel, + event_channel: Mutex::new(event_channel), cm_id: NonNull::new(raw).unwrap_unchecked(), user_context: Mutex::new(None), }) @@ -644,6 +735,7 @@ impl EventChannel { Ok(Event { event, + event_channel: Some(self.clone()), cm_id, listener_id, }) @@ -831,6 +923,38 @@ impl Identifier { Ok(()) } + /// Move this [`Identifier`] to another [`EventChannel`]. + /// + /// After a successful migration, RDMA CM events associated with this + /// identifier are reported on `channel`. `librdmacm` also moves any pending + /// events for the identifier to the new channel. + /// + /// # Note + /// + /// The underlying [`rdma_migrate_id(3)`] call may block while the current + /// event channel has unacknowledged events. Do not poll the current event + /// channel or invoke other routines on this identifier while migrating it + /// between channels. + /// + /// The C API accepts a null channel to put the ID into synchronous operation + /// mode. This safe wrapper intentionally exposes only migration to a live + /// [`EventChannel`]. + /// + /// [`rdma_migrate_id(3)`]: https://man7.org/linux/man-pages/man3/rdma_migrate_id.3.html + pub fn migrate(&self, channel: &Arc) -> Result<(), MigrateError> { + let mut event_channel = self.event_channel.lock().unwrap(); + let cm_id = self.cm_id; + let ret = unsafe { rdma_migrate_id(cm_id.as_ptr(), channel.channel.as_ptr()) }; + + if ret < 0 { + return Err(MigrateErrorKind::Rdmacm(io::Error::last_os_error()).into()); + } + + *event_channel = channel.clone(); + + Ok(()) + } + /// Get the [`DeviceContext`] associated with the [`Identifier`]. The [`DeviceContext`] is only /// available after the [`Identifier`] is bound to a specific address by [`bind_addr`] or /// [`resolve_addr`]. @@ -962,33 +1086,39 @@ impl Identifier { impl Default for ConnectionParameter { fn default() -> Self { - Self(rdma_conn_param { - private_data: null(), - private_data_len: 0, - responder_resources: 1, - initiator_depth: 1, - flow_control: 0, - retry_count: 7, - rnr_retry_count: 7, - srq: 0, - qp_num: 0, - }) + Self( + rdma_conn_param { + private_data: null(), + private_data_len: 0, + responder_resources: 1, + initiator_depth: 1, + flow_control: 0, + retry_count: 7, + rnr_retry_count: 7, + srq: 0, + qp_num: 0, + }, + Vec::new(), + ) } } impl ConnectionParameter { pub fn new() -> Self { - Self(rdma_conn_param { - private_data: null(), - private_data_len: 0, - responder_resources: 0, - initiator_depth: 0, - flow_control: 0, - retry_count: 0, - rnr_retry_count: 0, - srq: 0, - qp_num: 0, - }) + Self( + rdma_conn_param { + private_data: null(), + private_data_len: 0, + responder_resources: 0, + initiator_depth: 0, + flow_control: 0, + retry_count: 0, + rnr_retry_count: 0, + srq: 0, + qp_num: 0, + }, + Vec::new(), + ) } /// Setup the QP number of the [`Identifier`]. You should fill in this field when you are @@ -1002,16 +1132,240 @@ impl ConnectionParameter { self.0.qp_num = qp_number; self } + + /// Setup the private data to be sent with connect or accept. + /// + /// # Private data size + /// + /// This method copies the provided slice into the [`ConnectionParameter`] + /// and stores that owned buffer's pointer and length in the raw RDMA CM + /// parameter. It does not cap the length to any specific RDMA CM operation. + /// Check the operation limit + /// before calling [`Identifier::connect`] or [`Identifier::accept`], or when + /// using the lower-level [`rdma_reject(3)`] API. + /// + /// | Port space | Service type | [`connect`] | [`accept`] | [`rdma_reject(3)`] | + /// | --- | --- | ---: | ---: | ---: | + /// | [`PortSpace::Tcp`] | connected | 56 | 196 | 148 | + /// | [`PortSpace::Udp`] | datagram | 180 | 136 | 136 | + /// | [`PortSpace::InfiniBand`] | connected | 92 | 196 | 148 | + /// | [`PortSpace::InfiniBand`] | datagram | 216 | 136 | 136 | + /// + /// [`PortSpace::Tcp`] and [`PortSpace::Udp`] values are the user-visible + /// payload sizes documented by the [`rdma_connect(3)`] and + /// [`rdma_accept(3)`] man pages, plus the [`rdma_reject(3)`] sizes implied + /// by Linux's IB CM message constants and RDMA CM routing. + /// [`PortSpace::InfiniBand`] is derived from Linux CMA's + /// `id->qp_type == IB_QPT_UD` branch: connected QPs use IB CM REQ/REP/REJ + /// messages, while datagram services use SIDR REQ/REP. + /// + /// **For IB connected requests, Linux formats RDMA CM's + /// `struct cma_hdr` at the start of the IB CM REQ private-data area. + /// The first byte is `cma_version`, currently `0`, so byte 0 of + /// the [`EventType::ConnectRequest`] private data is overwritten + /// with `0`, please offset one byte when setting private data for IB.** + /// + /// [`connect`]: Identifier::connect + /// [`accept`]: Identifier::accept + /// [`rdma_connect(3)`]: https://man7.org/linux/man-pages/man3/rdma_connect.3.html + /// [`rdma_accept(3)`]: https://man7.org/linux/man-pages/man3/rdma_accept.3.html + /// [`rdma_reject(3)`]: https://man7.org/linux/man-pages/man3/rdma_reject.3.html + /// + /// [`setup_private_data`]: ConnectionParameter::setup_private_data + /// + /// # Panics + /// Panics if `data.len()` does not fit in `u8`, because + /// `rdma_conn_param.private_data_len` is an 8-bit field. + /// + /// # Example + /// ```ignore + /// let my_data = [1u8, 2, 3, 4]; + /// param.setup_private_data(&my_data); + /// id.connect(param)?; + /// ``` + pub fn setup_private_data(&mut self, data: &[u8]) -> &mut Self { + let len = + u8::try_from(data.len()).expect("ConnectionParameter private_data length is limited to u8::MAX bytes"); + self.1.clear(); + self.1.extend_from_slice(data); + self.0.private_data = if self.1.is_empty() { + null() + } else { + self.1.as_ptr().cast() + }; + self.0.private_data_len = len; + self + } + + /// Setup responder resources for the connection. + /// This is the maximum number of outstanding RDMA read/atomic operations + /// the local side will accept from the remote side. + pub fn setup_responder_resources(&mut self, resources: u8) -> &mut Self { + self.0.responder_resources = resources; + self + } + + /// Setup initiator depth for the connection. + /// This is the maximum number of outstanding RDMA read/atomic operations + /// that the local side will have pending to the remote side. + pub fn setup_initiator_depth(&mut self, depth: u8) -> &mut Self { + self.0.initiator_depth = depth; + self + } + + /// Setup retry count for the connection. + /// The number of times to retry a connection request or response. + pub fn setup_retry_count(&mut self, count: u8) -> &mut Self { + self.0.retry_count = count; + self + } + + /// Setup RNR retry count for the connection. + /// The number of times to retry a receiver-not-ready error. + pub fn setup_rnr_retry_count(&mut self, count: u8) -> &mut Self { + self.0.rnr_retry_count = count; + self + } } #[cfg(test)] mod tests { use super::*; - use polling::{Event, Events, Poller}; - use std::net::{IpAddr, SocketAddr}; + use crate::ibverbs::address::{GidEntry, GidType}; + use crate::ibverbs::completion::GenericCompletionQueue; + use crate::ibverbs::device; + use crate::ibverbs::queue_pair::{ExtendedQueuePair, QueuePair}; + use polling::{Event as PollingEvent, Events, Poller}; + use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::str::FromStr; use std::thread; + const CM_SETUP_PRIVATE_DATA_SIZE: usize = 54; + + #[derive(Clone, Copy)] + struct CmGidTestAddr { + ip: IpAddr, + scope_id: u32, + } + + impl CmGidTestAddr { + fn socket_addr(self, port: u16) -> SocketAddr { + match self.ip { + IpAddr::V4(addr) => SocketAddr::from((addr, port)), + IpAddr::V6(addr) => SocketAddr::V6(std::net::SocketAddrV6::new(addr, port, 0, self.scope_id)), + } + } + } + + fn cm_addr_from_gid_entry(gid_entry: GidEntry) -> Option { + let gid = gid_entry.gid(); + if gid.is_zero() { + return None; + } + + match gid_entry.gid_type() { + GidType::InfiniBand => {}, + GidType::RoceV2 if !gid.is_unicast_link_local() => {}, + _ => return None, + } + + let ipv6 = Ipv6Addr::from(gid); + let scope_id = if ipv6.is_unicast_link_local() { + let scope_id = gid_entry.netdev_index(); + if scope_id == 0 { + return None; + } + scope_id + } else { + 0 + }; + + Some(CmGidTestAddr { + ip: ipv6.to_ipv4_mapped().map_or(IpAddr::V6(ipv6), IpAddr::V4), + scope_id, + }) + } + + fn first_ib_or_roce_v2_gid_addr() -> Option { + let device_list = device::DeviceList::new().ok()?; + + for device in &device_list { + let Ok(ctx) = device.open() else { + continue; + }; + let Ok(gid_entries) = ctx.query_gid_table() else { + continue; + }; + + if let Some(addr) = gid_entries.into_iter().find_map(cm_addr_from_gid_entry) { + return Some(addr); + } + } + + None + } + + fn create_test_qp(id: &Identifier) -> Result { + let ctx = id + .get_device_context() + .ok_or_else(|| "RDMA CM ID has no verbs device context".to_owned())?; + let pd = ctx.alloc_pd().map_err(|err| err.to_string())?; + let cq = GenericCompletionQueue::from( + ctx.create_cq_builder() + .setup_cqe(2) + .build_ex() + .map_err(|err| err.to_string())?, + ); + + let mut qp = pd + .create_qp_builder() + .setup_max_send_wr(1) + .setup_max_send_sge(1) + .setup_max_recv_wr(1) + .setup_max_recv_sge(1) + .setup_send_cq(cq.clone()) + .setup_recv_cq(cq) + .build_ex() + .map_err(|err| err.to_string())?; + + qp.modify(&id.get_qp_attr(QueuePairState::Init).map_err(|err| err.to_string())?) + .map_err(|err| err.to_string())?; + + Ok(qp) + } + + fn move_test_qp_to_rts(id: &Identifier, qp: &mut ExtendedQueuePair) -> Result<(), String> { + qp.modify( + &id.get_qp_attr(QueuePairState::ReadyToReceive) + .map_err(|err| err.to_string())?, + ) + .map_err(|err| err.to_string())?; + qp.modify( + &id.get_qp_attr(QueuePairState::ReadyToSend) + .map_err(|err| err.to_string())?, + ) + .map_err(|err| err.to_string()) + } + + fn wait_for_cm_event( + channel: &Arc, timeout: Duration, description: &str, + ) -> Result> { + channel.set_nonblocking(true)?; + + let poller = Poller::new()?; + unsafe { poller.add(channel, PollingEvent::readable(1))? }; + + let mut events = Events::new(); + poller.wait(&mut events, Some(timeout))?; + + assert!( + !events.is_empty(), + "expected {description} to receive an RDMA CM event before timeout" + ); + + Ok(channel.get_cm_event()?) + } + #[test] fn test_cm_id_reference_count() -> Result<(), Box> { match EventChannel::new() { @@ -1029,7 +1383,7 @@ mod tests { assert_eq!(Arc::strong_count(&id), 1); - let event = channel.get_cm_event().unwrap(); + let event = wait_for_cm_event(&channel, Duration::from_secs(2), "reference count test")?; assert_eq!(Arc::strong_count(&id), 2); @@ -1057,31 +1411,15 @@ mod tests { assert_eq!(Arc::strong_count(&id), 1); - channel.set_nonblocking(true).unwrap(); - let dispatcher = thread::spawn(move || { - let poller = Poller::new().expect("Failed to create poller"); - let key = 233; - assert_eq!(Arc::strong_count(&channel), 2); - unsafe { poller.add(&channel, Event::readable(key)).unwrap() }; - let mut events = Events::new(); - events.clear(); - poller.wait(&mut events, None).unwrap(); + let event = wait_for_cm_event(&channel, Duration::from_secs(2), "event fd test").unwrap(); + assert_eq!(event.event_type(), EventType::AddressResolved); + assert_eq!(Arc::strong_count(&channel), 3); - assert_eq!(events.len(), 1); - - for ev in events.iter() { - assert_eq!(ev.key, key); - - let event = channel.get_cm_event().unwrap(); - assert_eq!(event.event_type(), EventType::AddressResolved); - assert_eq!(Arc::strong_count(&channel), 2); - - event.ack().unwrap(); - assert_eq!(Arc::strong_count(&channel), 2); - } + event.ack().unwrap(); + assert_eq!(Arc::strong_count(&channel), 2); }); let _ = id.resolve_addr( @@ -1099,6 +1437,129 @@ mod tests { } } + #[test] + fn test_event_keeps_source_channel_alive_after_identifier_migrates() -> Result<(), Box> { + match (EventChannel::new(), EventChannel::new()) { + (Ok(source_channel), Ok(migrated_channel)) => { + let id = source_channel.create_id(PortSpace::Tcp)?; + let source_channel_weak = Arc::downgrade(&source_channel); + + id.resolve_addr( + None, + SocketAddr::from((IpAddr::from_str("127.0.0.1").expect("Invalid IP address"), 0)), + Duration::new(0, 200000000), + )?; + + let event = wait_for_cm_event(&source_channel, Duration::from_secs(2), "source event channel")?; + assert_eq!(event.event_type(), EventType::AddressResolved); + assert_eq!(Arc::strong_count(&source_channel), 3); + + // Model the Rust-side lifetime state after a successful migration + // without calling rdma_migrate_id while an event is unacknowledged. + // The event must keep its retrieval channel alive even after the + // identifier's current channel anchor moves elsewhere. + *id.event_channel.lock().unwrap() = migrated_channel; + assert_eq!(Arc::strong_count(&source_channel), 2); + + drop(source_channel); + assert_eq!(source_channel_weak.strong_count(), 1); + + event.ack()?; + assert!(source_channel_weak.upgrade().is_none()); + + Ok(()) + }, + _ => Ok(()), + } + } + + #[test] + fn test_migrate_id_to_same_channel_reports_events() -> Result<(), Box> { + match EventChannel::new() { + Ok(channel) => { + let id = channel.create_id(PortSpace::Tcp)?; + let raw_channel = channel.channel.as_ptr(); + + assert_eq!(Arc::strong_count(&channel), 2); + assert_eq!(unsafe { id.cm_id.as_ref().channel }, raw_channel); + + id.migrate(&channel)?; + + assert_eq!(Arc::strong_count(&channel), 2); + assert_eq!(unsafe { id.cm_id.as_ref().channel }, raw_channel); + + id.resolve_addr( + None, + SocketAddr::from((IpAddr::from_str("127.0.0.1").expect("Invalid IP address"), 0)), + Duration::new(0, 200000000), + )?; + + let event = wait_for_cm_event(&channel, Duration::from_secs(2), "self-migrated event channel")?; + assert_eq!(event.event_type(), EventType::AddressResolved); + assert!(Arc::ptr_eq( + &event + .cm_id() + .expect("self-migrated event should carry the migrated identifier"), + &id + )); + event.ack()?; + + Ok(()) + }, + Err(_) => Ok(()), + } + } + + #[test] + fn test_migrate_id_reports_events_to_new_channel() -> Result<(), Box> { + match (EventChannel::new(), EventChannel::new()) { + (Ok(source_channel), Ok(migrated_channel)) => { + let id = source_channel.create_id(PortSpace::Tcp)?; + + assert_eq!(Arc::strong_count(&source_channel), 2); + assert_eq!(Arc::strong_count(&migrated_channel), 1); + + source_channel.set_nonblocking(true)?; + id.migrate(&migrated_channel)?; + + assert_eq!(Arc::strong_count(&source_channel), 1); + assert_eq!(Arc::strong_count(&migrated_channel), 2); + assert_eq!(unsafe { id.cm_id.as_ref().channel }, migrated_channel.channel.as_ptr()); + + id.resolve_addr( + None, + SocketAddr::from((IpAddr::from_str("127.0.0.1").expect("Invalid IP address"), 0)), + Duration::new(0, 200000000), + )?; + + let event = wait_for_cm_event(&migrated_channel, Duration::from_secs(2), "migrated event channel")?; + assert_eq!(event.event_type(), EventType::AddressResolved); + assert!(Arc::ptr_eq( + &event + .cm_id() + .expect("migrated event should carry the migrated identifier"), + &id + )); + event.ack()?; + + match source_channel.get_cm_event() { + Err(err) => match err.0 { + GetEventErrorKind::NoEvent => {}, + GetEventErrorKind::Rdmacm(err) => return Err(err.into()), + }, + Ok(event) => { + let event_type = event.event_type(); + event.ack()?; + panic!("source channel unexpectedly received migrated event {event_type:?}"); + }, + } + + Ok(()) + }, + _ => Ok(()), + } + } + #[test] fn test_bind_on_the_same_port() -> Result<(), Box> { match EventChannel::new() { @@ -1131,7 +1592,143 @@ mod tests { Ok(channel) => { let _id = channel.create_id(PortSpace::Tcp).unwrap(); - let _param = ConnectionParameter::new(); + let mut data = [0xa5; 196]; + let mut param = ConnectionParameter::new(); + param.setup_private_data(&data); + + assert_eq!(param.0.private_data, param.1.as_ptr().cast()); + assert_eq!(param.0.private_data_len, data.len() as u8); + assert_ne!(param.0.private_data, data.as_ptr().cast()); + + data[0] = 0x5a; + let stored_data = unsafe { + std::slice::from_raw_parts(param.0.private_data as *const u8, param.0.private_data_len as usize) + }; + assert_eq!(stored_data, &[0xa5; 196]); + assert_eq!(data[0], 0x5a); + + param.setup_private_data(&[]); + assert!(param.0.private_data.is_null()); + assert_eq!(param.0.private_data_len, 0); + + Ok(()) + }, + Err(_) => Ok(()), + } + } + + #[test] + fn test_connect_request_and_response_private_data() -> Result<(), Box> { + match EventChannel::new() { + Ok(channel) => { + let Some(cm_addr) = first_ib_or_roce_v2_gid_addr() else { + eprintln!( + "skipping RDMA CM private-data test: no usable IB GID or non-link-local RoCEv2 GID found" + ); + return Ok(()); + }; + + let listener = channel.create_id(PortSpace::Tcp)?; + let port = 18515; + let server_addr = cm_addr.socket_addr(port); + let client_src_addr = cm_addr.socket_addr(0); + listener.bind_addr(server_addr)?; + listener.listen(1)?; + + // Generate some data with different pattern, keep the first byte zero to make IB happy + let request_payload: [u8; CM_SETUP_PRIVATE_DATA_SIZE] = std::array::from_fn(|index| { + if index == 0 { + 0 + } else { + (index as u8).wrapping_mul(3).wrapping_add(1) + } + }); + + let response_payload: [u8; CM_SETUP_PRIVATE_DATA_SIZE] = + std::array::from_fn(|index| (index as u8).wrapping_mul(5).wrapping_add(3)); + + let server = thread::spawn(move || -> Result, String> { + let event = channel.get_cm_event().map_err(|err| err.to_string())?; + assert_eq!(event.event_type(), EventType::ConnectRequest); + + let request_private_data = event.private_data().to_vec(); + let conn_id = event + .cm_id() + .ok_or_else(|| "CONNECT_REQUEST did not provide a child CM ID".to_owned())?; + + let mut qp = create_test_qp(&conn_id)?; + move_test_qp_to_rts(&conn_id, &mut qp)?; + + let mut accept_param = ConnectionParameter::default(); + accept_param.setup_qp_number(qp.qp_number()); + accept_param.setup_private_data(&response_payload); + conn_id.accept(accept_param).map_err(|err| err.to_string())?; + event.ack().map_err(|err| err.to_string())?; + + let established = channel.get_cm_event().map_err(|err| err.to_string())?; + assert_eq!(established.event_type(), EventType::Established); + established.ack().map_err(|err| err.to_string())?; + + Ok(request_private_data) + }); + + let client_channel = EventChannel::new()?; + let client_id = client_channel.create_id(PortSpace::Tcp)?; + client_id.resolve_addr(Some(client_src_addr), server_addr, Duration::from_secs(2))?; + + let mut client_qp = None; + let response_private_data = loop { + let event = client_channel.get_cm_event()?; + match event.event_type() { + EventType::AddressResolved => { + client_id.resolve_route(Duration::from_secs(2))?; + event.ack()?; + }, + EventType::RouteResolved => { + let qp = create_test_qp(&client_id).map_err(io::Error::other)?; + let mut connect_param = ConnectionParameter::default(); + connect_param.setup_qp_number(qp.qp_number()); + connect_param.setup_private_data(&request_payload); + client_id.connect(connect_param)?; + client_qp = Some(qp); + event.ack()?; + }, + EventType::ConnectResponse => { + let data = event.private_data().to_vec(); + let qp = client_qp + .as_mut() + .expect("client QP must exist before CONNECT_RESPONSE"); + move_test_qp_to_rts(&client_id, qp).map_err(io::Error::other)?; + client_id.establish()?; + event.ack()?; + break data; + }, + event_type => panic!("unexpected client RDMA CM event: {event_type:?}"), + } + }; + + let request_private_data = server + .join() + .map_err(|panic| io::Error::other(format!("server thread panicked: {panic:?}")))? + .map_err(io::Error::other)?; + + // Valid field of private data from the RDMA CM event should be exactly the same as original data. + assert_eq!( + &request_private_data[..CM_SETUP_PRIVATE_DATA_SIZE], + request_payload.as_slice() + ); + assert_eq!( + &response_private_data[..CM_SETUP_PRIVATE_DATA_SIZE], + response_payload.as_slice() + ); + + // The left over should be all zero. + assert!(request_private_data[CM_SETUP_PRIVATE_DATA_SIZE..] + .iter() + .all(|&byte| byte == 0)); + assert!(response_private_data[CM_SETUP_PRIVATE_DATA_SIZE..] + .iter() + .all(|&byte| byte == 0)); Ok(()) }, @@ -1169,7 +1766,7 @@ mod tests { Duration::new(0, 200000000), ); - let event = channel.get_cm_event()?; + let event = wait_for_cm_event(&channel, Duration::from_secs(2), "device context test")?; assert_eq!(event.event_type(), EventType::AddressResolved); let ctx1 = id.get_device_context();