Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions config/client_aux.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
huri_url: ws://localhost:8000/session

topic_list: ["transcript", "question"]
topic_list: [question]

sample_rate: 16000
frame_duration: 0.030
senders:
audio:
name: audio
args:
sample_rate: 16000
frame_duration: 0.030

modules:
mic:
name: mic
args:
vad_agressiveness: 3
silence_duration: 1.5
block_duration: ${frame_duration}
block_duration: ${inputs.audio.args.frame_duration}
logging: INFO
stt:
name: stt
args:
language: "fr"
block_duration: ${frame_duration}
language: "en"
block_duration: ${inputs.audio.args.frame_duration}
logging: INFO
tag:
name: tag
Expand Down
30 changes: 30 additions & 0 deletions config/client_auxio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
huri_url: ws://localhost:8000/session

topic_list: [question]

senders:
audio:
name: audio
args:
sample_rate: 16000
frame_duration: 0.030
text:
name: text

modules:
mic:
name: mic
args:
vad_agressiveness: 3
silence_duration: 1.5
block_duration: ${inputs.audio.args.frame_duration}
logging: INFO
stt:
name: stt
args:
language: en
block_duration: ${inputs.audio.args.frame_duration}
logging: INFO
tag:
name: tag
logging: INFO
16 changes: 13 additions & 3 deletions config/client_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,24 @@
huri_url: ws://localhost:8000/session

# List of event topic the client will receive
topic_list: ["topic1", "topic2"]
topic_list: [topic1, topic2]

# Define module custom args
# Define senders to be used and their custom args
senders:
# sender tag can be anything
example:
# sender name must be in the list of available ClientSender in Client instance (src.client_sender:get_senders)
name: my_sender
# if my_sender init with "model", "sample_rate" and "refresh_rate" params, they can be customized here
args:
refresh_rate: infinite

# Define module to be used and their custom args
modules:
# module tag can be anything
example:
# module name must be in the list of available module in HuRI's instance (src.modules.modules:get_modules)
name: my_module
# if my_module init with "model", "sample_rate" and "hello" params, they can be customized here
args:
hello: "world"
hello: world
12 changes: 12 additions & 0 deletions config/client_text.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
huri_url: ws://localhost:8000/session

topic_list: [question]

senders:
text:
name: text

modules:
tag:
name: tag
logging: INFO
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ faster-whisper
# client
sounddevice
websockets
omegaconf
omegaconf
prompt-toolkit
4 changes: 3 additions & 1 deletion src/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from ray.serve import Application

from src.core.huri import HuRI
from src.modules.events import get_events
from src.modules.factory import bind_deployment_handles
from src.modules.modules import get_modules


def build_app() -> Application:
modules = get_modules()
handles = bind_deployment_handles(modules)
events = get_events()

app: Application = HuRI.bind(modules, handles) # type: ignore[attr-defined]
app: Application = HuRI.bind(modules, handles, events) # type: ignore[attr-defined]
return app


Expand Down
42 changes: 4 additions & 38 deletions src/client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import argparse
import asyncio
import json
from dataclasses import asdict
from typing import Dict

import numpy as np
import sounddevice as sd
import websockets
from omegaconf import OmegaConf

from src.core.client import Client
from src.core.dataclasses.config import ClientConfig


Expand All @@ -23,7 +19,7 @@ def load_client_config(path: str) -> ClientConfig:
return ClientConfig.from_dict(raw_resolved)


