diff --git a/config/client_aux2.yaml b/config/client_aux .yaml similarity index 55% rename from config/client_aux2.yaml rename to config/client_aux .yaml index 7d7b601..d1f5195 100644 --- a/config/client_aux2.yaml +++ b/config/client_aux .yaml @@ -1,13 +1,31 @@ huri_url: ws://localhost:8000/session -topic_list: [transcript, question, rag_response] +interface_path: src.interfaces.cli_interface:cli_interface senders: audio: name: audio + topic: audio args: sample_rate: 16000 frame_duration: 0.030 + text: + name: text + topic: question + args: + sample_rate: 16000 + frame_duration: 0.030 + +hooks: + text: + name: text + topics: [question, answer] + audio: + name: audio + topics: [audio] + args: + incoming_sample_rate: ${senders.audio.args.sample_rate} + sample_rate: 44100 modules: mic: @@ -21,10 +39,8 @@ modules: args: language: en block_duration: ${senders.audio.args.frame_duration} - logging: INFO tag: name: tag - logging: INFO rag: name: rag args: diff --git a/config/client_aux.yaml b/config/client_aux.yaml deleted file mode 100644 index fe3e332..0000000 --- a/config/client_aux.yaml +++ /dev/null @@ -1,28 +0,0 @@ -huri_url: ws://localhost:8000/session - -topic_list: [question] - -senders: - audio: - name: audio - args: - sample_rate: 16000 - frame_duration: 0.030 - -modules: - mic: - name: mic - args: - vad_agressiveness: 3 - silence_duration: 1.5 - block_duration: ${inputs.audio.args.frame_duration} - logging: INFO - stt: - name: stt - args: - language: "en" - block_duration: ${inputs.audio.args.frame_duration} - logging: INFO - tag: - name: tag - logging: INFO diff --git a/config/client_auxio.yaml b/config/client_auxio.yaml deleted file mode 100644 index 8fa2a91..0000000 --- a/config/client_auxio.yaml +++ /dev/null @@ -1,25 +0,0 @@ -huri_url: ws://localhost:8000/session - -topic_list: [question] - -senders: - text: - name: text - -modules: - mic: - name: mic - args: - vad_agressiveness: 3 - silence_duration: 1.5 - block_duration: ${senders.audio.args.frame_duration} - logging: INFO - stt: - name: stt - args: - language: en - block_duration: ${senders.audio.args.frame_duration} - logging: INFO - tag: - name: tag - logging: INFO diff --git a/config/client_template.yaml b/config/client_template.yaml index cf1627d..441f3c5 100644 --- a/config/client_template.yaml +++ b/config/client_template.yaml @@ -1,19 +1,34 @@ # HuRI websocket server url huri_url: ws://localhost:8000/session -# List of event topic the client will receive -topic_list: [topic1, topic2] +# Define interface to be used's import path +interface_path: src.interfaces.cli_interface:cli_interface # Define senders to be used and their custom args senders: # sender tag can be anything example: - # sender name must be in the list of available ClientSender in Client instance (src.client_sender:get_senders) + # sender name must be in the list of available ClientSender in chosen Interface (Interface.get_senders) name: my_sender + # topic the sender will send to HuRI, it must match output_type event data structure + topic: my_event # if my_sender init with "model", "sample_rate" and "refresh_rate" params, they can be customized here args: refresh_rate: infinite +# Define hooks to be used and their custom args +hooks: + # hook tag can be anything + example: + # hook name must be in the list of available ClientHook in chosen Interface (Interface.get_senders) + name: my_hook + # topics the hook will process from HuRI, it must match input_type event data structure + topics: [my_event, llm_response] + # if my_hook init with "model", "sample_rate" and "refresh_rate" params, they can be customized here + args: + sample_rate: 0 + no: beat + # Define module to be used and their custom args modules: # module tag can be anything diff --git a/config/client_text.yaml b/config/client_text.yaml index 8ddcaab..d2fb26f 100644 --- a/config/client_text.yaml +++ b/config/client_text.yaml @@ -1,10 +1,16 @@ huri_url: ws://localhost:8000/session -topic_list: [question, rag_response] +interface_path: src.interfaces.cli_interface:cli_interface senders: text: name: text + topic: question + +hooks: + text: + name: text + topics: [rag_response] modules: rag: diff --git a/src/core/client.py b/src/core/client.py index 085a0b8..6927565 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -1,14 +1,80 @@ import asyncio +import importlib import json import os +import struct +from collections import defaultdict from dataclasses import asdict -from typing import Dict, List, Optional, Type +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar import websockets from src.core.dataclasses.config import ClientConfig +from src.core.events import EventData -from .client_senders import ClientSender, get_senders +T = TypeVar("T", bound=EventData | bytes) + + +class ClientSender(Generic[T]): + """This class abstract sending data to HuRI. + + output_type: is the event data structure that the ClientSender will send. + It can be EventData or bytes, and must match event topic it send. + + Class derived from ClientSender must implement input_loop, + and use ClientSender.send to send data to HuRI. + + `singletton` is available to access shared ressources. + """ + + output_type: Type[T] + + def __init__(self, topic: str, singletton: Any, **_): + """ + :topic: topic sent to HuRI + :singletton: allow to get shared ressources""" + self.topic = topic + self.singletton = singletton + + async def input_loop(self, ws: websockets.ClientConnection): + raise NotImplementedError + + async def _send_bytes(self, ws: websockets.ClientConnection, data: bytes): + topic_bytes = self.topic.encode() + packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data + + await ws.send(packet) + + async def _send_event_data(self, ws: websockets.ClientConnection, data: EventData): + packet = json.dumps({"topic": self.topic, "data": asdict(data)}) + + await ws.send(packet) + + async def send(self, ws: websockets.ClientConnection, data: T): + if isinstance(data, bytes): + await self._send_bytes(ws, data) + else: + await self._send_event_data(ws, data) + + +class ClientHook(Generic[T]): + """This class abstract processing data from HuRI. + + input_type: is the event data structure that the ClientHook will process. + It can be EventData or bytes, and must match event topic it react to. + + Class derived from ClientHook must implement hook. + + `singletton` is available to access and modifies shared ressources. + """ + + input_type: Type[T] + + def __init__(self, singletton: Any, **_): + self.singletton = singletton + + async def hook(self, data: T): + raise NotImplementedError class Client: @@ -18,11 +84,33 @@ def __init__( self, config: ClientConfig, user_id_file: str = os.path.expanduser("~/.huri_user_id"), - senders_dict: Dict[str, Type[ClientSender]] = get_senders(), ): self.config = config + + module_path, object_name = self.config.interface_path.split(":", 1) + + module = importlib.import_module(module_path) + interface = getattr(module, object_name) + + available_senders = interface.get_senders() + self.senders: List[ClientSender] = [ + available_senders[sender.name]( + topic=sender.topic, singletton=interface.singletton, **sender.args + ) + for sender in self.config.senders.values() + ] + + available_hooks = interface.get_hooks() + self.hooks: Dict[str, List[ClientHook]] = defaultdict(list) + for hook in self.config.hooks.values(): + for topic in hook.topics: + self.hooks[topic].append( + available_hooks[hook.name]( + singletton=interface.singletton, **hook.args + ) + ) + self.user_id_file = user_id_file - self.senders_dict = senders_dict def _load_user_id(self) -> Optional[str]: if os.path.exists(self.user_id_file): @@ -37,9 +125,22 @@ def _save_user_id(self, _user_id: str): async def _receive_loop(self, ws: websockets.ClientConnection): try: while True: - text = await ws.recv() - print("<<", text) - await asyncio.sleep(0.1) + msg = await ws.recv() + + if isinstance(msg, bytes): + topic_len = struct.unpack("!H", msg[:2])[0] + + topic = msg[2 : 2 + topic_len].decode() + data = msg[2 + topic_len :] + else: + event = json.loads(msg) + topic = event["topic"] + data = event["data"] + + for hook in self.hooks[topic]: + if not isinstance(data, bytes): + data = hook.input_type(**data) + asyncio.create_task(hook.hook(data)) except (asyncio.CancelledError, websockets.ConnectionClosedOK): pass @@ -50,11 +151,6 @@ async def run(self): self.config.user_id = self._load_user_id() - senders: List[ClientSender] = [ - self.senders_dict[config.name](ws=ws, **config.args) - for config in self.config.senders.values() - ] - await ws.send(json.dumps(asdict(self.config))) init_msg = json.loads(await ws.recv()) @@ -63,9 +159,9 @@ async def run(self): self._save_user_id(user_id) print(f"Session started with _user_id: {user_id}") - receive_task = asyncio.create_task(self._receive_loop(ws)) + receive_task = asyncio.create_task(self._receive_loop(ws=ws)) await asyncio.gather( - *(sender.input_loop() for sender in senders), + *(sender.input_loop(ws=ws) for sender in self.senders), ) receive_task.cancel() diff --git a/src/core/client_senders.py b/src/core/client_senders.py deleted file mode 100644 index 03301a6..0000000 --- a/src/core/client_senders.py +++ /dev/null @@ -1,102 +0,0 @@ -import asyncio -import json -import struct -from dataclasses import asdict -from typing import Dict, Type - -import numpy as np -import sounddevice as sd -import websockets -from prompt_toolkit import PromptSession -from prompt_toolkit.patch_stdout import patch_stdout - -from src.core.events import EventData -from src.modules.speech_to_text.events import Sentence - - -class ClientSender: - """This class abstract sending data to HuRI. - - output_type: is the topic that the ClientSender will send. - Data structure must match event topic. - - Class derived from ClientSender must implement input_loop, - and use ClientSender.send to send data to HuRI. It can be EventData or bytes - """ - - output_type: str - - def __init__(self, ws: websockets.ClientConnection): - self.ws = ws - - async def input_loop(self): - raise NotImplementedError - - async def send(self, topic: str, data: EventData | bytes): - packet: str | bytes - if isinstance(data, EventData): - packet = json.dumps({"topic": topic, "data": asdict(data)}) - else: - topic_bytes = topic.encode() - - packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data - - await self.ws.send(packet) - - -class AudioSender(ClientSender): - output_type = "audio" - - def __init__( - self, sample_rate: int = 16000, frame_duration: float = 0.030, **kwargs - ): - super().__init__(**kwargs) - - self.sample_rate = sample_rate - self.frame_size = int(sample_rate * frame_duration) - - async def input_loop(self): - loop = asyncio.get_running_loop() - - queue: asyncio.Queue[np.ndarray] = asyncio.Queue() - - def callback(indata: np.ndarray, frames, time, status): - loop.call_soon_threadsafe(queue.put_nowait, indata.copy()) - - with sd.InputStream( - samplerate=self.sample_rate, - channels=1, - dtype="int16", - callback=callback, - blocksize=self.frame_size, - ): - while True: - chunk = await queue.get() - await self.send(self.output_type, chunk.tobytes()) - - -class TextSender(ClientSender): - output_type = "question" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def input_loop(self): - print("'\\exit' or CTRL+D/C to exit.") - session: PromptSession = PromptSession() - try: - while True: - with patch_stdout(): - text = await session.prompt_async(">> ") - if text == "\\exit": - return - await self.send(self.output_type, Sentence(text)) - - except (EOFError, KeyboardInterrupt): - pass - finally: - print("TextSender Exited...") - - -def get_senders() -> Dict[str, Type[ClientSender]]: - return {"audio": AudioSender, "text": TextSender} diff --git a/src/core/dataclasses/config.py b/src/core/dataclasses/config.py index aea111f..f515026 100644 --- a/src/core/dataclasses/config.py +++ b/src/core/dataclasses/config.py @@ -15,15 +15,32 @@ def from_dict(self, raw: dict) -> "ModuleConfig": ) +@dataclass +class ClientHookConfig: + name: str + topics: List[str] + args: Mapping[str, Any] + + @classmethod + def from_dict(self, raw: dict) -> "ClientHookConfig": + return self( + name=raw["name"], + topics=raw["topics"], + args=raw.get("args", {}), + ) + + @dataclass class ClientSenderConfig: name: str + topic: str args: Mapping[str, Any] @classmethod def from_dict(self, raw: dict) -> "ClientSenderConfig": return self( name=raw["name"], + topic=raw["topic"], args=raw.get("args", {}), ) @@ -32,15 +49,20 @@ def from_dict(self, raw: dict) -> "ClientSenderConfig": class ClientConfig: user_id: Optional[str] huri_url: str - topic_list: List[str] + interface_path: str + hooks: Dict[str, ClientHookConfig] senders: Dict[str, ClientSenderConfig] modules: Dict[str, ModuleConfig] @classmethod def from_dict(cls, raw: Dict) -> "ClientConfig": + hooks = { + hook_id: ClientHookConfig.from_dict(hok_raw) + for hook_id, hok_raw in raw.get("hooks", {}).items() + } senders = { - sender_id: ClientSenderConfig.from_dict(mod_raw) - for sender_id, mod_raw in raw.get("senders", {}).items() + sender_id: ClientSenderConfig.from_dict(snd_raw) + for sender_id, snd_raw in raw.get("senders", {}).items() } modules = { module_id: ModuleConfig.from_dict(mod_raw) @@ -49,7 +71,8 @@ def from_dict(cls, raw: Dict) -> "ClientConfig": return cls( user_id=None, huri_url=raw["huri_url"], - topic_list=raw["topic_list"], + interface_path=raw["interface_path"], + hooks=hooks, senders=senders, modules=modules, ) diff --git a/src/core/huri.py b/src/core/huri.py index 5fa8038..f5e4eeb 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -32,7 +32,7 @@ def __init__( self, modules: Dict[str, Type[Module]], handles: Dict[str, handle.DeploymentHandle], - events: Dict[str, Type[EventData]], + events: Dict[str, Type[EventData | bytes]], ) -> None: self.module_factory = ModuleFactory(handles) self.event_factory = EventDataFactory() @@ -80,9 +80,12 @@ async def run_session(self, ws: WebSocket): user_id = client_config_raw.get("user_id") or str(uuid.uuid4()) - senders: List[Module] = [ - Sender(ws, topic) for topic in client_config.topic_list + topic_list = [ + topic + for hook_config in client_config.hooks.values() + for topic in hook_config.topics ] + senders: List[Module] = [Sender(ws, topic) for topic in topic_list] modules: List[Module] = ( self.module_factory.create_from_config(user_id, client_config.modules) + senders @@ -112,7 +115,7 @@ async def receive_loop(session: Session, ws: WebSocket): msg_text = msg["text"] event = json.loads(msg_text) topic = event["topic"] - data = event["data"] + data = event["data"] # TODO client/server one function data = self.event_factory.create(topic, data) diff --git a/src/core/interface.py b/src/core/interface.py new file mode 100644 index 0000000..fb2b7ac --- /dev/null +++ b/src/core/interface.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, Type + +from .client import ClientHook, ClientSender + + +class Interface: + """This class abstract defining specific Client senders and hooks. + + `self.singletton`: allow hooks to modifies shared ressources, + and comes from the used interface. + + Class derived from Interface must implement get_senders and get_hooks. + """ + + def __init__(self, singletton: Any): + self.singletton = singletton + + def get_senders(self) -> Dict[str, Type[ClientSender]]: + raise NotImplementedError + + def get_hooks(self) -> Dict[str, Type[ClientHook]]: + raise NotImplementedError diff --git a/src/interfaces/__init__.py b/src/interfaces/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/interfaces/cli_interface.py b/src/interfaces/cli_interface.py new file mode 100644 index 0000000..2634d07 --- /dev/null +++ b/src/interfaces/cli_interface.py @@ -0,0 +1,124 @@ +import asyncio +from typing import Dict, Type + +import numpy as np +import sounddevice as sd +from prompt_toolkit import PromptSession +from prompt_toolkit.patch_stdout import patch_stdout +from scipy.signal import resample + +from src.core.client import ClientHook, ClientSender +from src.core.interface import Interface +from src.modules.rag.events import RAGResult +from src.modules.speech_to_text.events import Sentence + + +class AudioSender(ClientSender[bytes]): + def __init__( + self, sample_rate: int = 16000, frame_duration: float = 0.030, **kwargs + ): + super().__init__(**kwargs) + + self.sample_rate = sample_rate + self.frame_size = int(sample_rate * frame_duration) + + async def input_loop(self, ws): + loop = asyncio.get_running_loop() + + queue: asyncio.Queue[np.ndarray] = asyncio.Queue() + + def callback(indata: np.ndarray, frames, time, status): + loop.call_soon_threadsafe(queue.put_nowait, indata.copy()) + + with sd.InputStream( + samplerate=self.sample_rate, + channels=1, + dtype="int16", + callback=callback, + blocksize=self.frame_size, + ): + while True: + chunk = await queue.get() + await self.send(ws, chunk.tobytes()) + + +class TextSender(ClientSender[Sentence]): + output_type = Sentence + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def input_loop(self, ws): + print("'\\exit' or CTRL+D/C to exit.") + session: PromptSession = PromptSession() + try: + while True: + with patch_stdout(): + text = await session.prompt_async(">> ") + if text == "\\exit": + return + await self.send(ws, Sentence(text)) + + except (EOFError, KeyboardInterrupt): + pass + finally: + print("TextSender Exited...") + + +class AudioHook(ClientHook[bytes]): + input_type = bytes + + def __init__(self, sample_rate=48000, incoming_sample_rate=16000, **kwargs): + super().__init__(**kwargs) + + print("Speaker:", sd.query_devices(kind="output")) + + self.incoming_sample_rate = incoming_sample_rate + self.sample_rate = sample_rate + self.stream = sd.OutputStream( + samplerate=sample_rate, + channels=1, + dtype="int16", + ) + self.stream.start() + + self.resample_function = ( + self._resample if sample_rate != incoming_sample_rate else lambda x: x + ) + + def _resample(self, audio: np.ndarray): + return resample( + audio, + int(len(audio) * self.sample_rate / self.incoming_sample_rate), + ).astype(np.int16) + + async def hook(self, data: bytes): + audio = np.frombuffer(data, dtype=np.int16) + + audio = self.resample_function(audio) + + self.stream.write(audio.reshape(-1, 1)) + + +class TextHook(ClientHook[RAGResult]): + input_type = RAGResult + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def hook(self, data: RAGResult): + print("<<", data.answer) + + +class CLIInterface(Interface): + def __init__(self): + super().__init__(singletton=None) + + def get_senders(self) -> Dict[str, Type[ClientSender]]: + return {"audio": AudioSender, "text": TextSender} + + def get_hooks(self) -> Dict[str, Type[ClientHook]]: + return {"audio": AudioHook, "text": TextHook} + + +cli_interface = CLIInterface() diff --git a/src/modules/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/speech_to_text.py index 1300dd3..63bf060 100644 --- a/src/modules/speech_to_text/speech_to_text.py +++ b/src/modules/speech_to_text/speech_to_text.py @@ -83,13 +83,15 @@ async def process(self, voice: Voice) -> Optional[Transcript]: self.pending_silence = False processing_audio = np.concatenate(processing_chunks, axis=0) - segments, _ = self.model_faster.transcribe( - processing_audio, - language=self.language, - beam_size=1, # faster for realtime - ) - - current_text = " ".join([seg.text for seg in segments]).strip() + def transcribe_text(): + segments, _ = self.model_faster.transcribe( + processing_audio, + language=self.language, + beam_size=1, + ) + return " ".join(seg.text for seg in segments).strip() + + current_text = await asyncio.to_thread(transcribe_text) processed_size = self.window_size - self.step_size async with self.lock: diff --git a/src/modules/utils/sender.py b/src/modules/utils/sender.py index f09b0ba..a9fc2fa 100644 --- a/src/modules/utils/sender.py +++ b/src/modules/utils/sender.py @@ -1,3 +1,4 @@ +import struct from dataclasses import asdict from fastapi import WebSocket @@ -23,8 +24,8 @@ def __init__(self, ws: WebSocket, type: str): async def process(self, data: EventData | bytes): if isinstance(data, bytes): - await self.ws.send_bytes(data) - elif isinstance(data, EventData): - await self.ws.send_json(asdict(data)) + topic_bytes = self.input_type.encode() + packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data + await self.ws.send_bytes(packet) else: - await self.ws.send_text(data) + await self.ws.send_json({"topic": self.input_type, "data": asdict(data)})