Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions include/livekit/data_track_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

#include <condition_variable>
#include <cstdint>
#include <deque>
#include <memory>
#include <mutex>
#include <optional>

Expand All @@ -32,7 +30,8 @@ namespace livekit {

namespace proto {
class FfiEvent;
}
class DataTrackStreamReadResponse;
} // namespace proto

/**
* Represents a pull-based stream of frames from a remote data track.
Expand Down Expand Up @@ -111,6 +110,12 @@ class LIVEKIT_API DataTrackStream {
/// FFI event handler, called by FfiClient.
void onFfiEvent(const proto::FfiEvent& event);

/// Handle the immediate response returned by a read request.
void handleReadResponse(const proto::DataTrackStreamReadResponse& response);

/// Mark the stream failed due to an invalid FFI protocol response.
void failProtocolError(const char* message);

/// Push a received DataTrackFrame to the internal storage.
void pushFrame(DataTrackFrame&& frame);

Expand Down Expand Up @@ -143,7 +148,7 @@ class LIVEKIT_API DataTrackStream {
FfiHandle subscription_handle_;

/** FfiClient listener id for routing FfiEvent callbacks to this object. */
std::int32_t listener_id_{0};
std::int32_t listener_id_{-1};
};

} // namespace livekit
53 changes: 45 additions & 8 deletions src/data_track_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "livekit/data_track_stream.h"

#include <optional>
#include <utility>