async def stream_audio():
async def launch_client():
parser = argparse.ArgumentParser(description="Client config")
parser.add_argument(
"--config",
Expand All @@ -34,38 +30,8 @@ async def stream_audio():
args = parser.parse_args()
config = load_client_config(args.config)

FRAME_SIZE = int(config.sample_rate * config.frame_duration)
async with websockets.connect(config.huri_url) as ws:
print("Connected to server")

await ws.send(json.dumps(asdict(config)))

async def receive(ws: websockets.ClientConnection):
while True:
text = await ws.recv()
print("received:", text)

async def send(ws: websockets.ClientConnection):
loop = asyncio.get_running_loop()

queue: asyncio.Queue = asyncio.Queue()

def callback(indata: np.ndarray, frames, time, status):
loop.call_soon_threadsafe(queue.put_nowait, indata.copy())

with sd.InputStream(
samplerate=config.sample_rate,
channels=1,
dtype="int16",
callback=callback,
blocksize=FRAME_SIZE,
):
while True:
chunk = await queue.get()
await ws.send(chunk.tobytes())

await asyncio.gather(receive(ws), send(ws))
await Client(config=config).run()


if __name__ == "__main__":
asyncio.run(stream_audio())
asyncio.run(launch_client())
44 changes: 44 additions & 0 deletions src/core/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import asyncio
import json
from dataclasses import asdict
from typing import Dict, List, Type

import websockets

from src.core.dataclasses.config import ClientConfig

from .client_senders import ClientSender, get_senders


class Client:
"""Client is init with a Config, and connects to HuRI using websockets"""

def __init__(
self,
config: ClientConfig,
senders_dict: Dict[str, Type[ClientSender]] = get_senders(),
):
self.config = config
self.senders_dict = senders_dict

async def _receive_loop(self, ws: websockets.ClientConnection):
while True:
text = await ws.recv()
print("<<", text)
await asyncio.sleep(0.1)

async def run(self):
async with websockets.connect(self.config.huri_url) as ws:
print("Connected to server")

senders: List[ClientSender] = [
self.senders_dict[config.name](ws=ws, **config.args)
for config in self.config.senders.values()
]

await ws.send(json.dumps(asdict(self.config)))

await asyncio.gather(
*(sender.input_loop() for sender in senders),
self._receive_loop(ws),
)
94 changes: 94 additions & 0 deletions src/core/client_senders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
import json
import struct
from dataclasses import asdict
from typing import Dict, Type

import numpy as np
import sounddevice as sd
import websockets
from prompt_toolkit import PromptSession
from prompt_toolkit.patch_stdout import patch_stdout

from src.core.events import EventData
from src.modules.speech_to_text.events import Sentence


class ClientSender:
"""This class abstract sending data to HuRI.

output_type: is the topic that the ClientSender will send.
Data structure must match event topic.

Class derived from ClientSender must implement input_loop,
and use ClientSender.send to send data to HuRI. It can be EventData or bytes
"""

output_type: str

def __init__(self, ws: websockets.ClientConnection):
self.ws = ws

async def input_loop(self):
raise NotImplementedError

async def send(self, topic: str, data: EventData | bytes):
packet: str | bytes
if isinstance(data, EventData):
packet = json.dumps({"topic": topic, "data": asdict(data)})
else:
topic_bytes = topic.encode()

packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data

await self.ws.send(packet)


class AudioSender(ClientSender):
output_type = "audio"

def __init__(
self, sample_rate: int = 16000, frame_duration: float = 0.030, **kwargs
):
super().__init__(**kwargs)

self.sample_rate = sample_rate
self.frame_size = int(sample_rate * frame_duration)

async def input_loop(self):
loop = asyncio.get_running_loop()

queue: asyncio.Queue[np.ndarray] = asyncio.Queue()

def callback(indata: np.ndarray, frames, time, status):
loop.call_soon_threadsafe(queue.put_nowait, indata.copy())

with sd.InputStream(
samplerate=self.sample_rate,
channels=1,
dtype="int16",
callback=callback,
blocksize=self.frame_size,
):
while True:
chunk = await queue.get()
await self.send(self.output_type, chunk.tobytes())


class TextSender(ClientSender):
output_type = "question"

def __init__(self, **kwargs):
super().__init__(**kwargs)

async def input_loop(self):
session: PromptSession = PromptSession()
while True:
with patch_stdout():
text = await session.prompt_async(">> ")

await self.send(self.output_type, Sentence(text))


def get_senders() -> Dict[str, Type[ClientSender]]:
return {"audio": AudioSender, "text": TextSender}
23 changes: 19 additions & 4 deletions src/core/dataclasses/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,39 @@ def from_dict(self, raw: dict) -> "ModuleConfig":
)


@dataclass
class ClientSenderConfig:
name: str
args: Mapping[str, Any]

@classmethod
def from_dict(self, raw: dict) -> "ClientSenderConfig":
return self(
name=raw["name"],
args=raw.get("args", {}),
)


@dataclass
class ClientConfig:
huri_url: str
topic_list: List[str]
sample_rate: float
frame_duration: float
senders: Dict[str, ClientSenderConfig]
modules: Dict[str, ModuleConfig]

@classmethod
def from_dict(cls, raw: Dict) -> "ClientConfig":
senders = {
sender_id: ClientSenderConfig.from_dict(mod_raw)
for sender_id, mod_raw in raw.get("senders", {}).items()
}
modules = {
module_id: ModuleConfig.from_dict(mod_raw)
for module_id, mod_raw in raw.get("modules", {}).items()
}
return cls(
huri_url=raw["huri_url"],
topic_list=raw["topic_list"],
sample_rate=raw["sample_rate"],
frame_duration=raw["frame_duration"],
senders=senders,
modules=modules,
)
Loading