Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions config/client_aux2.yaml → config/client_aux .yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
huri_url: ws://localhost:8000/session

topic_list: [transcript, question, rag_response]
interface_path: src.interfaces.cli_interface:cli_interface

senders:
audio:
name: audio
topic: audio
args:
sample_rate: 16000
frame_duration: 0.030
text:
name: text
topic: question
args:
sample_rate: 16000
frame_duration: 0.030

hooks:
text:
name: text
topics: [question, answer]
audio:
name: audio
topics: [audio]
args:
incoming_sample_rate: ${senders.audio.args.sample_rate}
sample_rate: 44100

modules:
mic:
Expand All @@ -21,10 +39,8 @@ modules:
args:
language: en
block_duration: ${senders.audio.args.frame_duration}
logging: INFO
tag:
name: tag
logging: INFO
rag:
name: rag
args:
Expand Down
28 changes: 0 additions & 28 deletions config/client_aux.yaml

This file was deleted.

25 changes: 0 additions & 25 deletions config/client_auxio.yaml

This file was deleted.

21 changes: 18 additions & 3 deletions config/client_template.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
# HuRI websocket server url
huri_url: ws://localhost:8000/session

# List of event topic the client will receive
topic_list: [topic1, topic2]
# Define interface to be used's import path
interface_path: src.interfaces.cli_interface:cli_interface

# Define senders to be used and their custom args
senders:
# sender tag can be anything
example:
# sender name must be in the list of available ClientSender in Client instance (src.client_sender:get_senders)
# sender name must be in the list of available ClientSender in chosen Interface (Interface.get_senders)
name: my_sender
# topic the sender will send to HuRI, it must match output_type event data structure
topic: my_event
# if my_sender init with "model", "sample_rate" and "refresh_rate" params, they can be customized here
args:
refresh_rate: infinite

# Define hooks to be used and their custom args
hooks:
# hook tag can be anything
example:
# hook name must be in the list of available ClientHook in chosen Interface (Interface.get_senders)
name: my_hook
# topics the hook will process from HuRI, it must match input_type event data structure
topics: [my_event, llm_response]
# if my_hook init with "model", "sample_rate" and "refresh_rate" params, they can be customized here
args:
sample_rate: 0
no: beat

# Define module to be used and their custom args
modules:
# module tag can be anything
Expand Down
8 changes: 7 additions & 1 deletion config/client_text.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
huri_url: ws://localhost:8000/session

topic_list: [question, rag_response]
interface_path: src.interfaces.cli_interface:cli_interface

senders:
text:
name: text
topic: question

hooks:
text:
name: text
topics: [rag_response]

modules:
rag:
Expand Down
124 changes: 110 additions & 14 deletions src/core/client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,80 @@
import asyncio
import importlib
import json
import os
import struct
from collections import defaultdict
from dataclasses import asdict
from typing import Dict, List, Optional, Type
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar

import websockets

from src.core.dataclasses.config import ClientConfig
from src.core.events import EventData

from .client_senders import ClientSender, get_senders
T = TypeVar("T", bound=EventData | bytes)


class ClientSender(Generic[T]):
"""This class abstract sending data to HuRI.

output_type: is the event data structure that the ClientSender will send.
It can be EventData or bytes, and must match event topic it send.

Class derived from ClientSender must implement input_loop,
and use ClientSender.send to send data to HuRI.

`singletton` is available to access shared ressources.
"""

output_type: Type[T]

def __init__(self, topic: str, singletton: Any, **_):
"""
:topic: topic sent to HuRI
:singletton: allow to get shared ressources"""
self.topic = topic
self.singletton = singletton

async def input_loop(self, ws: websockets.ClientConnection):
raise NotImplementedError

async def _send_bytes(self, ws: websockets.ClientConnection, data: bytes):
topic_bytes = self.topic.encode()
packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data

await ws.send(packet)

async def _send_event_data(self, ws: websockets.ClientConnection, data: EventData):
packet = json.dumps({"topic": self.topic, "data": asdict(data)})

await ws.send(packet)

async def send(self, ws: websockets.ClientConnection, data: T):
if isinstance(data, bytes):
await self._send_bytes(ws, data)
else:
await self._send_event_data(ws, data)


class ClientHook(Generic[T]):
"""This class abstract processing data from HuRI.

input_type: is the event data structure that the ClientHook will process.
It can be EventData or bytes, and must match event topic it react to.

Class derived from ClientHook must implement hook.

`singletton` is available to access and modifies shared ressources.
"""

input_type: Type[T]

def __init__(self, singletton: Any, **_):
self.singletton = singletton

async def hook(self, data: T):
raise NotImplementedError


class Client:
Expand All @@ -18,11 +84,33 @@ def __init__(
self,
config: ClientConfig,
user_id_file: str = os.path.expanduser("~/.huri_user_id"),
senders_dict: Dict[str, Type[ClientSender]] = get_senders(),
):
self.config = config

module_path, object_name = self.config.interface_path.split(":", 1)

module = importlib.import_module(module_path)
interface = getattr(module, object_name)

available_senders = interface.get_senders()
self.senders: List[ClientSender] = [
available_senders[sender.name](
topic=sender.topic, singletton=interface.singletton, **sender.args
)
for sender in self.config.senders.values()
]

available_hooks = interface.get_hooks()
self.hooks: Dict[str, List[ClientHook]] = defaultdict(list)
for hook in self.config.hooks.values():
for topic in hook.topics:
self.hooks[topic].append(
available_hooks[hook.name](
singletton=interface.singletton, **hook.args
)
)

self.user_id_file = user_id_file
self.senders_dict = senders_dict

def _load_user_id(self) -> Optional[str]:
if os.path.exists(self.user_id_file):
Expand All @@ -37,9 +125,22 @@ def _save_user_id(self, _user_id: str):
async def _receive_loop(self, ws: websockets.ClientConnection):
try:
while True:
text = await ws.recv()
print("<<", text)
await asyncio.sleep(0.1)
msg = await ws.recv()

if isinstance(msg, bytes):
topic_len = struct.unpack("!H", msg[:2])[0]

topic = msg[2 : 2 + topic_len].decode()
data = msg[2 + topic_len :]
else:
event = json.loads(msg)
topic = event["topic"]
data = event["data"]

for hook in self.hooks[topic]:
if not isinstance(data, bytes):
data = hook.input_type(**data)
asyncio.create_task(hook.hook(data))

except (asyncio.CancelledError, websockets.ConnectionClosedOK):
pass
Expand All @@ -50,11 +151,6 @@ async def run(self):

self.config.user_id = self._load_user_id()

senders: List[ClientSender] = [
self.senders_dict[config.name](ws=ws, **config.args)
for config in self.config.senders.values()
]

await ws.send(json.dumps(asdict(self.config)))

init_msg = json.loads(await ws.recv())
Expand All @@ -63,9 +159,9 @@ async def run(self):
self._save_user_id(user_id)
print(f"Session started with _user_id: {user_id}")

receive_task = asyncio.create_task(self._receive_loop(ws))
receive_task = asyncio.create_task(self._receive_loop(ws=ws))
await asyncio.gather(
*(sender.input_loop() for sender in senders),
*(sender.input_loop(ws=ws) for sender in self.senders),
)

receive_task.cancel()
Loading