Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 63 additions & 11 deletions ostool/src/board/serial_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::sync::{
atomic::{AtomicBool, Ordering},
};

use std::io::ErrorKind;

use anyhow::Context as _;
use futures::{SinkExt, StreamExt};
use tokio::{
Expand Down Expand Up @@ -44,12 +46,12 @@ pub async fn connect_serial_stream(
while let Some(message) = ws_stream.next().await {
match message.context("serial websocket read failed")? {
Message::Binary(bytes) => {
tokio::io::AsyncWriteExt::write_all(&mut bridge_tx, &bytes)
.await
.context("failed to write serial websocket bytes")?;
tokio::io::AsyncWriteExt::flush(&mut bridge_tx)
if write_bridge_bytes(&mut bridge_tx, &bytes)
.await
.context("failed to flush serial websocket bytes")?;
.context("failed to write serial websocket bytes")?
{
break;
}
}
Message::Text(text) => {
if let Ok(control) = serde_json::from_str::<ServerControlMessage>(&text) {
Expand All @@ -64,12 +66,12 @@ pub async fn connect_serial_stream(
}
}

tokio::io::AsyncWriteExt::write_all(&mut bridge_tx, text.as_bytes())
if write_bridge_bytes(&mut bridge_tx, text.as_bytes())
.await
.context("failed to write text serial websocket payload")?;
tokio::io::AsyncWriteExt::flush(&mut bridge_tx)
.await
.context("failed to flush text serial websocket payload")?;
.context("failed to write text serial websocket payload")?
{
break;
}
}
Message::Close(_) => {
if locally_closed.load(Ordering::SeqCst) {
Expand Down Expand Up @@ -125,6 +127,22 @@ pub async fn connect_serial_stream(
))
}

async fn write_bridge_bytes<W>(writer: &mut W, bytes: &[u8]) -> anyhow::Result<bool>
where
W: tokio::io::AsyncWrite + Unpin,
{
match tokio::io::AsyncWriteExt::write_all(writer, bytes).await {
Ok(()) => {}
Err(err) if err.kind() == ErrorKind::BrokenPipe => return Ok(true),
Err(err) => return Err(err.into()),
}
match tokio::io::AsyncWriteExt::flush(writer).await {
Ok(()) => Ok(false),
Err(err) if err.kind() == ErrorKind::BrokenPipe => Ok(true),
Err(err) => Err(err.into()),
}
}

impl SerialStreamTasks {
pub async fn shutdown(self) -> anyhow::Result<()> {
let write_result = self.write_task.await;
Expand Down Expand Up @@ -204,7 +222,7 @@ mod tests {

use tokio::{sync::Notify, task::JoinHandle};

use super::SerialStreamTasks;
use super::{SerialStreamTasks, write_bridge_bytes};

#[tokio::test]
async fn shutdown_waits_for_writer_before_reader() {
Expand Down Expand Up @@ -240,4 +258,38 @@ mod tests {
.await
.unwrap();
}

#[tokio::test]
async fn shutdown_allows_reader_to_finish_after_local_consumer_closed() {
let (mut writer, reader) = tokio::io::duplex(1);
drop(reader);

let read_task: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
if write_bridge_bytes(&mut writer, b"late console output").await? {
return Ok(());
}
anyhow::bail!("bridge writer unexpectedly stayed open")
});
let write_task: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move { Ok(()) });

SerialStreamTasks {
read_task,
write_task,
}
.shutdown()
.await
.unwrap();
}

#[tokio::test]
async fn bridge_writer_treats_closed_local_consumer_as_done() {
let (mut writer, reader) = tokio::io::duplex(1);
drop(reader);

assert!(
write_bridge_bytes(&mut writer, b"late console output")
.await
.unwrap()
);
}
}