Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 134 additions & 4 deletions ostool-server/src/serial/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use futures_util::{Sink, SinkExt, StreamExt};
use serde::Deserialize;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::task::JoinHandle;
use tokio_serial::SerialPortBuilderExt;
use tokio_serial::{ClearBuffer, SerialPort, SerialPortBuilderExt};

use crate::{
config::BoardConfig,
Expand Down Expand Up @@ -234,6 +234,8 @@ async fn run_serial_ws_inner(

let result =
finalize_power_linked_session(state, &board, power_linked, power_on_task, result).await;
let mut port = serial_rx.unsplit(serial_tx);
let result = preserve_result_after_serial_cleanup(&session_id, result, &mut port).await;
let _ = state
.request_session_stop(&session_id, crate::session::SessionStopReason::SerialClosed)
.await;
Expand Down Expand Up @@ -308,11 +310,55 @@ async fn write_serial_payload(
Ok(())
}

#[async_trait::async_trait]
trait SerialQueueCleanup {
async fn flush_output(&mut self) -> std::io::Result<()>;
fn clear_all_buffers(&mut self) -> std::io::Result<()>;
}

#[async_trait::async_trait]
impl SerialQueueCleanup for tokio_serial::SerialStream {
async fn flush_output(&mut self) -> std::io::Result<()> {
AsyncWriteExt::flush(self).await
}

fn clear_all_buffers(&mut self) -> std::io::Result<()> {
self.clear(ClearBuffer::All).map_err(std::io::Error::from)
}
}

async fn cleanup_serial_queue_before_close<T>(port: &mut T) -> anyhow::Result<()>
where
T: SerialQueueCleanup + ?Sized,
{
port.flush_output()
.await
.context("failed to flush serial output before close")?;
port.clear_all_buffers()
.context("failed to clear serial buffers before close")?;
Ok(())
}

async fn preserve_result_after_serial_cleanup<T, P>(
session_id: &str,
result: anyhow::Result<T>,
port: &mut P,
) -> anyhow::Result<T>
where
P: SerialQueueCleanup + ?Sized,
{
if let Err(err) = cleanup_serial_queue_before_close(port).await {
log::warn!("session `{session_id}` failed to clean serial queue before close: {err:#}");
}
result
}

#[cfg(test)]
mod tests {
use std::{
fs,
fs, io,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
time::Duration,
};
Expand All @@ -322,8 +368,9 @@ mod tests {
use tempfile::tempdir;

use super::{
ClientControlMessage, cleanup_power_link, finalize_power_linked_session,
send_power_on_failure_and_close,
ClientControlMessage, SerialQueueCleanup, cleanup_power_link,
cleanup_serial_queue_before_close, finalize_power_linked_session,
preserve_result_after_serial_cleanup, send_power_on_failure_and_close,
};
use crate::{
build_app_state,
Expand All @@ -340,6 +387,33 @@ mod tests {
messages: Vec<Message>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
enum CleanupEvent {
Flush,
ClearAll,
}

struct RecordingSerialCleanup {
events: Arc<Mutex<Vec<CleanupEvent>>>,
clear_result: io::Result<()>,
}

#[async_trait::async_trait]
impl SerialQueueCleanup for RecordingSerialCleanup {
async fn flush_output(&mut self) -> io::Result<()> {
self.events.lock().unwrap().push(CleanupEvent::Flush);
Ok(())
}

fn clear_all_buffers(&mut self) -> io::Result<()> {
self.events.lock().unwrap().push(CleanupEvent::ClearAll);
self.clear_result
.as_ref()
.map(|_| ())
.map_err(|err| io::Error::new(err.kind(), err.to_string()))
}
}

impl Sink<Message> for VecSink {
type Error = ();

Expand Down Expand Up @@ -376,6 +450,62 @@ mod tests {
assert_eq!(message.kind, "close");
}

#[tokio::test]
async fn serial_cleanup_flushes_before_clearing_all_buffers() {
let events = Arc::new(Mutex::new(Vec::new()));
let mut cleanup = RecordingSerialCleanup {
events: events.clone(),
clear_result: Ok(()),
};

cleanup_serial_queue_before_close(&mut cleanup)
.await
.unwrap();

assert_eq!(
events.lock().unwrap().as_slice(),
&[CleanupEvent::Flush, CleanupEvent::ClearAll]
);
}

#[tokio::test]
async fn serial_cleanup_reports_clear_failures() {
let events = Arc::new(Mutex::new(Vec::new()));
let mut cleanup = RecordingSerialCleanup {
events: events.clone(),
clear_result: Err(io::Error::new(io::ErrorKind::Other, "clear failed")),
};

let err = cleanup_serial_queue_before_close(&mut cleanup)
.await
.unwrap_err();

assert!(err.to_string().contains("failed to clear serial buffers"));
assert_eq!(
events.lock().unwrap().as_slice(),
&[CleanupEvent::Flush, CleanupEvent::ClearAll]
);
}

#[tokio::test]
async fn serial_cleanup_failure_preserves_original_session_error() {
let events = Arc::new(Mutex::new(Vec::new()));
let mut cleanup = RecordingSerialCleanup {
events,
clear_result: Err(io::Error::new(io::ErrorKind::Other, "clear failed")),
};

let err = preserve_result_after_serial_cleanup::<(), _>(
"session-1",
Err(anyhow::anyhow!("websocket failed")),
&mut cleanup,
)
.await
.unwrap_err();

assert_eq!(err.to_string(), "websocket failed");
}

async fn test_state(root: &std::path::Path) -> crate::AppState {
let config_path = root.join(".ostool-server.toml");
let config = ServerConfig {
Expand Down