#include "data_track.pb.h"
Expand All @@ -27,6 +28,19 @@ namespace livekit {

using proto::FfiEvent;

namespace {

constexpr char kMissingReadResponseError[] = "DataTrackStream::read: FFI response missing data_track_stream_read";

std::optional<SubscribeDataTrackError> terminalErrorFromEos(const proto::DataTrackStreamEOS& eos) {
if (!eos.has_error()) {
return std::nullopt;
}
return SubscribeDataTrackError::fromProto(eos.error());
}

} // namespace

DataTrackStream::~DataTrackStream() { close(); }

void DataTrackStream::init(FfiHandle subscription_handle) {
Expand All @@ -36,6 +50,9 @@ void DataTrackStream::init(FfiHandle subscription_handle) {
}

bool DataTrackStream::read(DataTrackFrame& out) {
proto::DataTrackStreamReadResponse read_response;
bool missing_read_response = false;

{
const std::scoped_lock<std::mutex> lock(mutex_);
if (closed_ || eof_) {
Expand All @@ -50,9 +67,21 @@ bool DataTrackStream::read(DataTrackFrame& out) {
proto::FfiRequest req;
auto* msg = req.mutable_data_track_stream_read();
msg->set_stream_handle(subscription_handle);
FfiClient::instance().sendRequest(req);
const proto::FfiResponse resp = FfiClient::instance().sendRequest(req);
if (!resp.has_data_track_stream_read()) {
missing_read_response = true;
} else {
read_response = resp.data_track_stream_read();
}
}

if (missing_read_response) {
failProtocolError(kMissingReadResponseError);
return false;
}

handleReadResponse(read_response);
Comment thread
stephen-derosa marked this conversation as resolved.

std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return frame_.has_value() || eof_ || closed_; });

Expand Down Expand Up @@ -113,13 +142,21 @@ void DataTrackStream::onFfiEvent(const FfiEvent& event) {
DataTrackFrame frame = DataTrackFrame::fromOwnedInfo(fr);
pushFrame(std::move(frame));
} else if (dts.has_eos()) {
std::optional<SubscribeDataTrackError> error;
const auto& eos = dts.eos();
if (eos.has_error()) {
error = SubscribeDataTrackError::fromProto(eos.error());
}
pushEos(std::move(error));
pushEos(terminalErrorFromEos(dts.eos()));
}
}

void DataTrackStream::handleReadResponse(const proto::DataTrackStreamReadResponse& response) {
if (!response.has_eos_event()) {
return;
}
pushEos(terminalErrorFromEos(response.eos_event()));
}

void DataTrackStream::failProtocolError(const char* message) {
LK_LOG_ERROR("{}", message);
pushEos(SubscribeDataTrackError{SubscribeDataTrackErrorCode::PROTOCOL_ERROR, message});
close();
}

void DataTrackStream::pushFrame(DataTrackFrame&& frame) {
Expand All @@ -141,7 +178,7 @@ void DataTrackStream::pushFrame(DataTrackFrame&& frame) {
void DataTrackStream::pushEos(std::optional<SubscribeDataTrackError> error) {
{
const std::scoped_lock<std::mutex> lock(mutex_);
if (eof_) {
if (closed_ || eof_) {
return;
}
eof_ = true;
Expand Down
47 changes: 47 additions & 0 deletions src/tests/unit/test_data_track_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class DataTrackStreamTest : public ::testing::Test {

static void pushEvent(DataTrackStream& stream, const proto::FfiEvent& event) { stream.onFfiEvent(event); }

static void handleReadResponse(DataTrackStream& stream, const proto::DataTrackStreamReadResponse& response) {
stream.handleReadResponse(response);
}

static void failProtocolError(DataTrackStream& stream, const char* message) { stream.failProtocolError(message); }

static proto::FfiEvent makeEosEvent(std::optional<proto::SubscribeDataTrackErrorCode> code = std::nullopt,
const std::string& message = {}) {
proto::FfiEvent event;
Expand All @@ -47,6 +53,18 @@ class DataTrackStreamTest : public ::testing::Test {
return event;
}

static proto::DataTrackStreamReadResponse makeEosReadResponse(
std::optional<proto::SubscribeDataTrackErrorCode> code = std::nullopt, const std::string& message = {}) {
proto::DataTrackStreamReadResponse response;
auto* eos = response.mutable_eos_event();
if (code.has_value()) {
auto* error = eos->mutable_error();
error->set_code(code.value());
error->set_message(message);
}
return response;
}

static proto::FfiEvent makeAudioStreamEvent() {
proto::FfiEvent event;
event.mutable_audio_stream_event()->set_stream_handle(0);
Expand All @@ -71,6 +89,15 @@ TEST_F(DataTrackStreamTest, TerminalErrorEmptyForNormalEos) {
EXPECT_FALSE(stream->terminalError().has_value());
}

TEST_F(DataTrackStreamTest, ReadResponseNormalEosEndsStreamWithoutTerminalError) {
auto stream = makeStream();
handleReadResponse(*stream, makeEosReadResponse());

DataTrackFrame frame;
EXPECT_FALSE(stream->read(frame));
EXPECT_FALSE(stream->terminalError().has_value());
}

TEST_F(DataTrackStreamTest, TerminalErrorStoredForSubscribeFailureEos) {
auto stream = makeStream();
pushEvent(*stream, makeEosEvent(proto::SUBSCRIBE_DATA_TRACK_ERROR_CODE_UNPUBLISHED,
Expand All @@ -85,6 +112,26 @@ TEST_F(DataTrackStreamTest, TerminalErrorStoredForSubscribeFailureEos) {
EXPECT_EQ(error->message, "track unpublished before subscription completed");
}

TEST_F(DataTrackStreamTest, ReadResponseSubscribeFailureEosStoresTerminalError) {
auto stream = makeStream();
handleReadResponse(*stream, makeEosReadResponse(proto::SUBSCRIBE_DATA_TRACK_ERROR_CODE_UNPUBLISHED,
"track unpublished before read completed"));

DataTrackFrame frame;
EXPECT_FALSE(stream->read(frame));
expectTerminalError(*stream, SubscribeDataTrackErrorCode::UNPUBLISHED, "track unpublished before read completed");
}

TEST_F(DataTrackStreamTest, ProtocolErrorClosesStreamAndStoresTerminalError) {
auto stream = makeStream();

EXPECT_NO_THROW(failProtocolError(*stream, "malformed FFI response"));

DataTrackFrame frame;
EXPECT_FALSE(stream->read(frame));
expectTerminalError(*stream, SubscribeDataTrackErrorCode::PROTOCOL_ERROR, "malformed FFI response");
}

TEST_F(DataTrackStreamTest, CloseBeforeEosSuppressesLaterTerminalError) {
auto stream = makeStream();
stream->close();
Expand Down
Loading