diff --git a/Cargo.lock b/Cargo.lock index e4057f75c..31144fc41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3162,6 +3162,10 @@ dependencies = [ "futures", "hex", "hmac", + "http", + "http-body-util", + "hyper", + "hyper-util", "ipnet", "landlock", "libc", diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index 2cd362023..d5e813931 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -11,13 +11,14 @@ use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, Provider, ProviderResponse, - RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, - ServiceStatus, UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + Provider, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -298,6 +299,8 @@ impl OpenShell for TestOpenShell { tokio_stream::wrappers::ReceiverStream>; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; async fn watch_sandbox( &self, @@ -423,6 +426,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // ── TLS helpers ────────────────────────────────────────────────────── diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index 5d04239bf..c98b7eae4 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -200,6 +200,9 @@ impl OpenShell for TestOpenShell { >; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; async fn watch_sandbox( &self, @@ -325,6 +328,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } fn build_ca() -> (Certificate, KeyPair) { diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index c5476afee..1d1323371 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -7,13 +7,14 @@ use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, Provider, ProviderResponse, - RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, - ServiceStatus, UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + Provider, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -252,6 +253,8 @@ impl OpenShell for TestOpenShell { tokio_stream::wrappers::ReceiverStream>; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; async fn watch_sandbox( &self, @@ -377,6 +380,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } fn install_rustls_provider() { diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index d5d39f082..e4c658b7b 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -8,14 +8,14 @@ use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, PlatformEvent, ProviderResponse, - RevokeSshSessionRequest, RevokeSshSessionResponse, Sandbox, SandboxPhase, SandboxResponse, - SandboxStreamEvent, ServiceStatus, UpdateProviderRequest, WatchSandboxRequest, - sandbox_stream_event, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + PlatformEvent, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, Sandbox, + SandboxPhase, SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, + UpdateProviderRequest, WatchSandboxRequest, sandbox_stream_event, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -242,6 +242,8 @@ impl OpenShell for TestOpenShell { tokio_stream::wrappers::ReceiverStream>; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; async fn watch_sandbox( &self, @@ -403,6 +405,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } fn install_rustls_provider() { diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index fbadec4c3..7824d141a 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -8,12 +8,13 @@ use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, Sandbox, SandboxResponse, - SandboxStreamEvent, ServiceStatus, UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, Sandbox, SandboxResponse, SandboxStreamEvent, ServiceStatus, + SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -210,6 +211,8 @@ impl OpenShell for TestOpenShell { tokio_stream::wrappers::ReceiverStream>; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; async fn watch_sandbox( &self, @@ -335,6 +338,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // ── helpers ─────────────────────────────────────────────────────────── diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 440703af5..3e0240d0f 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -5,7 +5,7 @@ use crate::config::KubernetesComputeConfig; use futures::{Stream, StreamExt, TryStreamExt}; -use k8s_openapi::api::core::v1::{Event as KubeEventObj, Node, Pod}; +use k8s_openapi::api::core::v1::{Event as KubeEventObj, Node}; use kube::api::{Api, ApiResource, DeleteParams, ListParams, PostParams}; use kube::core::gvk::GroupVersionKind; use kube::core::{DynamicObject, ObjectMeta}; @@ -15,12 +15,10 @@ use openshell_core::proto::compute::v1::{ DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, - GetCapabilitiesResponse, ResolveSandboxEndpointResponse, SandboxEndpoint, - WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, - WatchSandboxesSandboxEvent, sandbox_endpoint, watch_sandboxes_event, + GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, }; use std::collections::BTreeMap; -use std::net::IpAddr; use std::pin::Pin; use std::time::Duration; use tokio::sync::mpsc; @@ -271,21 +269,6 @@ impl KubernetesComputeDriver { &self.config.ssh_handshake_secret } - async fn agent_pod_ip(&self, pod_name: &str) -> Result, KubeError> { - let api: Api = Api::namespaced(self.client.clone(), &self.config.namespace); - match api.get(pod_name).await { - Ok(pod) => { - let ip = pod - .status - .and_then(|status| status.pod_ip) - .and_then(|ip| ip.parse().ok()); - Ok(ip) - } - Err(KubeError::Api(err)) if err.code == 404 => Ok(None), - Err(err) => Err(err), - } - } - pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { let name = sandbox.name.as_str(); info!( @@ -407,52 +390,6 @@ impl KubernetesComputeDriver { } } - pub async fn resolve_sandbox_endpoint( - &self, - sandbox: &Sandbox, - ) -> Result { - if let Some(status) = sandbox.status.as_ref() - && !status.instance_id.is_empty() - { - match self.agent_pod_ip(&status.instance_id).await { - Ok(Some(ip)) => { - return Ok(ResolveSandboxEndpointResponse { - endpoint: Some(SandboxEndpoint { - target: Some(sandbox_endpoint::Target::Ip(ip.to_string())), - port: u32::from(self.config.ssh_port), - }), - }); - } - Ok(None) => { - return Err(KubernetesDriverError::Precondition( - "sandbox agent pod IP is not available".to_string(), - )); - } - Err(err) => { - return Err(KubernetesDriverError::Message(format!( - "failed to resolve agent pod IP: {err}" - ))); - } - } - } - - if sandbox.name.is_empty() { - return Err(KubernetesDriverError::Precondition( - "sandbox has no name".to_string(), - )); - } - - Ok(ResolveSandboxEndpointResponse { - endpoint: Some(SandboxEndpoint { - target: Some(sandbox_endpoint::Target::Host(format!( - "{}.{}.svc.cluster.local", - sandbox.name, self.config.namespace - ))), - port: u32::from(self.config.ssh_port), - }), - }) - } - pub async fn watch_sandboxes(&self) -> Result { let namespace = self.config.namespace.clone(); let sandbox_api = self.api(); diff --git a/crates/openshell-driver-kubernetes/src/grpc.rs b/crates/openshell-driver-kubernetes/src/grpc.rs index 2c5a94467..75e131d41 100644 --- a/crates/openshell-driver-kubernetes/src/grpc.rs +++ b/crates/openshell-driver-kubernetes/src/grpc.rs @@ -5,8 +5,7 @@ use futures::{Stream, StreamExt}; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, - ListSandboxesRequest, ListSandboxesResponse, ResolveSandboxEndpointRequest, - ResolveSandboxEndpointResponse, StopSandboxRequest, StopSandboxResponse, + ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_server::ComputeDriver, }; @@ -128,21 +127,6 @@ impl ComputeDriver for ComputeDriverService { Ok(Response::new(DeleteSandboxResponse { deleted })) } - async fn resolve_sandbox_endpoint( - &self, - request: Request, - ) -> Result, Status> { - let sandbox = request - .into_inner() - .sandbox - .ok_or_else(|| Status::invalid_argument("sandbox is required"))?; - self.driver - .resolve_sandbox_endpoint(&sandbox) - .await - .map(Response::new) - .map_err(status_from_driver_error) - } - type WatchSandboxesStream = Pin> + Send + 'static>>; diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 541784ee6..b21b1948f 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -51,8 +51,15 @@ rcgen = { workspace = true } webpki-roots = { workspace = true } # HTTP +hyper = { workspace = true } +hyper-util = { workspace = true } +http = "1" +http-body-util = "0.1" bytes = { workspace = true } +# UUID +uuid = { workspace = true } + # Encoding base64 = { workspace = true } diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 5503637ee..09e7b607d 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -74,6 +74,11 @@ async fn connect_channel(endpoint: &str) -> Result { .wrap_err("failed to connect to OpenShell server") } +/// Create a channel to the OpenShell server (public for use by supervisor_session). +pub async fn connect_channel_pub(endpoint: &str) -> Result { + connect_channel(endpoint).await +} + /// Connect to the OpenShell server (mTLS or plaintext based on endpoint scheme). async fn connect(endpoint: &str) -> Result> { let channel = connect_channel(endpoint).await?; diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index b81dd4a6c..76da6bb3f 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -21,6 +21,7 @@ pub mod proxy; mod sandbox; mod secrets; mod ssh; +mod supervisor_session; use miette::{IntoDiagnostic, Result}; #[cfg(target_os = "linux")] @@ -676,6 +677,21 @@ pub async fn run_sandbox( } } + // Spawn the persistent supervisor session if we have a gateway endpoint + // and sandbox identity. The session provides relay channels for SSH + // connect and ExecSandbox through the gateway. + if let (Some(endpoint), Some(id)) = (openshell_endpoint.as_ref(), sandbox_id.as_ref()) { + // The SSH listen address was consumed above, so we use the configured + // SSH port (default 2222) for loopback connections from the relay. + let ssh_port = std::env::var("OPENSHELL_SSH_PORT") + .ok() + .and_then(|p| p.parse::().ok()) + .unwrap_or(2222); + + supervisor_session::spawn(endpoint.clone(), id.clone(), ssh_port); + info!("supervisor session task spawned"); + } + #[cfg(target_os = "linux")] let mut handle = ProcessHandle::spawn( program, diff --git a/crates/openshell-sandbox/src/supervisor_session.rs b/crates/openshell-sandbox/src/supervisor_session.rs new file mode 100644 index 000000000..2b571df08 --- /dev/null +++ b/crates/openshell-sandbox/src/supervisor_session.rs @@ -0,0 +1,351 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Persistent supervisor-to-gateway session. +//! +//! Maintains a long-lived `ConnectSupervisor` bidirectional gRPC stream to the +//! gateway. When the gateway sends `RelayOpen`, the supervisor opens a reverse +//! HTTP CONNECT tunnel back to the gateway and bridges it to the local SSH +//! daemon. The supervisor is a dumb byte bridge — it has no protocol awareness +//! of the SSH or NSSH1 bytes flowing through the tunnel. + +use std::time::Duration; + +use openshell_core::proto::open_shell_client::OpenShellClient; +use openshell_core::proto::{ + GatewayMessage, SupervisorHeartbeat, SupervisorHello, SupervisorMessage, gateway_message, + supervisor_message, +}; +use tokio::sync::mpsc; +use tonic::transport::Channel; +use tracing::{info, warn}; + +use crate::grpc_client; + +const INITIAL_BACKOFF: Duration = Duration::from_secs(1); +const MAX_BACKOFF: Duration = Duration::from_secs(30); + +/// Spawn the supervisor session task. +/// +/// The task runs for the lifetime of the sandbox process, reconnecting with +/// exponential backoff on failures. +pub fn spawn( + endpoint: String, + sandbox_id: String, + ssh_listen_port: u16, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(run_session_loop(endpoint, sandbox_id, ssh_listen_port)) +} + +async fn run_session_loop(endpoint: String, sandbox_id: String, ssh_listen_port: u16) { + let mut backoff = INITIAL_BACKOFF; + let mut attempt: u64 = 0; + + loop { + attempt += 1; + + match run_single_session(&endpoint, &sandbox_id, ssh_listen_port).await { + Ok(()) => { + info!(sandbox_id = %sandbox_id, "supervisor session ended cleanly"); + break; + } + Err(e) => { + warn!( + sandbox_id = %sandbox_id, + attempt = attempt, + backoff_ms = backoff.as_millis() as u64, + error = %e, + "supervisor session failed, reconnecting" + ); + tokio::time::sleep(backoff).await; + backoff = (backoff * 2).min(MAX_BACKOFF); + } + } + } +} + +async fn run_single_session( + endpoint: &str, + sandbox_id: &str, + ssh_listen_port: u16, +) -> Result<(), Box> { + // Connect to the gateway. + let channel = grpc_client::connect_channel_pub(endpoint) + .await + .map_err(|e| format!("connect failed: {e}"))?; + let mut client = OpenShellClient::new(channel.clone()); + + // Create the outbound message stream. + let (tx, rx) = mpsc::channel::(64); + let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); + + // Send hello as the first message. + let instance_id = uuid::Uuid::new_v4().to_string(); + tx.send(SupervisorMessage { + payload: Some(supervisor_message::Payload::Hello(SupervisorHello { + sandbox_id: sandbox_id.to_string(), + instance_id: instance_id.clone(), + })), + }) + .await + .map_err(|_| "failed to queue hello")?; + + // Open the bidirectional stream. + let response = client + .connect_supervisor(outbound) + .await + .map_err(|e| format!("connect_supervisor RPC failed: {e}"))?; + let mut inbound = response.into_inner(); + + // Wait for SessionAccepted. + let accepted = match inbound.message().await? { + Some(msg) => match msg.payload { + Some(gateway_message::Payload::SessionAccepted(a)) => a, + Some(gateway_message::Payload::SessionRejected(r)) => { + return Err(format!("session rejected: {}", r.reason).into()); + } + _ => return Err("expected SessionAccepted or SessionRejected".into()), + }, + None => return Err("stream closed before session accepted".into()), + }; + + let heartbeat_secs = accepted.heartbeat_interval_secs.max(5); + info!( + sandbox_id = %sandbox_id, + session_id = %accepted.session_id, + instance_id = %instance_id, + heartbeat_secs = heartbeat_secs, + "supervisor session established" + ); + + // Main loop: receive gateway messages + send heartbeats. + let mut heartbeat_interval = + tokio::time::interval(Duration::from_secs(u64::from(heartbeat_secs))); + heartbeat_interval.tick().await; // skip immediate tick + + loop { + tokio::select! { + msg = inbound.message() => { + match msg { + Ok(Some(msg)) => { + handle_gateway_message( + &msg, + sandbox_id, + &endpoint, + ssh_listen_port, + &channel, + ).await; + } + Ok(None) => { + info!(sandbox_id = %sandbox_id, "supervisor session: gateway closed stream"); + return Ok(()); + } + Err(e) => { + return Err(format!("stream error: {e}").into()); + } + } + } + _ = heartbeat_interval.tick() => { + let hb = SupervisorMessage { + payload: Some(supervisor_message::Payload::Heartbeat( + SupervisorHeartbeat {}, + )), + }; + if tx.send(hb).await.is_err() { + return Err("outbound channel closed".into()); + } + } + } + } +} + +async fn handle_gateway_message( + msg: &GatewayMessage, + sandbox_id: &str, + endpoint: &str, + ssh_listen_port: u16, + _channel: &Channel, +) { + match &msg.payload { + Some(gateway_message::Payload::Heartbeat(_)) => { + // Gateway heartbeat — nothing to do. + } + Some(gateway_message::Payload::RelayOpen(open)) => { + let channel_id = open.channel_id.clone(); + let endpoint = endpoint.to_string(); + let sandbox_id = sandbox_id.to_string(); + + info!( + sandbox_id = %sandbox_id, + channel_id = %channel_id, + "supervisor session: relay open request, spawning bridge" + ); + + tokio::spawn(async move { + if let Err(e) = handle_relay_open(&channel_id, &endpoint, ssh_listen_port).await { + warn!( + sandbox_id = %sandbox_id, + channel_id = %channel_id, + error = %e, + "supervisor session: relay bridge failed" + ); + } + }); + } + Some(gateway_message::Payload::RelayClose(close)) => { + info!( + sandbox_id = %sandbox_id, + channel_id = %close.channel_id, + reason = %close.reason, + "supervisor session: relay close from gateway" + ); + } + _ => { + warn!(sandbox_id = %sandbox_id, "supervisor session: unexpected gateway message"); + } + } +} + +/// Handle a RelayOpen by opening a reverse HTTP CONNECT to the gateway and +/// bridging it to the local SSH daemon. +async fn handle_relay_open( + channel_id: &str, + endpoint: &str, + ssh_listen_port: u16, +) -> Result<(), Box> { + // Build the relay URL from the gateway endpoint. + // The endpoint is like "https://gateway:8080" or "http://gateway:8080". + let relay_url = format!("{endpoint}/relay/{channel_id}"); + + // Open a reverse HTTP CONNECT to the gateway's relay endpoint. + let mut relay_stream = open_reverse_connect(&relay_url).await?; + + // Connect to the local SSH daemon on loopback. + let mut ssh_conn = tokio::net::TcpStream::connect(("127.0.0.1", ssh_listen_port)).await?; + + info!(channel_id = %channel_id, "relay bridge: connected to local SSH daemon, bridging"); + + // Bridge the relay stream to the local SSH connection. + // The gateway sends NSSH1 preface + SSH bytes through the relay. + // The SSH daemon receives them as if the gateway connected directly. + let _ = tokio::io::copy_bidirectional(&mut relay_stream, &mut ssh_conn).await; + + Ok(()) +} + +/// Open an HTTP CONNECT tunnel to the given URL and return the upgraded stream. +/// +/// This uses a raw hyper HTTP/1.1 client to send a CONNECT request and upgrade +/// the connection to a raw byte stream. +async fn open_reverse_connect( + url: &str, +) -> Result< + hyper_util::rt::TokioIo, + Box, +> { + let uri: http::Uri = url.parse()?; + let host = uri.host().ok_or("missing host")?; + let port = uri + .port_u16() + .unwrap_or(if uri.scheme_str() == Some("https") { + 443 + } else { + 80 + }); + let authority = format!("{host}:{port}"); + let path = uri.path().to_string(); + let use_tls = uri.scheme_str() == Some("https"); + + // Connect TCP. + let tcp = tokio::net::TcpStream::connect(&authority).await?; + tcp.set_nodelay(true)?; + + if use_tls { + // Build TLS connector using the same env-var certs as the gRPC client. + let tls_stream = connect_tls(tcp, host).await?; + send_connect_request(tls_stream, &authority, &path).await + } else { + send_connect_request(tcp, &authority, &path).await + } +} + +async fn send_connect_request( + io: IO, + authority: &str, + path: &str, +) -> Result< + hyper_util::rt::TokioIo, + Box, +> +where + IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + use http::Method; + + let (mut sender, conn) = + hyper::client::conn::http1::handshake(hyper_util::rt::TokioIo::new(io)).await?; + + // Spawn the connection driver. + tokio::spawn(async move { + if let Err(e) = conn.with_upgrades().await { + warn!(error = %e, "relay CONNECT connection driver error"); + } + }); + + let req = http::Request::builder() + .method(Method::CONNECT) + .uri(path) + .header(http::header::HOST, authority) + .body(http_body_util::Empty::::new())?; + + let resp = sender.send_request(req).await?; + + if resp.status() != http::StatusCode::OK + && resp.status() != http::StatusCode::SWITCHING_PROTOCOLS + { + return Err(format!("relay CONNECT failed: {}", resp.status()).into()); + } + + let upgraded = hyper::upgrade::on(resp).await?; + Ok(hyper_util::rt::TokioIo::new(upgraded)) +} + +/// Connect TLS using the same cert env vars as the gRPC client. +async fn connect_tls( + tcp: tokio::net::TcpStream, + host: &str, +) -> Result< + tokio_rustls::client::TlsStream, + Box, +> { + use rustls::pki_types::ServerName; + use std::sync::Arc; + + let ca_path = std::env::var("OPENSHELL_TLS_CA")?; + let cert_path = std::env::var("OPENSHELL_TLS_CERT")?; + let key_path = std::env::var("OPENSHELL_TLS_KEY")?; + + let ca_pem = std::fs::read(&ca_path)?; + let cert_pem = std::fs::read(&cert_path)?; + let key_pem = std::fs::read(&key_path)?; + + let mut root_store = rustls::RootCertStore::empty(); + for cert in rustls_pemfile::certs(&mut ca_pem.as_slice()) { + root_store.add(cert?)?; + } + + let certs: Vec<_> = + rustls_pemfile::certs(&mut cert_pem.as_slice()).collect::>()?; + let key = + rustls_pemfile::private_key(&mut key_pem.as_slice())?.ok_or("no private key found")?; + + let config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_client_auth_cert(certs, key)?; + + let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); + let server_name = ServerName::try_from(host.to_string())?; + let tls_stream = connector.connect(server_name, tcp).await?; + + Ok(tls_stream) +} diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 846782c65..181a5a819 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -13,9 +13,8 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, - ResolveSandboxEndpointRequest, ResolveSandboxEndpointResponse, ValidateSandboxCreateRequest, - WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_server::ComputeDriver, - sandbox_endpoint, watch_sandboxes_event, + ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::proto::{ PlatformEvent, Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, SandboxStatus, @@ -26,7 +25,6 @@ use openshell_driver_kubernetes::{ }; use prost::Message; use std::fmt; -use std::net::IpAddr; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -55,12 +53,6 @@ pub enum ComputeError { Message(String), } -#[derive(Debug)] -pub enum ResolvedEndpoint { - Ip(IpAddr, u16), - Host(String, u16), -} - #[derive(Clone)] pub struct ComputeRuntime { driver: SharedComputeDriver, @@ -243,29 +235,6 @@ impl ComputeRuntime { Ok(deleted) } - pub async fn resolve_sandbox_endpoint( - &self, - sandbox: &Sandbox, - ) -> Result { - let driver_sandbox = driver_sandbox_from_public(sandbox); - self.driver - .resolve_sandbox_endpoint(Request::new(ResolveSandboxEndpointRequest { - sandbox: Some(driver_sandbox), - })) - .await - .map(|response| response.into_inner()) - .map_err(|status| match status.code() { - Code::FailedPrecondition => { - Status::failed_precondition(status.message().to_string()) - } - _ => Status::internal(status.message().to_string()), - }) - .and_then(|response| { - resolved_endpoint_from_response(&response) - .map_err(|err| Status::internal(err.to_string())) - }) - } - pub fn spawn_watchers(&self) { let runtime = Arc::new(self.clone()); let watch_runtime = runtime.clone(); @@ -813,30 +782,6 @@ fn decode_sandbox_record(record: &ObjectRecord) -> Result { Sandbox::decode(record.payload.as_slice()).map_err(|e| e.to_string()) } -fn resolved_endpoint_from_response( - response: &ResolveSandboxEndpointResponse, -) -> Result { - let endpoint = response - .endpoint - .as_ref() - .ok_or_else(|| ComputeError::Message("compute driver returned no endpoint".to_string()))?; - let port = u16::try_from(endpoint.port) - .map_err(|_| ComputeError::Message("compute driver returned invalid port".to_string()))?; - - match endpoint.target.as_ref() { - Some(sandbox_endpoint::Target::Ip(ip)) => ip - .parse() - .map(|ip| ResolvedEndpoint::Ip(ip, port)) - .map_err(|e| ComputeError::Message(format!("invalid endpoint IP: {e}"))), - Some(sandbox_endpoint::Target::Host(host)) => { - Ok(ResolvedEndpoint::Host(host.clone(), port)) - } - None => Err(ComputeError::Message( - "compute driver returned endpoint without target".to_string(), - )), - } -} - fn public_status_from_driver(status: &DriverSandboxStatus) -> SandboxStatus { SandboxStatus { sandbox_name: status.sandbox_name.clone(), @@ -929,8 +874,7 @@ mod tests { use futures::stream; use openshell_core::proto::compute::v1::{ CreateSandboxResponse, DeleteSandboxResponse, GetCapabilitiesResponse, GetSandboxRequest, - GetSandboxResponse, ResolveSandboxEndpointResponse, SandboxEndpoint, StopSandboxRequest, - StopSandboxResponse, ValidateSandboxCreateResponse, sandbox_endpoint, + GetSandboxResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateResponse, }; use std::sync::Arc; @@ -938,7 +882,6 @@ mod tests { struct TestDriver { listed_sandboxes: Vec, current_sandboxes: Vec, - resolve_precondition: Option, } #[tonic::async_trait] @@ -1031,24 +974,6 @@ mod tests { })) } - async fn resolve_sandbox_endpoint( - &self, - _request: Request, - ) -> Result, Status> { - if let Some(message) = &self.resolve_precondition { - return Err(Status::failed_precondition(message.clone())); - } - - Ok(tonic::Response::new(ResolveSandboxEndpointResponse { - endpoint: Some(SandboxEndpoint { - target: Some(sandbox_endpoint::Target::Host( - "sandbox.default.svc.cluster.local".to_string(), - )), - port: 2222, - }), - })) - } - async fn watch_sandboxes( &self, _request: Request, @@ -1322,23 +1247,6 @@ mod tests { ); } - #[tokio::test] - async fn resolve_sandbox_endpoint_preserves_precondition_errors() { - let runtime = test_runtime(Arc::new(TestDriver { - resolve_precondition: Some("sandbox agent pod IP is not available".to_string()), - ..Default::default() - })) - .await; - - let err = runtime - .resolve_sandbox_endpoint(&sandbox_record("sb-1", "sandbox-a", SandboxPhase::Ready)) - .await - .expect_err("endpoint resolution should preserve failed-precondition errors"); - - assert_eq!(err.code(), Code::FailedPrecondition); - assert_eq!(err.message(), "sandbox agent pod IP is not available"); - } - #[tokio::test] async fn reconcile_store_with_backend_applies_driver_snapshot() { let runtime = test_runtime(Arc::new(TestDriver { diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index af60897d1..8a5516c6b 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -14,10 +14,10 @@ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, EditDraftChunkRequest, EditDraftChunkResponse, ExecSandboxEvent, ExecSandboxRequest, - GetDraftHistoryRequest, GetDraftHistoryResponse, GetDraftPolicyRequest, GetDraftPolicyResponse, - GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxLogsRequest, GetSandboxLogsResponse, - GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, + GatewayMessage, GetDraftHistoryRequest, GetDraftHistoryResponse, GetDraftPolicyRequest, + GetDraftPolicyResponse, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, + GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxLogsRequest, + GetSandboxLogsResponse, GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, ListSandboxesRequest, @@ -25,11 +25,12 @@ use openshell_core::proto::{ RejectDraftChunkRequest, RejectDraftChunkResponse, ReportPolicyStatusRequest, ReportPolicyStatusResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, - UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, UpdateConfigResponse, - UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, + SupervisorMessage, UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, + UpdateConfigResponse, UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, }; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; +use std::pin::Pin; use std::sync::Arc; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; @@ -383,6 +384,18 @@ impl OpenShell for OpenShellService { ) -> Result, Status> { policy::handle_get_draft_history(&self.state, request).await } + + // --- Supervisor session --- + + type ConnectSupervisorStream = + Pin> + Send + 'static>>; + + async fn connect_supervisor( + &self, + request: Request>, + ) -> Result, Status> { + crate::supervisor_session::handle_connect_supervisor(&self.state, request).await + } } // --------------------------------------------------------------------------- diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 8e5930826..bdda63d6a 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -22,13 +22,11 @@ use openshell_core::proto::{ use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; use prost::Message; use std::sync::Arc; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; -use tracing::{debug, info, warn}; +use tracing::{info, warn}; use russh::ChannelMsg; use russh::client::AuthResult; @@ -438,7 +436,16 @@ pub(super) async fn handle_exec_sandbox( return Err(Status::failed_precondition("sandbox is not ready")); } - let (target_host, target_port) = resolve_sandbox_exec_target(state, &sandbox).await?; + // Open a relay channel through the supervisor session. Use a 15s + // session-wait timeout — enough to cover a transient supervisor + // reconnect, but shorter than `/connect/ssh` since `ExecSandbox` is + // typically called during normal operation (not right after create). + let (channel_id, relay_rx) = state + .supervisor_sessions + .open_relay(&sandbox.id, std::time::Duration::from_secs(15)) + .await + .map_err(|e| Status::unavailable(format!("supervisor relay failed: {e}")))?; + let command_str = build_remote_exec_command(&req) .map_err(|e| Status::invalid_argument(format!("command construction failed: {e}")))?; let stdin_payload = req.stdin; @@ -449,11 +456,32 @@ pub(super) async fn handle_exec_sandbox( let (tx, rx) = mpsc::channel::>(256); tokio::spawn(async move { - if let Err(err) = stream_exec_over_ssh( + // Wait for the supervisor's reverse CONNECT to deliver the relay stream. + let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) + .await + { + Ok(Ok(stream)) => stream, + Ok(Err(_)) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay channel dropped"); + let _ = tx + .send(Err(Status::unavailable("relay channel dropped"))) + .await; + return; + } + Err(_) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay open timed out"); + let _ = tx + .send(Err(Status::deadline_exceeded("relay open timed out"))) + .await; + return; + } + }; + + if let Err(err) = stream_exec_over_relay( tx.clone(), &sandbox_id, - &target_host, - target_port, + &channel_id, + relay_stream, &command_str, stdin_payload, timeout_seconds, @@ -584,16 +612,6 @@ fn resolve_gateway(config: &openshell_core::Config) -> (String, u16) { (host, port) } -async fn resolve_sandbox_exec_target( - state: &ServerState, - sandbox: &Sandbox, -) -> Result<(String, u16), Status> { - match state.compute.resolve_sandbox_endpoint(sandbox).await? { - crate::compute::ResolvedEndpoint::Ip(ip, port) => Ok((ip.to_string(), port)), - crate::compute::ResolvedEndpoint::Host(host, port) => Ok((host, port)), - } -} - /// Shell-escape a value for embedding in a POSIX shell command. /// /// Wraps unsafe values in single quotes with the standard `'\''` idiom for @@ -646,34 +664,18 @@ fn build_remote_exec_command(req: &ExecSandboxRequest) -> Result Ok(result) } -/// Maximum number of attempts when establishing the SSH transport to a sandbox. -const SSH_CONNECT_MAX_ATTEMPTS: u32 = 6; - -/// Initial backoff duration between SSH connection retries. -const SSH_CONNECT_INITIAL_BACKOFF: std::time::Duration = std::time::Duration::from_millis(250); - -/// Maximum backoff duration between SSH connection retries. -const SSH_CONNECT_MAX_BACKOFF: std::time::Duration = std::time::Duration::from_secs(2); - -/// Returns `true` if the gRPC status represents a transient SSH connection error. -fn is_retryable_ssh_error(status: &Status) -> bool { - if status.code() != tonic::Code::Internal { - return false; - } - let msg = status.message(); - msg.contains("Connection reset by peer") - || msg.contains("Connection refused") - || msg.contains("failed to establish ssh transport") - || msg.contains("failed to connect to ssh proxy") - || msg.contains("failed to start ssh proxy") -} - +/// Execute a command over an SSH transport relayed through a supervisor session. +/// +/// This is the relay equivalent of `stream_exec_over_ssh`. Instead of dialing a +/// sandbox endpoint directly, the SSH transport runs over a `DuplexStream` that +/// is bridged to the supervisor's local SSH daemon via a reverse HTTP CONNECT +/// tunnel. #[allow(clippy::too_many_arguments)] -async fn stream_exec_over_ssh( +async fn stream_exec_over_relay( tx: mpsc::Sender>, sandbox_id: &str, - target_host: &str, - target_port: u16, + channel_id: &str, + relay_stream: tokio::io::DuplexStream, command: &str, stdin_payload: Vec, timeout_seconds: u32, @@ -683,96 +685,53 @@ async fn stream_exec_over_ssh( let command_preview: String = command.chars().take(120).collect(); info!( sandbox_id = %sandbox_id, - target_host = %target_host, - target_port, + channel_id = %channel_id, command_len = command.len(), stdin_len = stdin_payload.len(), command_preview = %command_preview, - "ExecSandbox command started" + "ExecSandbox (relay): command started" ); - let (exit_code, proxy_task) = { - let mut last_err: Option = None; - - let mut result = None; - for attempt in 0..SSH_CONNECT_MAX_ATTEMPTS { - if attempt > 0 { - let backoff = (SSH_CONNECT_INITIAL_BACKOFF * 2u32.pow(attempt - 1)) - .min(SSH_CONNECT_MAX_BACKOFF); - warn!( - sandbox_id = %sandbox_id, - attempt = attempt + 1, - backoff_ms = %backoff.as_millis(), - error = %last_err.as_ref().unwrap(), - "Retrying SSH transport establishment" - ); - tokio::time::sleep(backoff).await; - } - - let (local_proxy_port, proxy_task) = match start_single_use_ssh_proxy( - target_host, - target_port, - handshake_secret, - ) + let (local_proxy_port, proxy_task) = + start_single_use_ssh_proxy_over_relay(relay_stream, handshake_secret) .await - { - Ok(v) => v, - Err(e) => { - last_err = Some(Status::internal(format!("failed to start ssh proxy: {e}"))); - continue; - } - }; - - let exec = run_exec_with_russh( - local_proxy_port, - command, - stdin_payload.clone(), - request_tty, - tx.clone(), - ); + .map_err(|e| Status::internal(format!("failed to start relay proxy: {e}")))?; + + let exec = run_exec_with_russh( + local_proxy_port, + command, + stdin_payload, + request_tty, + tx.clone(), + ); - let exec_result = if timeout_seconds == 0 { - exec.await - } else if let Ok(r) = tokio::time::timeout( - std::time::Duration::from_secs(u64::from(timeout_seconds)), - exec, - ) - .await - { - r - } else { - let _ = tx - .send(Ok(ExecSandboxEvent { - payload: Some(openshell_core::proto::exec_sandbox_event::Payload::Exit( - ExecSandboxExit { exit_code: 124 }, - )), - })) - .await; - let _ = proxy_task.await; - return Ok(()); - }; - - match exec_result { - Ok(exit_code) => { - result = Some((exit_code, proxy_task)); - break; - } - Err(status) => { - let _ = proxy_task.await; - if is_retryable_ssh_error(&status) && attempt + 1 < SSH_CONNECT_MAX_ATTEMPTS { - last_err = Some(status); - continue; - } - return Err(status); - } - } - } + let exec_result = if timeout_seconds == 0 { + exec.await + } else if let Ok(r) = tokio::time::timeout( + std::time::Duration::from_secs(u64::from(timeout_seconds)), + exec, + ) + .await + { + r + } else { + let _ = tx + .send(Ok(ExecSandboxEvent { + payload: Some(openshell_core::proto::exec_sandbox_event::Payload::Exit( + ExecSandboxExit { exit_code: 124 }, + )), + })) + .await; + let _ = proxy_task.await; + return Ok(()); + }; - result.ok_or_else(|| { - last_err.unwrap_or_else(|| { - Status::internal("ssh connection failed after exhausting retries") - }) - })? + let exit_code = match exec_result { + Ok(code) => code, + Err(status) => { + let _ = proxy_task.await; + return Err(status); + } }; let _ = proxy_task.await; @@ -788,6 +747,75 @@ async fn stream_exec_over_ssh( Ok(()) } +/// Create a localhost SSH proxy that bridges to a relay DuplexStream. +/// +/// The proxy sends the NSSH1 handshake preface through the relay (which flows +/// to the supervisor and on to the embedded SSH daemon), waits for "OK", then +/// bridges the russh client connection with the relay stream. +async fn start_single_use_ssh_proxy_over_relay( + relay_stream: tokio::io::DuplexStream, + handshake_secret: &str, +) -> Result<(u16, tokio::task::JoinHandle<()>), Box> { + let listener = TcpListener::bind(("127.0.0.1", 0)).await?; + let port = listener.local_addr()?.port(); + let handshake_secret = handshake_secret.to_string(); + + let task = tokio::spawn(async move { + let Ok((mut client_conn, _)) = listener.accept().await else { + warn!("SSH relay proxy: failed to accept local connection"); + return; + }; + + let (mut relay_read, mut relay_write) = tokio::io::split(relay_stream); + + // Send NSSH1 handshake through the relay to the SSH daemon. + let Ok(preface) = build_preface(&uuid::Uuid::new_v4().to_string(), &handshake_secret) + else { + warn!("SSH relay proxy: failed to build handshake preface"); + return; + }; + if let Err(e) = + tokio::io::AsyncWriteExt::write_all(&mut relay_write, preface.as_bytes()).await + { + warn!(error = %e, "SSH relay proxy: failed to send handshake preface"); + return; + } + + // Read handshake response from the relay. + let mut response_buf = Vec::new(); + loop { + let mut byte = [0u8; 1]; + match tokio::io::AsyncReadExt::read(&mut relay_read, &mut byte).await { + Ok(0) => break, + Ok(_) => { + if byte[0] == b'\n' { + break; + } + response_buf.push(byte[0]); + if response_buf.len() > 1024 { + break; + } + } + Err(e) => { + warn!(error = %e, "SSH relay proxy: failed to read handshake response"); + return; + } + } + } + let response = String::from_utf8_lossy(&response_buf); + if response.trim() != "OK" { + warn!(response = %response.trim(), "SSH relay proxy: handshake rejected"); + return; + } + + // Reunite the split halves for copy_bidirectional. + let mut relay = relay_read.unsplit(relay_write); + let _ = tokio::io::copy_bidirectional(&mut client_conn, &mut relay).await; + }); + + Ok((port, task)) +} + #[derive(Debug, Clone, Copy)] struct SandboxSshClientHandler; @@ -914,98 +942,6 @@ async fn run_exec_with_russh( Ok(exit_code.unwrap_or(1)) } -/// Check whether an IP address is safe to use as an SSH proxy target. -fn is_safe_ssh_proxy_target(ip: std::net::IpAddr) -> bool { - match ip { - std::net::IpAddr::V4(v4) => !v4.is_loopback() && !v4.is_link_local(), - std::net::IpAddr::V6(v6) => { - if v6.is_loopback() { - return false; - } - if let Some(v4) = v6.to_ipv4_mapped() { - return !v4.is_loopback() && !v4.is_link_local(); - } - true - } - } -} - -async fn start_single_use_ssh_proxy( - target_host: &str, - target_port: u16, - handshake_secret: &str, -) -> Result<(u16, tokio::task::JoinHandle<()>), Box> { - let listener = TcpListener::bind(("127.0.0.1", 0)).await?; - let port = listener.local_addr()?.port(); - let target_host = target_host.to_string(); - let handshake_secret = handshake_secret.to_string(); - - let task = tokio::spawn(async move { - let Ok((mut client_conn, _)) = listener.accept().await else { - warn!("SSH proxy: failed to accept local connection"); - return; - }; - - let addr_str = format!("{target_host}:{target_port}"); - let resolved = match tokio::net::lookup_host(&addr_str).await { - Ok(mut addrs) => { - if let Some(addr) = addrs.next() { - addr - } else { - warn!(target_host = %target_host, "SSH proxy: DNS resolution returned no addresses"); - return; - } - } - Err(e) => { - warn!(target_host = %target_host, error = %e, "SSH proxy: DNS resolution failed"); - return; - } - }; - - if !is_safe_ssh_proxy_target(resolved.ip()) { - warn!( - target_host = %target_host, - resolved_ip = %resolved.ip(), - "SSH proxy: target resolved to blocked IP range (loopback or link-local)" - ); - return; - } - - debug!( - target_host = %target_host, - resolved_ip = %resolved.ip(), - target_port, - "SSH proxy: connecting to validated target" - ); - - let Ok(mut sandbox_conn) = TcpStream::connect(resolved).await else { - warn!(target_host = %target_host, resolved_ip = %resolved.ip(), target_port, "SSH proxy: failed to connect to sandbox"); - return; - }; - let Ok(preface) = build_preface(&uuid::Uuid::new_v4().to_string(), &handshake_secret) - else { - warn!("SSH proxy: failed to build handshake preface"); - return; - }; - if let Err(e) = sandbox_conn.write_all(preface.as_bytes()).await { - warn!(error = %e, "SSH proxy: failed to send handshake preface"); - return; - } - let mut response = String::new(); - if let Err(e) = read_line(&mut sandbox_conn, &mut response).await { - warn!(error = %e, "SSH proxy: failed to read handshake response"); - return; - } - if response.trim() != "OK" { - warn!(response = %response.trim(), "SSH proxy: handshake rejected by sandbox"); - return; - } - let _ = tokio::io::copy_bidirectional(&mut client_conn, &mut sandbox_conn).await; - }); - - Ok((port, task)) -} - fn build_preface( token: &str, secret: &str, @@ -1023,29 +959,6 @@ fn build_preface( Ok(format!("NSSH1 {token} {timestamp} {nonce} {signature}\n")) } -async fn read_line( - stream: &mut TcpStream, - buf: &mut String, -) -> Result<(), Box> { - let mut bytes = Vec::new(); - loop { - let mut byte = [0_u8; 1]; - let n = stream.read(&mut byte).await?; - if n == 0 { - break; - } - if byte[0] == b'\n' { - break; - } - bytes.push(byte[0]); - if bytes.len() > 1024 { - break; - } - } - *buf = String::from_utf8_lossy(&bytes).to_string(); - Ok(()) -} - fn hmac_sha256(key: &[u8], data: &[u8]) -> String { use hmac::{Hmac, Mac}; use sha2::Sha256; @@ -1161,59 +1074,6 @@ mod tests { assert!(build_remote_exec_command(&req).is_err()); } - // ---- is_safe_ssh_proxy_target ---- - - #[test] - fn ssh_proxy_target_allows_pod_network_ips() { - use std::net::{IpAddr, Ipv4Addr}; - assert!(is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 10, 0, 0, 5 - )))); - assert!(is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 172, 16, 0, 1 - )))); - assert!(is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 192, 168, 1, 100 - )))); - } - - #[test] - fn ssh_proxy_target_blocks_loopback() { - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 127, 0, 0, 1 - )))); - assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 127, 0, 0, 2 - )))); - assert!(!is_safe_ssh_proxy_target(IpAddr::V6(Ipv6Addr::LOCALHOST))); - } - - #[test] - fn ssh_proxy_target_blocks_link_local() { - use std::net::{IpAddr, Ipv4Addr}; - assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 169, 254, 169, 254 - )))); - assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 169, 254, 0, 1 - )))); - } - - #[test] - fn ssh_proxy_target_blocks_ipv4_mapped_ipv6_loopback() { - use std::net::IpAddr; - let ip: IpAddr = "::ffff:127.0.0.1".parse().unwrap(); - assert!(!is_safe_ssh_proxy_target(ip)); - } - - #[test] - fn ssh_proxy_target_blocks_ipv4_mapped_ipv6_link_local() { - use std::net::IpAddr; - let ip: IpAddr = "::ffff:169.254.169.254".parse().unwrap(); - assert!(!is_safe_ssh_proxy_target(ip)); - } - // ---- petname / generate_name ---- #[test] diff --git a/crates/openshell-server/src/http.rs b/crates/openshell-server/src/http.rs index afe7edc1b..aefe4181b 100644 --- a/crates/openshell-server/src/http.rs +++ b/crates/openshell-server/src/http.rs @@ -49,6 +49,7 @@ pub fn health_router() -> Router { pub fn http_router(state: Arc) -> Router { health_router() .merge(crate::ssh_tunnel::router(state.clone())) + .merge(crate::relay::router(state.clone())) .merge(crate::ws_tunnel::router(state.clone())) .merge(crate::auth::router(state)) } diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index a8d820b4d..cbef28b0e 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -16,9 +16,11 @@ mod http; mod inference; mod multiplex; mod persistence; +mod relay; mod sandbox_index; mod sandbox_watch; mod ssh_tunnel; +pub(crate) mod supervisor_session; mod tls; pub mod tracing_bus; mod ws_tunnel; @@ -73,6 +75,9 @@ pub struct ServerState { /// set/delete operation, including the precedence check on sandbox /// mutations that reads global state. pub settings_mutex: tokio::sync::Mutex<()>, + + /// Registry of active supervisor sessions and pending relay channels. + pub supervisor_sessions: supervisor_session::SupervisorSessionRegistry, } fn is_benign_tls_handshake_failure(error: &std::io::Error) -> bool { @@ -103,6 +108,7 @@ impl ServerState { ssh_connections_by_token: Mutex::new(HashMap::new()), ssh_connections_by_sandbox: Mutex::new(HashMap::new()), settings_mutex: tokio::sync::Mutex::new(()), + supervisor_sessions: supervisor_session::SupervisorSessionRegistry::new(), } } } @@ -148,6 +154,7 @@ pub async fn run_server(config: Config, tracing_log_bus: TracingLogBus) -> Resul state.compute.spawn_watchers(); ssh_tunnel::spawn_session_reaper(store.clone(), std::time::Duration::from_secs(3600)); + supervisor_session::spawn_relay_reaper(state.clone(), std::time::Duration::from_secs(30)); // Create the multiplexed service let service = MultiplexService::new(state.clone()); diff --git a/crates/openshell-server/src/relay.rs b/crates/openshell-server/src/relay.rs new file mode 100644 index 000000000..662fe4d99 --- /dev/null +++ b/crates/openshell-server/src/relay.rs @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! HTTP CONNECT relay endpoint for supervisor-initiated reverse tunnels. +//! +//! When the gateway sends a `RelayOpen` message over the supervisor's gRPC +//! session, the supervisor opens `CONNECT /relay/{channel_id}` back to this +//! endpoint. The gateway then bridges the supervisor's upgraded stream with +//! the client's SSH tunnel or exec proxy. + +use axum::{ + Router, extract::Path, extract::State, http::Method, response::IntoResponse, routing::any, +}; +use http::StatusCode; +use hyper::upgrade::OnUpgrade; +use hyper_util::rt::TokioIo; +use std::sync::Arc; +use tokio::io::AsyncWriteExt; +use tracing::{info, warn}; + +use crate::ServerState; + +pub fn router(state: Arc) -> Router { + Router::new() + .route("/relay/{channel_id}", any(relay_connect)) + .with_state(state) +} + +async fn relay_connect( + State(state): State>, + Path(channel_id): Path, + req: hyper::Request, +) -> impl IntoResponse { + if req.method() != Method::CONNECT { + return StatusCode::METHOD_NOT_ALLOWED.into_response(); + } + + // Claim the pending relay. This consumes the entry — it cannot be reused. + let supervisor_stream = match state.supervisor_sessions.claim_relay(&channel_id) { + Ok(stream) => stream, + Err(_) => { + warn!(channel_id = %channel_id, "relay: unknown or expired channel"); + return StatusCode::NOT_FOUND.into_response(); + } + }; + + info!(channel_id = %channel_id, "relay: supervisor connected, upgrading"); + + // Upgrade the HTTP connection to a raw byte stream and bridge it to + // the DuplexStream that connects to the gateway-side waiter. + let on_upgrade: OnUpgrade = hyper::upgrade::on(req); + tokio::spawn(async move { + match on_upgrade.await { + Ok(upgraded) => { + let mut upgraded = TokioIo::new(upgraded); + let mut supervisor = supervisor_stream; + let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut supervisor).await; + let _ = AsyncWriteExt::shutdown(&mut upgraded).await; + } + Err(e) => { + warn!(channel_id = %channel_id, error = %e, "relay: upgrade failed"); + } + } + }); + + StatusCode::SWITCHING_PROTOCOLS.into_response() +} diff --git a/crates/openshell-server/src/ssh_tunnel.rs b/crates/openshell-server/src/ssh_tunnel.rs index 536513ccd..8b7d6b48d 100644 --- a/crates/openshell-server/src/ssh_tunnel.rs +++ b/crates/openshell-server/src/ssh_tunnel.rs @@ -6,15 +6,12 @@ use axum::{Router, extract::State, http::Method, response::IntoResponse, routing::any}; use http::StatusCode; use hyper::Request; -use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use openshell_core::proto::{Sandbox, SandboxPhase, SshSession}; use prost::Message; -use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; use tracing::{info, warn}; use uuid::Uuid; @@ -23,7 +20,6 @@ use crate::persistence::{ObjectId, ObjectName, ObjectType, Store}; const HEADER_SANDBOX_ID: &str = "x-sandbox-id"; const HEADER_TOKEN: &str = "x-sandbox-token"; -const PREFACE_MAGIC: &str = "NSSH1"; /// Maximum concurrent SSH tunnel connections per session token. const MAX_CONNECTIONS_PER_TOKEN: u32 = 3; @@ -100,19 +96,23 @@ async fn ssh_connect( return StatusCode::PRECONDITION_FAILED.into_response(); } - let connect_target = match state.compute.resolve_sandbox_endpoint(&sandbox).await { - Ok(crate::compute::ResolvedEndpoint::Ip(ip, port)) => { - ConnectTarget::Ip(SocketAddr::new(ip, port)) - } - Ok(crate::compute::ResolvedEndpoint::Host(host, port)) => ConnectTarget::Host(host, port), - Err(status) if status.code() == tonic::Code::FailedPrecondition => { - return StatusCode::PRECONDITION_FAILED.into_response(); - } - Err(err) => { - warn!(error = %err, "Failed to resolve sandbox endpoint"); + // Open a relay channel through the supervisor session. Use a generous + // 30s session-wait timeout because `/connect/ssh` is typically called + // immediately after `sandbox create`, so we need to cover the supervisor's + // initial TLS + gRPC handshake on a cold-started pod. The old + // direct-connect path tolerated ~34s here for similar reasons. + let (channel_id, relay_rx) = match state + .supervisor_sessions + .open_relay(&sandbox_id, Duration::from_secs(30)) + .await + { + Ok(pair) => pair, + Err(status) => { + warn!(sandbox_id = %sandbox_id, error = %status.message(), "SSH tunnel: supervisor session not available"); return StatusCode::BAD_GATEWAY.into_response(); } }; + // Enforce per-token concurrent connection limit. { let mut counts = state.ssh_connections_by_token.lock().unwrap(); @@ -150,20 +150,97 @@ async fn ssh_connect( let upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - match upgrade.await { - Ok(mut upgraded) => { - if let Err(err) = handle_tunnel( - &mut upgraded, - connect_target, - &token_clone, - &handshake_secret, + // Wait for the supervisor's reverse CONNECT to arrive and claim the relay. + let relay_stream = match tokio::time::timeout(Duration::from_secs(10), relay_rx).await { + Ok(Ok(stream)) => stream, + Ok(Err(_)) => { + warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay channel dropped"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count( + &state_clone.ssh_connections_by_sandbox, &sandbox_id_clone, - ) - .await - { - warn!(error = %err, "SSH tunnel failure"); + ); + return; + } + Err(_) => { + warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay open timed out"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count( + &state_clone.ssh_connections_by_sandbox, + &sandbox_id_clone, + ); + return; + } + }; + + // Send NSSH1 handshake through the relay to the SSH daemon before + // bridging the client's SSH bytes. The relay carries bytes to the + // supervisor which bridges them to the local SSH daemon on loopback. + let (mut relay_read, mut relay_write) = tokio::io::split(relay_stream); + let preface = match build_preface(&token_clone, &handshake_secret) { + Ok(p) => p, + Err(e) => { + warn!(error = %e, "SSH tunnel: failed to build NSSH1 preface"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count( + &state_clone.ssh_connections_by_sandbox, + &sandbox_id_clone, + ); + return; + } + }; + if let Err(e) = relay_write.write_all(preface.as_bytes()).await { + warn!(error = %e, "SSH tunnel: failed to send NSSH1 preface through relay"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count(&state_clone.ssh_connections_by_sandbox, &sandbox_id_clone); + return; + } + + // Read handshake response from the SSH daemon through the relay. + let mut response_buf = Vec::new(); + loop { + let mut byte = [0u8; 1]; + match relay_read.read(&mut byte).await { + Ok(0) => break, + Ok(_) => { + if byte[0] == b'\n' { + break; + } + response_buf.push(byte[0]); + if response_buf.len() > 1024 { + break; + } + } + Err(e) => { + warn!(error = %e, "SSH tunnel: failed to read NSSH1 response from relay"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count( + &state_clone.ssh_connections_by_sandbox, + &sandbox_id_clone, + ); + return; } } + } + let response = String::from_utf8_lossy(&response_buf); + if response.trim() != "OK" { + warn!(response = %response.trim(), "SSH tunnel: NSSH1 handshake rejected by sandbox"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count(&state_clone.ssh_connections_by_sandbox, &sandbox_id_clone); + return; + } + + info!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: NSSH1 handshake OK, bridging client"); + + // Reunite the split relay halves and bridge with the client's upgraded stream. + let mut relay = relay_read.unsplit(relay_write); + + match upgrade.await { + Ok(upgraded) => { + let mut upgraded = TokioIo::new(upgraded); + let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut relay).await; + let _ = AsyncWriteExt::shutdown(&mut upgraded).await; + } Err(err) => { warn!(error = %err, "SSH upgrade failed"); } @@ -177,90 +254,6 @@ async fn ssh_connect( StatusCode::OK.into_response() } -async fn handle_tunnel( - upgraded: &mut Upgraded, - target: ConnectTarget, - token: &str, - secret: &str, - sandbox_id: &str, -) -> Result<(), Box> { - // The sandbox pod may not be network-reachable immediately after the CRD - // reports Ready (DNS propagation, pod IP assignment, SSH server startup). - // Retry the TCP connection with exponential backoff. - let mut upstream = None; - let mut last_err = None; - let delays = [ - Duration::from_millis(100), - Duration::from_millis(250), - Duration::from_millis(500), - Duration::from_secs(1), - Duration::from_secs(2), - Duration::from_secs(5), - Duration::from_secs(10), - Duration::from_secs(15), - ]; - let target_desc = match &target { - ConnectTarget::Ip(addr) => format!("{addr}"), - ConnectTarget::Host(host, port) => format!("{host}:{port}"), - }; - info!(sandbox_id = %sandbox_id, target = %target_desc, "SSH tunnel: connecting to sandbox"); - for (attempt, delay) in std::iter::once(&Duration::ZERO) - .chain(delays.iter()) - .enumerate() - { - if !delay.is_zero() { - info!(sandbox_id = %sandbox_id, attempt = attempt + 1, delay_ms = delay.as_millis() as u64, "SSH tunnel: retrying TCP connect"); - tokio::time::sleep(*delay).await; - } - let result = match &target { - ConnectTarget::Ip(addr) => TcpStream::connect(addr).await, - ConnectTarget::Host(host, port) => TcpStream::connect((host.as_str(), *port)).await, - }; - match result { - Ok(stream) => { - info!( - sandbox_id = %sandbox_id, - attempts = attempt + 1, - "SSH tunnel: TCP connected to sandbox" - ); - upstream = Some(stream); - break; - } - Err(err) => { - info!(sandbox_id = %sandbox_id, attempt = attempt + 1, error = %err, "SSH tunnel: TCP connect failed"); - last_err = Some(err); - } - } - } - let mut upstream = upstream.ok_or_else(|| { - let err = last_err.unwrap(); - format!("failed to connect to sandbox after retries: {err}") - })?; - upstream.set_nodelay(true)?; - info!(sandbox_id = %sandbox_id, "SSH tunnel: sending NSSH1 handshake preface"); - let preface = build_preface(token, secret)?; - upstream.write_all(preface.as_bytes()).await?; - - info!(sandbox_id = %sandbox_id, "SSH tunnel: waiting for handshake response"); - let mut response = String::new(); - read_line(&mut upstream, &mut response).await?; - info!(sandbox_id = %sandbox_id, response = %response.trim(), "SSH tunnel: handshake response received"); - if response.trim() != "OK" { - return Err("sandbox handshake rejected".into()); - } - - info!(sandbox_id = %sandbox_id, "SSH tunnel established"); - let mut upgraded = TokioIo::new(upgraded); - // Discard the result entirely – connection-close errors are expected when - // the SSH session ends and do not represent a failure worth propagating. - let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut upstream).await; - // Gracefully shut down the write-half of the upgraded connection so the - // client receives a clean EOF instead of a TCP RST. This gives SSH time - // to read any remaining protocol data (e.g. exit-status) from its buffer. - let _ = AsyncWriteExt::shutdown(&mut upgraded).await; - Ok(()) -} - fn header_value(headers: &http::HeaderMap, name: &str) -> Result { let value = headers .get(name) @@ -275,6 +268,8 @@ fn header_value(headers: &http::HeaderMap, name: &str) -> Result Result<(), Box> { - let mut bytes = Vec::new(); - loop { - let mut byte = [0u8; 1]; - let n = stream.read(&mut byte).await?; - if n == 0 { - break; - } - if byte[0] == b'\n' { - break; - } - bytes.push(byte[0]); - if bytes.len() > 1024 { - break; - } - } - *buf = String::from_utf8_lossy(&bytes).to_string(); - Ok(()) -} - fn hmac_sha256(key: &[u8], data: &[u8]) -> String { use hmac::{Hmac, Mac}; use sha2::Sha256; @@ -345,11 +317,6 @@ impl ObjectName for SshSession { } } -enum ConnectTarget { - Ip(SocketAddr), - Host(String, u16), -} - /// Decrement a connection count entry, removing it if it reaches zero. fn decrement_connection_count( counts: &std::sync::Mutex>, diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs new file mode 100644 index 000000000..d79540d42 --- /dev/null +++ b/crates/openshell-server/src/supervisor_session.rs @@ -0,0 +1,801 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; +use tracing::{info, warn}; +use uuid::Uuid; + +use openshell_core::proto::{ + GatewayMessage, RelayOpen, SessionAccepted, SupervisorMessage, gateway_message, + supervisor_message, +}; + +use crate::ServerState; + +const HEARTBEAT_INTERVAL_SECS: u32 = 15; +const RELAY_PENDING_TIMEOUT: Duration = Duration::from_secs(10); +/// Initial backoff between session-availability polls in `wait_for_session`. +const SESSION_WAIT_INITIAL_BACKOFF: Duration = Duration::from_millis(100); +/// Maximum backoff between session-availability polls in `wait_for_session`. +const SESSION_WAIT_MAX_BACKOFF: Duration = Duration::from_secs(2); + +// --------------------------------------------------------------------------- +// Session registry +// --------------------------------------------------------------------------- + +/// A live supervisor session handle. +struct LiveSession { + #[allow(dead_code)] + sandbox_id: String, + /// Uniquely identifies this session instance. Used by cleanup to avoid + /// removing a session that has since been superseded by a reconnect. + session_id: String, + tx: mpsc::Sender, + #[allow(dead_code)] + connected_at: Instant, +} + +/// Holds a oneshot sender that will deliver the upgraded relay stream. +type RelayStreamSender = oneshot::Sender; + +/// Registry of active supervisor sessions and pending relay channels. +#[derive(Default)] +pub struct SupervisorSessionRegistry { + /// sandbox_id -> live session handle. + sessions: Mutex>, + /// channel_id -> oneshot sender for the reverse CONNECT stream. + pending_relays: Mutex>, +} + +struct PendingRelay { + sender: RelayStreamSender, + created_at: Instant, +} + +impl std::fmt::Debug for SupervisorSessionRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let session_count = self.sessions.lock().unwrap().len(); + let pending_count = self.pending_relays.lock().unwrap().len(); + f.debug_struct("SupervisorSessionRegistry") + .field("sessions", &session_count) + .field("pending_relays", &pending_count) + .finish() + } +} + +impl SupervisorSessionRegistry { + pub fn new() -> Self { + Self::default() + } + + /// Register a live supervisor session for the given sandbox. + /// + /// Returns the previous session's sender (if any) so the caller can close it. + fn register( + &self, + sandbox_id: String, + session_id: String, + tx: mpsc::Sender, + ) -> Option> { + let mut sessions = self.sessions.lock().unwrap(); + let previous = sessions.remove(&sandbox_id).map(|s| s.tx); + sessions.insert( + sandbox_id.clone(), + LiveSession { + sandbox_id, + session_id, + tx, + connected_at: Instant::now(), + }, + ); + previous + } + + /// Remove the session for a sandbox. + fn remove(&self, sandbox_id: &str) { + self.sessions.lock().unwrap().remove(sandbox_id); + } + + /// Remove the session only if its `session_id` matches the one we are + /// cleaning up. Returns `true` if the entry was removed. + /// + /// This guards against the supersede race: an old session's task may + /// finish long after a new session has taken its place. The old task's + /// cleanup must not evict the new registration. + fn remove_if_current(&self, sandbox_id: &str, session_id: &str) -> bool { + let mut sessions = self.sessions.lock().unwrap(); + let is_current = sessions + .get(sandbox_id) + .is_some_and(|s| s.session_id == session_id); + if is_current { + sessions.remove(sandbox_id); + } + is_current + } + + /// Look up the sender for a supervisor session, waiting up to `timeout` + /// for it to appear if absent. + /// + /// Uses exponential backoff (100ms → 2s) while polling the sessions map. + async fn wait_for_session( + &self, + sandbox_id: &str, + timeout: Duration, + ) -> Result, Status> { + let deadline = Instant::now() + timeout; + let mut backoff = SESSION_WAIT_INITIAL_BACKOFF; + + loop { + if let Some(tx) = self.lookup_session(sandbox_id) { + return Ok(tx); + } + if Instant::now() + backoff > deadline { + return Err(Status::unavailable("supervisor session not connected")); + } + tokio::time::sleep(backoff).await; + backoff = (backoff * 2).min(SESSION_WAIT_MAX_BACKOFF); + } + } + + fn lookup_session(&self, sandbox_id: &str) -> Option> { + self.sessions + .lock() + .unwrap() + .get(sandbox_id) + .map(|s| s.tx.clone()) + } + + /// Open a relay channel and return a receiver for the supervisor-side + /// stream. + /// + /// Sends `RelayOpen` over the supervisor's gRPC session and returns a + /// oneshot receiver that resolves once the supervisor opens its reverse + /// HTTP CONNECT to `/relay/{channel_id}`. + /// + /// If the session is not currently registered, this method waits up to + /// `session_wait_timeout` for it to appear. A session may be temporarily + /// absent for several reasons — all of which look identical from here: + /// + /// - startup race: the sandbox just reported Ready but the supervisor's + /// `ConnectSupervisor` gRPC handshake hasn't completed yet + /// - transient disconnect: the session was up but got dropped (network + /// blip, gateway restart, supervisor restart) and the supervisor is + /// in its reconnect backoff loop + /// + /// Callers pick the timeout based on how much patience the caller needs. + /// A first `sandbox connect` right after `sandbox create` may need to + /// wait for the supervisor's initial TLS + gRPC handshake (tens of + /// seconds on a slow cluster), while mid-lifetime calls typically just + /// need to cover a short reconnect window. + pub async fn open_relay( + &self, + sandbox_id: &str, + session_wait_timeout: Duration, + ) -> Result<(String, oneshot::Receiver), Status> { + let tx = self + .wait_for_session(sandbox_id, session_wait_timeout) + .await?; + + let channel_id = Uuid::new_v4().to_string(); + + // Register the pending relay before sending RelayOpen to avoid a race. + let (relay_tx, relay_rx) = oneshot::channel(); + { + let mut pending = self.pending_relays.lock().unwrap(); + pending.insert( + channel_id.clone(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + } + + let msg = GatewayMessage { + payload: Some(gateway_message::Payload::RelayOpen(RelayOpen { + channel_id: channel_id.clone(), + })), + }; + + if tx.send(msg).await.is_err() { + // Session dropped between our lookup and send. + self.pending_relays.lock().unwrap().remove(&channel_id); + return Err(Status::unavailable("supervisor session disconnected")); + } + + Ok((channel_id, relay_rx)) + } + + /// Claim a pending relay channel. Called by the /relay/{channel_id} HTTP handler + /// when the supervisor's reverse CONNECT arrives. + /// + /// Returns the DuplexStream half that the supervisor side should read/write. + pub fn claim_relay(&self, channel_id: &str) -> Result { + let pending = { + let mut map = self.pending_relays.lock().unwrap(); + map.remove(channel_id) + .ok_or_else(|| Status::not_found("unknown or expired relay channel"))? + }; + + if pending.created_at.elapsed() > RELAY_PENDING_TIMEOUT { + return Err(Status::deadline_exceeded("relay channel timed out")); + } + + // Create a duplex stream pair: one end for the gateway bridge, one for + // the supervisor HTTP CONNECT handler. + let (gateway_stream, supervisor_stream) = tokio::io::duplex(64 * 1024); + + // Send the gateway-side stream to the waiter (ssh_tunnel or exec handler). + if pending.sender.send(gateway_stream).is_err() { + return Err(Status::internal("relay requester dropped")); + } + + Ok(supervisor_stream) + } + + /// Remove all pending relays that have exceeded the timeout. + pub fn reap_expired_relays(&self) { + let mut map = self.pending_relays.lock().unwrap(); + map.retain(|_, pending| pending.created_at.elapsed() <= RELAY_PENDING_TIMEOUT); + } + + /// Clean up all state for a sandbox (session + pending relays). + pub fn cleanup_sandbox(&self, sandbox_id: &str) { + self.remove(sandbox_id); + } +} + +/// Spawn a background task that periodically reaps expired pending relay +/// entries. +/// +/// Pending entries are normally consumed either when the supervisor opens its +/// reverse CONNECT (via `claim_relay`) or by the gateway-side waiter timing +/// out. If neither happens — e.g., the supervisor crashed after acknowledging +/// `RelayOpen` but before dialing back — the entry would otherwise sit in the +/// map indefinitely. This sweeper bounds that leak. +pub fn spawn_relay_reaper(state: Arc, interval: Duration) { + tokio::spawn(async move { + loop { + tokio::time::sleep(interval).await; + state.supervisor_sessions.reap_expired_relays(); + } + }); +} + +// --------------------------------------------------------------------------- +// ConnectSupervisor gRPC handler +// --------------------------------------------------------------------------- + +pub async fn handle_connect_supervisor( + state: &Arc, + request: Request>, +) -> Result< + Response< + Pin> + Send + 'static>>, + >, + Status, +> { + let mut inbound = request.into_inner(); + + // Step 1: Wait for SupervisorHello. + let hello = match inbound.message().await? { + Some(msg) => match msg.payload { + Some(supervisor_message::Payload::Hello(hello)) => hello, + _ => return Err(Status::invalid_argument("expected SupervisorHello")), + }, + None => return Err(Status::invalid_argument("stream closed before hello")), + }; + + let sandbox_id = hello.sandbox_id.clone(); + if sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox_id is required")); + } + + let session_id = Uuid::new_v4().to_string(); + info!( + sandbox_id = %sandbox_id, + session_id = %session_id, + instance_id = %hello.instance_id, + "supervisor session: accepted" + ); + + // Step 2: Create the outbound channel and register the session. + let (tx, rx) = mpsc::channel::(64); + if let Some(_previous_tx) = + state + .supervisor_sessions + .register(sandbox_id.clone(), session_id.clone(), tx.clone()) + { + info!( + sandbox_id = %sandbox_id, + session_id = %session_id, + "supervisor session: superseded previous session" + ); + } + + // Step 3: Send SessionAccepted. + let accepted = GatewayMessage { + payload: Some(gateway_message::Payload::SessionAccepted(SessionAccepted { + session_id: session_id.clone(), + heartbeat_interval_secs: HEARTBEAT_INTERVAL_SECS, + })), + }; + if tx.send(accepted).await.is_err() { + // Only evict ourselves — a faster reconnect may already have + // superseded this registration. + state + .supervisor_sessions + .remove_if_current(&sandbox_id, &session_id); + return Err(Status::internal("failed to send session accepted")); + } + + // Step 4: Spawn the session loop that reads inbound messages. + let state_clone = Arc::clone(state); + let sandbox_id_clone = sandbox_id.clone(); + tokio::spawn(async move { + run_session_loop( + &state_clone, + &sandbox_id_clone, + &session_id, + &tx, + &mut inbound, + ) + .await; + let still_ours = state_clone + .supervisor_sessions + .remove_if_current(&sandbox_id_clone, &session_id); + if still_ours { + info!(sandbox_id = %sandbox_id_clone, session_id = %session_id, "supervisor session: ended"); + } else { + info!(sandbox_id = %sandbox_id_clone, session_id = %session_id, "supervisor session: ended (already superseded)"); + } + }); + + // Return the outbound stream. + let stream = ReceiverStream::new(rx); + let stream: Pin< + Box> + Send + 'static>, + > = Box::pin(tokio_stream::StreamExt::map(stream, Ok)); + + Ok(Response::new(stream)) +} + +async fn run_session_loop( + _state: &Arc, + sandbox_id: &str, + session_id: &str, + tx: &mpsc::Sender, + inbound: &mut tonic::Streaming, +) { + let heartbeat_interval = Duration::from_secs(u64::from(HEARTBEAT_INTERVAL_SECS)); + let mut heartbeat_timer = tokio::time::interval(heartbeat_interval); + // Skip the first immediate tick. + heartbeat_timer.tick().await; + + loop { + tokio::select! { + msg = inbound.message() => { + match msg { + Ok(Some(msg)) => { + handle_supervisor_message(sandbox_id, session_id, msg); + } + Ok(None) => { + info!(sandbox_id = %sandbox_id, session_id = %session_id, "supervisor session: stream closed by supervisor"); + break; + } + Err(e) => { + warn!(sandbox_id = %sandbox_id, session_id = %session_id, error = %e, "supervisor session: stream error"); + break; + } + } + } + _ = heartbeat_timer.tick() => { + let hb = GatewayMessage { + payload: Some(gateway_message::Payload::Heartbeat( + openshell_core::proto::GatewayHeartbeat {}, + )), + }; + if tx.send(hb).await.is_err() { + info!(sandbox_id = %sandbox_id, session_id = %session_id, "supervisor session: outbound channel closed"); + break; + } + } + } + } +} + +fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: SupervisorMessage) { + match msg.payload { + Some(supervisor_message::Payload::Heartbeat(_)) => { + // Heartbeat received — nothing to do for now. + } + Some(supervisor_message::Payload::RelayOpenResult(result)) => { + if result.success { + info!( + sandbox_id = %sandbox_id, + session_id = %session_id, + channel_id = %result.channel_id, + "supervisor session: relay opened successfully" + ); + } else { + warn!( + sandbox_id = %sandbox_id, + session_id = %session_id, + channel_id = %result.channel_id, + error = %result.error, + "supervisor session: relay open failed" + ); + } + } + Some(supervisor_message::Payload::RelayClose(close)) => { + info!( + sandbox_id = %sandbox_id, + session_id = %session_id, + channel_id = %close.channel_id, + reason = %close.reason, + "supervisor session: relay closed by supervisor" + ); + } + _ => { + warn!( + sandbox_id = %sandbox_id, + session_id = %session_id, + "supervisor session: unexpected message type" + ); + } + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + // ---- registry: register / remove ---- + + #[test] + fn registry_register_and_lookup() { + let registry = SupervisorSessionRegistry::new(); + let (tx, _rx) = mpsc::channel(1); + + assert!( + registry + .register("sandbox-1".to_string(), "s1".to_string(), tx) + .is_none() + ); + + let sessions = registry.sessions.lock().unwrap(); + assert!(sessions.contains_key("sandbox-1")); + } + + #[test] + fn registry_supersedes_previous_session() { + let registry = SupervisorSessionRegistry::new(); + let (tx1, _rx1) = mpsc::channel(1); + let (tx2, _rx2) = mpsc::channel(1); + + assert!( + registry + .register("sandbox-1".to_string(), "s1".to_string(), tx1) + .is_none() + ); + assert!( + registry + .register("sandbox-1".to_string(), "s2".to_string(), tx2) + .is_some() + ); + } + + #[test] + fn registry_remove() { + let registry = SupervisorSessionRegistry::new(); + let (tx, _rx) = mpsc::channel(1); + registry.register("sandbox-1".to_string(), "s1".to_string(), tx); + + registry.remove("sandbox-1"); + let sessions = registry.sessions.lock().unwrap(); + assert!(!sessions.contains_key("sandbox-1")); + } + + #[test] + fn remove_if_current_removes_matching_session() { + let registry = SupervisorSessionRegistry::new(); + let (tx, _rx) = mpsc::channel(1); + registry.register("sbx".to_string(), "s1".to_string(), tx); + + assert!(registry.remove_if_current("sbx", "s1")); + assert!(!registry.sessions.lock().unwrap().contains_key("sbx")); + } + + #[test] + fn remove_if_current_ignores_stale_session_id() { + let registry = SupervisorSessionRegistry::new(); + let (tx_old, _rx_old) = mpsc::channel(1); + let (tx_new, _rx_new) = mpsc::channel(1); + + // Old session registers, then is superseded by a new session. + registry.register("sbx".to_string(), "s-old".to_string(), tx_old); + registry.register("sbx".to_string(), "s-new".to_string(), tx_new); + + // Cleanup from the old session task runs late. It must NOT evict the + // newly registered session. + assert!(!registry.remove_if_current("sbx", "s-old")); + let sessions = registry.sessions.lock().unwrap(); + assert!( + sessions.contains_key("sbx"), + "new session must still be registered" + ); + assert_eq!(sessions.get("sbx").unwrap().session_id, "s-new"); + } + + #[test] + fn remove_if_current_unknown_sandbox_is_noop() { + let registry = SupervisorSessionRegistry::new(); + assert!(!registry.remove_if_current("sbx-does-not-exist", "s1")); + } + + // ---- open_relay: happy path and wait semantics ---- + + #[tokio::test] + async fn open_relay_sends_relay_open_to_registered_session() { + let registry = SupervisorSessionRegistry::new(); + let (tx, mut rx) = mpsc::channel(4); + registry.register("sbx".to_string(), "s1".to_string(), tx); + + let (channel_id, _relay_rx) = registry + .open_relay("sbx", Duration::from_secs(1)) + .await + .expect("open_relay should succeed when session is live"); + + let msg = rx.recv().await.expect("relay open should be delivered"); + match msg.payload { + Some(gateway_message::Payload::RelayOpen(open)) => { + assert_eq!(open.channel_id, channel_id); + } + other => panic!("expected RelayOpen, got {other:?}"), + } + } + + #[tokio::test] + async fn open_relay_times_out_without_session() { + let registry = SupervisorSessionRegistry::new(); + let err = registry + .open_relay("missing", Duration::from_millis(50)) + .await + .expect_err("open_relay should time out"); + assert_eq!(err.code(), tonic::Code::Unavailable); + } + + #[tokio::test] + async fn open_relay_waits_for_session_to_appear() { + let registry = Arc::new(SupervisorSessionRegistry::new()); + let registry_for_register = Arc::clone(®istry); + + // Register the session after a small delay, shorter than the wait. + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + let (tx, mut rx) = mpsc::channel::(4); + // Keep the receiver alive so the send in open_relay succeeds. + tokio::spawn(async move { while rx.recv().await.is_some() {} }); + registry_for_register.register("sbx".to_string(), "s1".to_string(), tx); + }); + + let result = registry.open_relay("sbx", Duration::from_secs(2)).await; + assert!( + result.is_ok(), + "open_relay should succeed when session arrives mid-wait: {result:?}" + ); + } + + #[tokio::test] + async fn open_relay_fails_when_session_receiver_dropped() { + let registry = SupervisorSessionRegistry::new(); + let (tx, rx) = mpsc::channel::(4); + registry.register("sbx".to_string(), "s1".to_string(), tx); + + // Simulate the supervisor's stream going away between lookup and send: + // the receiver held by `ReceiverStream` is dropped. + drop(rx); + + let err = registry + .open_relay("sbx", Duration::from_secs(1)) + .await + .expect_err("open_relay should fail when mpsc is closed"); + assert_eq!(err.code(), tonic::Code::Unavailable); + // The pending-relay entry must have been cleaned up on failure. + assert!(registry.pending_relays.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn open_relay_uses_newest_session_after_supersede() { + let registry = SupervisorSessionRegistry::new(); + let (tx_old, mut rx_old) = mpsc::channel::(4); + let (tx_new, mut rx_new) = mpsc::channel(4); + + // Hold a clone of the old sender so supersede doesn't close the old + // channel — that way try_recv distinguishes "no message sent" from + // "channel closed". + let _tx_old_alive = tx_old.clone(); + + registry.register("sbx".to_string(), "s-old".to_string(), tx_old); + registry.register("sbx".to_string(), "s-new".to_string(), tx_new); + + let (_channel_id, _relay_rx) = registry + .open_relay("sbx", Duration::from_secs(1)) + .await + .expect("open_relay should succeed"); + + let msg = rx_new + .recv() + .await + .expect("new session should receive RelayOpen"); + assert!(matches!( + msg.payload, + Some(gateway_message::Payload::RelayOpen(_)) + )); + + // The old session must have received no messages — the channel is + // still open but empty. + use tokio::sync::mpsc::error::TryRecvError; + match rx_old.try_recv() { + Err(TryRecvError::Empty) => {} + other => panic!("expected Empty on superseded session, got {other:?}"), + } + } + + // ---- claim_relay: expiry, drop, wiring ---- + + #[test] + fn claim_relay_unknown_channel() { + let registry = SupervisorSessionRegistry::new(); + let err = registry.claim_relay("nonexistent").expect_err("should err"); + assert_eq!(err.code(), tonic::Code::NotFound); + } + + #[test] + fn claim_relay_success() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-1".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + + let result = registry.claim_relay("ch-1"); + assert!(result.is_ok()); + assert!(!registry.pending_relays.lock().unwrap().contains_key("ch-1")); + } + + #[test] + fn claim_relay_expired_returns_deadline_exceeded() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-old".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now() - Duration::from_secs(60), + }, + ); + + let err = registry + .claim_relay("ch-old") + .expect_err("expired entry must fail"); + assert_eq!(err.code(), tonic::Code::DeadlineExceeded); + // Entry must have been consumed regardless. + assert!( + !registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-old") + ); + } + + #[test] + fn claim_relay_receiver_dropped_returns_internal() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, relay_rx) = oneshot::channel::(); + drop(relay_rx); // Gateway-side waiter has given up already. + registry.pending_relays.lock().unwrap().insert( + "ch-1".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + + let err = registry + .claim_relay("ch-1") + .expect_err("should err when receiver is gone"); + assert_eq!(err.code(), tonic::Code::Internal); + } + + #[tokio::test] + async fn claim_relay_connects_both_ends() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, relay_rx) = oneshot::channel::(); + registry.pending_relays.lock().unwrap().insert( + "ch-io".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + + let mut supervisor_side = registry.claim_relay("ch-io").expect("claim should succeed"); + let mut gateway_side = relay_rx.await.expect("gateway side should receive stream"); + + // Supervisor side writes → gateway side reads. + supervisor_side.write_all(b"hello").await.unwrap(); + let mut buf = [0u8; 5]; + gateway_side.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"hello"); + + // Gateway side writes → supervisor side reads. + gateway_side.write_all(b"world").await.unwrap(); + let mut buf = [0u8; 5]; + supervisor_side.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"world"); + } + + // ---- reap_expired_relays ---- + + #[test] + fn reap_expired_relays_removes_old_entries() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-old".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now() - Duration::from_secs(60), + }, + ); + + registry.reap_expired_relays(); + assert!( + !registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-old") + ); + } + + #[test] + fn reap_expired_relays_keeps_fresh_entries() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-fresh".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + + registry.reap_expired_relays(); + assert!( + registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-fresh") + ); + } +} diff --git a/crates/openshell-server/tests/auth_endpoint_integration.rs b/crates/openshell-server/tests/auth_endpoint_integration.rs index 7c6545873..cd2abe157 100644 --- a/crates/openshell-server/tests/auth_endpoint_integration.rs +++ b/crates/openshell-server/tests/auth_endpoint_integration.rs @@ -528,6 +528,9 @@ impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream< Result, >; + type ConnectSupervisorStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; async fn watch_sandbox( &self, @@ -663,6 +666,13 @@ impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { { Err(tonic::Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented in test")) + } } /// Test 7: Plaintext server (no TLS) accepts both gRPC and HTTP. diff --git a/crates/openshell-server/tests/edge_tunnel_auth.rs b/crates/openshell-server/tests/edge_tunnel_auth.rs index 22f08434d..a5d6a88e9 100644 --- a/crates/openshell-server/tests/edge_tunnel_auth.rs +++ b/crates/openshell-server/tests/edge_tunnel_auth.rs @@ -37,13 +37,14 @@ use hyper_util::{ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -186,6 +187,7 @@ impl OpenShell for TestOpenShell { type WatchSandboxStream = ReceiverStream>; type ExecSandboxStream = ReceiverStream>; + type ConnectSupervisorStream = ReceiverStream>; async fn watch_sandbox( &self, @@ -307,6 +309,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // --------------------------------------------------------------------------- diff --git a/crates/openshell-server/tests/multiplex_integration.rs b/crates/openshell-server/tests/multiplex_integration.rs index 1957c5b87..8b93b0989 100644 --- a/crates/openshell-server/tests/multiplex_integration.rs +++ b/crates/openshell-server/tests/multiplex_integration.rs @@ -11,13 +11,14 @@ use hyper_util::{ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -154,6 +155,7 @@ impl OpenShell for TestOpenShell { type WatchSandboxStream = ReceiverStream>; type ExecSandboxStream = ReceiverStream>; + type ConnectSupervisorStream = ReceiverStream>; async fn watch_sandbox( &self, @@ -275,6 +277,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } #[tokio::test] diff --git a/crates/openshell-server/tests/multiplex_tls_integration.rs b/crates/openshell-server/tests/multiplex_tls_integration.rs index 98d5d6256..4d77e8cae 100644 --- a/crates/openshell-server/tests/multiplex_tls_integration.rs +++ b/crates/openshell-server/tests/multiplex_tls_integration.rs @@ -13,13 +13,14 @@ use hyper_util::{ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -167,6 +168,7 @@ impl OpenShell for TestOpenShell { type WatchSandboxStream = ReceiverStream>; type ExecSandboxStream = ReceiverStream>; + type ConnectSupervisorStream = ReceiverStream>; async fn watch_sandbox( &self, @@ -288,6 +290,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } /// PKI bundle: CA cert, server cert+key, client cert+key. diff --git a/crates/openshell-server/tests/ws_tunnel_integration.rs b/crates/openshell-server/tests/ws_tunnel_integration.rs index 54a7354c8..705e9de49 100644 --- a/crates/openshell-server/tests/ws_tunnel_integration.rs +++ b/crates/openshell-server/tests/ws_tunnel_integration.rs @@ -40,13 +40,14 @@ use hyper_util::{ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -180,6 +181,7 @@ impl OpenShell for TestOpenShell { type WatchSandboxStream = ReceiverStream>; type ExecSandboxStream = ReceiverStream>; + type ConnectSupervisorStream = ReceiverStream>; async fn watch_sandbox( &self, @@ -301,6 +303,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // --------------------------------------------------------------------------- diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 53b0ac27d..68af695e5 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -38,10 +38,6 @@ service ComputeDriver { // Tear down platform resources for a sandbox. rpc DeleteSandbox(DeleteSandboxRequest) returns (DeleteSandboxResponse); - // Resolve the current endpoint for sandbox exec/SSH transport. - rpc ResolveSandboxEndpoint(ResolveSandboxEndpointRequest) - returns (ResolveSandboxEndpointResponse); - // Stream sandbox observations from the platform. rpc WatchSandboxes(WatchSandboxesRequest) returns (stream WatchSandboxesEvent); } @@ -238,27 +234,6 @@ message DeleteSandboxResponse { bool deleted = 1; } -message ResolveSandboxEndpointRequest { - // Sandbox to resolve for exec or SSH connectivity. - DriverSandbox sandbox = 1; -} - -message SandboxEndpoint { - oneof target { - // Direct IP address for the sandbox endpoint. - string ip = 1; - // DNS host name for the sandbox endpoint. - string host = 2; - } - // TCP port for the sandbox endpoint. - uint32 port = 3; -} - -message ResolveSandboxEndpointResponse { - // Current endpoint the gateway should use to reach the sandbox. - SandboxEndpoint endpoint = 1; -} - message WatchSandboxesRequest {} message WatchSandboxesSandboxEvent { diff --git a/proto/openshell.proto b/proto/openshell.proto index 0ee1e8904..53812c977 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -91,6 +91,14 @@ service OpenShell { // Push sandbox supervisor logs to the server (client-streaming). rpc PushSandboxLogs(stream PushSandboxLogsRequest) returns (PushSandboxLogsResponse); + // Persistent supervisor-to-gateway session (bidirectional streaming). + // + // The supervisor opens this stream at startup and keeps it alive for the + // sandbox lifetime. The gateway uses it to coordinate relay channels for + // SSH connect and ExecSandbox. SSH bytes flow over separate reverse HTTP + // CONNECT tunnels, not over this stream. + rpc ConnectSupervisor(stream SupervisorMessage) returns (stream GatewayMessage); + // Watch a sandbox and stream updates. // // This stream can include: @@ -704,6 +712,87 @@ message GetSandboxLogsResponse { uint32 buffer_total = 2; } +// --------------------------------------------------------------------------- +// Supervisor session messages +// --------------------------------------------------------------------------- + +// Envelope for supervisor-to-gateway messages on the ConnectSupervisor stream. +message SupervisorMessage { + oneof payload { + SupervisorHello hello = 1; + SupervisorHeartbeat heartbeat = 2; + RelayOpenResult relay_open_result = 3; + RelayClose relay_close = 4; + } +} + +// Envelope for gateway-to-supervisor messages on the ConnectSupervisor stream. +message GatewayMessage { + oneof payload { + SessionAccepted session_accepted = 1; + SessionRejected session_rejected = 2; + GatewayHeartbeat heartbeat = 3; + RelayOpen relay_open = 4; + RelayClose relay_close = 5; + } +} + +// Supervisor identifies itself and the sandbox it manages. +message SupervisorHello { + // Sandbox ID this supervisor manages. + string sandbox_id = 1; + // Supervisor instance ID (e.g. boot id or process epoch). + string instance_id = 2; +} + +// Gateway accepts the supervisor session. +message SessionAccepted { + // Gateway-assigned session ID for this connection. + string session_id = 1; + // Recommended heartbeat interval in seconds. + uint32 heartbeat_interval_secs = 2; +} + +// Gateway rejects the supervisor session. +message SessionRejected { + // Human-readable rejection reason. + string reason = 1; +} + +// Supervisor heartbeat. +message SupervisorHeartbeat {} + +// Gateway heartbeat. +message GatewayHeartbeat {} + +// Gateway requests the supervisor to open a relay channel. +// +// On receiving this, the supervisor should open a reverse HTTP CONNECT +// to the gateway's /relay/{channel_id} endpoint and bridge it to the +// local SSH daemon. +message RelayOpen { + // Gateway-allocated channel identifier (UUID). + string channel_id = 1; +} + +// Supervisor reports the result of a relay open request. +message RelayOpenResult { + // Channel identifier from the RelayOpen request. + string channel_id = 1; + // True if the relay was successfully established. + bool success = 2; + // Error message if success is false. + string error = 3; +} + +// Either side requests closure of a relay channel. +message RelayClose { + // Channel identifier to close. + string channel_id = 1; + // Optional reason for closure. + string reason = 2; +} + // --------------------------------------------------------------------------- // Service status // ---------------------------------------------------------------------------