diff --git a/.gitignore b/.gitignore index 7892d64..92a7c35 100644 --- a/.gitignore +++ b/.gitignore @@ -176,6 +176,12 @@ cython_debug/ # PyPI configuration file .pypirc +# Helm +**/charts/*.tgz + # Others .trash -docs \ No newline at end of file +docs + +# HuRI client outputs +*.wav diff --git a/config/client_full.yaml b/config/client_full.yaml new file mode 100644 index 0000000..1481430 --- /dev/null +++ b/config/client_full.yaml @@ -0,0 +1,44 @@ +huri_url: ws://localhost:8000/session + +topic_list: [transcript, question, token, motion] + +senders: + audio: + name: audio + args: + sample_rate: 16000 + frame_duration: 0.030 + +modules: + mic: + name: mic + args: + vad_agressiveness: 3 + silence_duration: 1.5 + block_duration: ${senders.audio.args.frame_duration} + logging: INFO + stt: + name: stt + args: + language: en + block_duration: ${senders.audio.args.frame_duration} + logging: INFO + tag: + name: tag + logging: INFO + rag: + name: rag + args: + language: en + tone: formal + response_format: paragraph + max_length: 1024 + logging: INFO + tts: + name: tts + args: + min_clause_chars: 20 + logging: INFO + gesture: + name: gesture + logging: INFO diff --git a/config/huri.yaml b/config/huri.yaml deleted file mode 100644 index 70d2cc7..0000000 --- a/config/huri.yaml +++ /dev/null @@ -1,35 +0,0 @@ -proxy_location: EveryNode - -http_options: - host: 0.0.0.0 - port: 8000 - -logging_config: - encoding: TEXT - log_level: INFO - logs_dir: null - enable_access_log: true - additional_log_standard_attrs: [] - -services: - qdrant: - port: 6333 - image: "qdrant/qdrant:latest" - storage_volume: "qdrant_data" - ollama: - model: "mistral:7b" - image: "ollama/ollama:rocm" - gpu_devices: true - num_replicas: 1 - -applications: - - name: huri-app - route_prefix: / - import_path: src.app:app - runtime_env: { RAY_COLOR_PREFIX=1 } - deployments: - - name: HuRI - - name: RAGHandle - num_replicas: 2 - - name: OllamaService - - name: QdrantService diff --git a/deploy/Dockerfile.amd b/deploy/Dockerfile.amd new file mode 100644 index 0000000..786094b --- /dev/null +++ b/deploy/Dockerfile.amd @@ -0,0 +1,79 @@ +FROM rayproject/ray:2.55.1-py312 +WORKDIR /app +USER root +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + gnupg2 \ + && rm -rf /var/lib/apt/lists/* + +# Add ROCm 7.2.3 repository (Ubuntu 22.04 Jammy) +RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/keyrings/rocm.gpg \ + && printf 'deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/7.2.3 jammy main\ndeb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/7.2 jammy main\n' \ + > /etc/apt/sources.list.d/rocm.list \ + && printf 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600\n' \ + > /etc/apt/preferences.d/rocm-pin-600 + +# Install ROCm runtime + libraries +RUN apt-get update && apt-get install -y \ + rocm-hip-runtime \ + rocm-hip-libraries \ + rocm-device-libs \ + rocm-smi-lib \ + rocblas \ + hipblas \ + miopen-hip \ + rccl \ + rocsolver \ + rocfft \ + rocrand \ + hipsparse \ + && rm -rf /var/lib/apt/lists/* + +# Set ROCm environment +ENV ROCM_PATH=/opt/rocm +ENV PATH="${ROCM_PATH}/bin:${PATH}" +ENV LD_LIBRARY_PATH="${ROCM_PATH}/lib:${LD_LIBRARY_PATH}" + +USER ray + + +COPY serve_requirements.txt /app +RUN pip install --no-cache-dir -r serve_requirements.txt + +# 1. AMD's PyTorch built for ROCm (NOT the PyPI one — it's built for ROCm 6.2 and will silently break) +ARG ROCM_VERSION=7.2 +ARG PYTHON_VERSION=cp312 +ARG TRITON_VERSION=3.4.0+rocm7.2.0.git0cace8d2 + +RUN pip install --no-cache-dir \ + "https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/triton-${TRITON_VERSION}-${PYTHON_VERSION}-${PYTHON_VERSION}-linux_x86_64.whl" + +RUN pip install --no-cache-dir \ + --extra-index-url https://repo.radeon.com/rocm/pypi/ \ + "https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/torch-2.8.0+rocm${ROCM_VERSION}.0.lw.gitbf943426-${PYTHON_VERSION}-${PYTHON_VERSION}-linux_x86_64.whl" \ + "https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/torchaudio-2.8.0+rocm${ROCM_VERSION}.0.git6e1c7fe9-${PYTHON_VERSION}-${PYTHON_VERSION}-linux_x86_64.whl" + +RUN pip install --no-cache-dir filelock sympy networkx jinja2 fsspec numpy + +USER root + +# 3. Official CTranslate2 ROCm wheel (it's inside a zip on the releases page) +RUN apt-get update && apt-get install -y unzip curl \ + && curl -L https://github.com/OpenNMT/CTranslate2/releases/download/v4.7.1/rocm-python-wheels-Linux.zip \ + -o /tmp/ct2-rocm.zip \ + && unzip -j /tmp/ct2-rocm.zip 'temp-linux/ctranslate2-4.7.1-cp312-*manylinux*x86_64.whl' -d /tmp/ct2 \ + && pip install --no-cache-dir /tmp/ct2/ctranslate2-4.7.1-cp312-*.whl \ + && rm -rf /tmp/ct2 /tmp/ct2-rocm.zip + +USER ray + +# 4. faster-whisper +RUN pip install --no-cache-dir faster-whisper + +# 5. RAG / LLM extras (httpx, qdrant-client, sentence-transformers, …) +# Installed last so the ROCm torch wheel installed above is the resolved one. +COPY requirements-amd.txt /app +RUN pip install --no-cache-dir -r requirements-amd.txt + +COPY src /app/src diff --git a/deploy/Dockerfile.base b/deploy/Dockerfile.base new file mode 100644 index 0000000..f46ac35 --- /dev/null +++ b/deploy/Dockerfile.base @@ -0,0 +1,16 @@ +FROM rayproject/ray:2.55.1-py312 + +WORKDIR /app + + +USER root +RUN apt-get update && apt-get install -y \ +build-essential \ +&& rm -rf /var/lib/apt/lists/* + +USER ray + +COPY serve_requirements.txt /app +RUN pip install --no-cache-dir -r serve_requirements.txt + +COPY src /app/src diff --git a/deploy/Dockerfile.nvidia b/deploy/Dockerfile.nvidia new file mode 100644 index 0000000..f0e6e4d --- /dev/null +++ b/deploy/Dockerfile.nvidia @@ -0,0 +1,31 @@ +FROM rayproject/ray:2.55.1-py312-gpu + +WORKDIR /app + +# Full CUDA 12.1 dependency stack (CosyVoice2, faster-whisper, TensorRT, …). +# PyTorch cu121 wheels live on the PyTorch index; TensorRT wheels on the NGC index. +COPY requirements-nvidia.txt /app +RUN pip install --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu121 \ + --extra-index-url https://pypi.ngc.nvidia.com \ + -r requirements-nvidia.txt + +COPY serve_requirements.txt /app +RUN pip install --no-cache-dir -r serve_requirements.txt + +USER root + +RUN apt-get update && apt-get install -y --no-install-recommends git \ + && rm -rf /var/lib/apt/lists/* + +USER ray + +# CosyVoice2 has no setup.py/pyproject.toml so it cannot be pip-installed. +# Clone at a pinned commit for supply-chain integrity and expose it via PYTHONPATH. +RUN git clone https://github.com/FunAudioLLM/CosyVoice.git /app/cosyvoice \ + && git -C /app/cosyvoice checkout 074ca6dc9e80a2f424f1f74b48bdd7d3fea531cc \ + && git -C /app/cosyvoice submodule update --init --recursive + +ENV PYTHONPATH="/app/cosyvoice:${PYTHONPATH:-}" + +COPY src /app/src diff --git a/deploy/examples/local_nvidia_amd/Chart.lock b/deploy/examples/local_nvidia_amd/Chart.lock new file mode 100644 index 0000000..b0e0b95 --- /dev/null +++ b/deploy/examples/local_nvidia_amd/Chart.lock @@ -0,0 +1,6 @@ +dependencies: +- name: kuberay-operator + repository: https://ray-project.github.io/kuberay-helm/ + version: 1.6.0 +digest: sha256:b9057481d9a5e2d8b8798488b0b321bbd3f6e43dcb5a9dea18b181641a63b400 +generated: "2026-05-22T17:28:37.5934885+02:00" diff --git a/deploy/examples/local_nvidia_amd/Chart.yaml b/deploy/examples/local_nvidia_amd/Chart.yaml new file mode 100644 index 0000000..c51091a --- /dev/null +++ b/deploy/examples/local_nvidia_amd/Chart.yaml @@ -0,0 +1,20 @@ +apiVersion: v2 +name: huri +description: HuRI service powered by Ray Serve on KubeRay +type: application +version: 0.1.0 +appVersion: "2.52.0" +keywords: + - ray + - kuberay + - ray-serve + - robotics + - hri +maintainers: + - name: Sentience Robotics + +dependencies: + - name: kuberay-operator + version: "1.6.0" + repository: "https://ray-project.github.io/kuberay-helm/" + condition: kuberay.install diff --git a/deploy/examples/local_nvidia_amd/templates/_helpers.tpl b/deploy/examples/local_nvidia_amd/templates/_helpers.tpl new file mode 100644 index 0000000..cb431d4 --- /dev/null +++ b/deploy/examples/local_nvidia_amd/templates/_helpers.tpl @@ -0,0 +1,72 @@ +{{/* +Expand the name of the chart. +*/}} +{{- define "huri.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully-qualified app name. +Truncated at 63 chars because some Kubernetes name fields have this limit. +*/}} +{{- define "huri.fullname" -}} +{{- if .Values.fullnameOverride }} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- $name := default .Chart.Name .Values.nameOverride }} +{{- if contains $name .Release.Name }} +{{- .Release.Name | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} +{{- end }} +{{- end }} +{{- end }} + +{{/* +Chart label: -. +*/}} +{{- define "huri.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels applied to every resource. +*/}} +{{- define "huri.labels" -}} +helm.sh/chart: {{ include "huri.chart" . }} +{{ include "huri.selectorLabels" . }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels (used in matchLabels / ingress backends). +*/}} +{{- define "huri.selectorLabels" -}} +app.kubernetes.io/name: {{ include "huri.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} + +{{/* +Name of the KubeRay-managed serve service. +KubeRay appends "-serve-svc" to the RayService name. +*/}} +{{- define "huri.serveSvcName" -}} +{{- printf "%s-serve-svc" (include "huri.fullname" .) }} +{{- end }} + +{{/* +Name of the KubeRay-managed head service. +KubeRay appends "-head-svc" to the RayService name. +*/}} +{{- define "huri.headSvcName" -}} +{{- printf "%s-head-svc" (include "huri.fullname" .) }} +{{- end }} + +{{/* +Name of the stable dashboard service managed by this chart. +Selects the head pod via stable labels, avoiding KubeRay's random-suffix service. +*/}} +{{- define "huri.headDashboardSvcName" -}} +{{- printf "%s-head-dashboard-svc" (include "huri.fullname" .) }} +{{- end }} diff --git a/deploy/examples/local_nvidia_amd/templates/cosytts-model-init-job.yaml b/deploy/examples/local_nvidia_amd/templates/cosytts-model-init-job.yaml new file mode 100644 index 0000000..b0e3bc7 --- /dev/null +++ b/deploy/examples/local_nvidia_amd/templates/cosytts-model-init-job.yaml @@ -0,0 +1,103 @@ +{{- if .Values.models.cosytts.enabled }} +{{- $model := .Values.models.cosytts }} +{{- $pvcName := printf "%s-cosytts-models" (include "huri.fullname" .) }} +{{- if not (lookup "v1" "PersistentVolumeClaim" .Release.Namespace $pvcName) }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: {{ $pvcName }} + labels: + {{- include "huri.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-10" + "helm.sh/resource-policy": keep +spec: + accessModes: + {{- toYaml $model.pvc.accessModes | nindent 4 }} + resources: + requests: + storage: {{ $model.pvc.size }} + {{- if $model.pvc.storageClassName }} + storageClassName: {{ $model.pvc.storageClassName }} + {{- end }} +{{- end }} +--- +# Runs on every install/upgrade (pre-* hook) but exits early once the model dir +# already looks complete — so a no-op upgrade costs only a pod spin-up. Set +# $model.forceDownload=true to wipe and re-fetch (e.g. when changing versions). +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ include "huri.fullname" . }}-cosytts-init + labels: + {{- include "huri.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-5" + "helm.sh/hook-delete-policy": hook-succeeded,before-hook-creation +spec: + backoffLimit: 3 + template: + metadata: + labels: + {{- include "huri.selectorLabels" . | nindent 8 }} + spec: + restartPolicy: OnFailure + {{- with $model.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + volumes: + - name: models + persistentVolumeClaim: + claimName: {{ include "huri.fullname" . }}-cosytts-models + containers: + - name: cosytts-downloader + image: python:3.11-slim + command: ["/bin/sh", "-c"] + args: + - | + set -e + MODEL_DIR="{{ $model.mountPath }}/{{ $model.modelSource.modelId }}" + {{- if $model.forceDownload }} + echo "forceDownload=true — wiping $MODEL_DIR for a fresh download." + rm -rf "$MODEL_DIR" + {{- end }} + # CosyVoice3 ships cosyvoice3.yaml + llm.pt (the Qwen LM) plus a + # bundled CosyVoice-BlankEN/ dir, so a single snapshot grabs + # everything the worker needs — no separate sub-model download. + HAS_CONFIG="no" + HAS_LLM="no" + if [ -f "$MODEL_DIR/cosyvoice3.yaml" ]; then + HAS_CONFIG="yes" + fi + if [ -f "$MODEL_DIR/llm.pt" ]; then + HAS_LLM="yes" + fi + if [ "$HAS_CONFIG" = "yes" ] && [ "$HAS_LLM" = "yes" ]; then + echo "Model already present at $MODEL_DIR — skipping download." + exit 0 + fi + echo "Downloading {{ $model.modelSource.modelId }} into $MODEL_DIR …" + pip install --quiet modelscope + python - <<'PYEOF' + from modelscope import snapshot_download + snapshot_download( + "{{ $model.modelSource.modelId }}", + local_dir="{{ $model.mountPath }}/{{ $model.modelSource.modelId }}", + ) + PYEOF + echo "Download complete." + volumeMounts: + - name: models + mountPath: {{ $model.mountPath }} + resources: + requests: + cpu: "500m" + memory: "512Mi" + limits: + cpu: "2" + memory: "2Gi" +{{- end }} diff --git a/deploy/examples/local_nvidia_amd/templates/emage-model-init-job.yaml b/deploy/examples/local_nvidia_amd/templates/emage-model-init-job.yaml new file mode 100644 index 0000000..c4c4a08 --- /dev/null +++ b/deploy/examples/local_nvidia_amd/templates/emage-model-init-job.yaml @@ -0,0 +1,86 @@ +{{- if .Values.models.emage.enabled }} +{{- $model := .Values.models.emage }} +{{- $pvcName := printf "%s-emage-models" (include "huri.fullname" .) }} +{{- if not (lookup "v1" "PersistentVolumeClaim" .Release.Namespace $pvcName) }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: {{ $pvcName }} + labels: + {{- include "huri.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-10" + "helm.sh/resource-policy": keep +spec: + accessModes: + {{- toYaml $model.pvc.accessModes | nindent 4 }} + resources: + requests: + storage: {{ $model.pvc.size }} + {{- if $model.pvc.storageClassName }} + storageClassName: {{ $model.pvc.storageClassName }} + {{- end }} +{{- end }} +--- +# Runs only on first install (not on upgrade) — models are already on the PVC. +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ include "huri.fullname" . }}-emage-init + labels: + {{- include "huri.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-5" + "helm.sh/hook-delete-policy": hook-succeeded,before-hook-creation +spec: + backoffLimit: 3 + template: + metadata: + labels: + {{- include "huri.selectorLabels" . | nindent 8 }} + spec: + restartPolicy: OnFailure + {{- with $model.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + volumes: + - name: models + persistentVolumeClaim: + claimName: {{ include "huri.fullname" . }}-emage-models + containers: + - name: emage-downloader + image: python:3.11-slim + command: ["/bin/sh", "-c"] + args: + - | + set -e + MODEL_DIR="{{ $model.mountPath }}/{{ $model.modelSource.repoId }}" + if [ -d "$MODEL_DIR" ] && [ "$(ls -A $MODEL_DIR 2>/dev/null)" ]; then + echo "Model already present at $MODEL_DIR — skipping download." + exit 0 + fi + echo "Downloading {{ $model.modelSource.repoId }} into $MODEL_DIR …" + pip install --quiet huggingface_hub + python - <<'PYEOF' + from huggingface_hub import snapshot_download + snapshot_download( + "{{ $model.modelSource.repoId }}", + local_dir="{{ $model.mountPath }}/{{ $model.modelSource.repoId }}", + ) + PYEOF + echo "Download complete." + volumeMounts: + - name: models + mountPath: {{ $model.mountPath }} + resources: + requests: + cpu: "500m" + memory: "512Mi" + limits: + cpu: "2" + memory: "2Gi" +{{- end }} diff --git a/deploy/examples/local_nvidia_amd/templates/head-dashboard-svc.yaml b/deploy/examples/local_nvidia_amd/templates/head-dashboard-svc.yaml new file mode 100644 index 0000000..9b47431 --- /dev/null +++ b/deploy/examples/local_nvidia_amd/templates/head-dashboard-svc.yaml @@ -0,0 +1,16 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "huri.fullname" . }}-head-dashboard-svc + labels: + {{- include "huri.labels" . | nindent 4 }} +spec: + type: ClusterIP + selector: + ray.io/node-type: head + ray.io/group: headgroup + app.kubernetes.io/instance: {{ .Release.Name }} + ports: + - name: dashboard + port: 8265 + targetPort: 8265 diff --git a/deploy/examples/local_nvidia_amd/templates/ingress-dashboard.yaml b/deploy/examples/local_nvidia_amd/templates/ingress-dashboard.yaml new file mode 100644 index 0000000..dc91299 --- /dev/null +++ b/deploy/examples/local_nvidia_amd/templates/ingress-dashboard.yaml @@ -0,0 +1,31 @@ +{{- if .Values.dashboard.ingress.enabled }} +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: {{ include "huri.fullname" . }}-dashboard + labels: + {{- include "huri.labels" . | nindent 4 }} + {{- with .Values.dashboard.ingress.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + {{- if .Values.dashboard.ingress.className }} + ingressClassName: {{ .Values.dashboard.ingress.className }} + {{- end }} + {{- with .Values.dashboard.ingress.tls }} + tls: + {{- toYaml . | nindent 4 }} + {{- end }} + rules: + - host: {{ .Values.dashboard.ingress.host | quote }} + http: + paths: + - path: / + pathType: Prefix + backend: + service: + name: {{ include "huri.headDashboardSvcName" . }} + port: + number: 8265 +{{- end }} diff --git a/deploy/examples/local_nvidia_amd/templates/ingress.yaml b/deploy/examples/local_nvidia_amd/templates/ingress.yaml new file mode 100644 index 0000000..85f94de --- /dev/null +++ b/deploy/examples/local_nvidia_amd/templates/ingress.yaml @@ -0,0 +1,32 @@ +{{- if .Values.ingress.enabled }} +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: {{ include "huri.fullname" . }} + labels: + {{- include "huri.labels" . | nindent 4 }} + {{- with .Values.ingress.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + {{- if .Values.ingress.className }} + ingressClassName: {{ .Values.ingress.className }} + {{- end }} + {{- with .Values.ingress.tls }} + tls: + {{- toYaml . | nindent 4 }} + {{- end }} + rules: + - host: {{ .Values.ingress.host | quote }} + http: + paths: + - path: / + pathType: Prefix + backend: + service: + # KubeRay creates -serve-svc for the Serve endpoint. + name: {{ include "huri.serveSvcName" . }} + port: + number: 8000 +{{- end }} diff --git a/deploy/examples/local_nvidia_amd/templates/rayservice.yaml b/deploy/examples/local_nvidia_amd/templates/rayservice.yaml new file mode 100644 index 0000000..1922d2e --- /dev/null +++ b/deploy/examples/local_nvidia_amd/templates/rayservice.yaml @@ -0,0 +1,190 @@ +apiVersion: ray.io/v1 +kind: RayService +metadata: + name: {{ include "huri.fullname" . }} + labels: + {{- include "huri.labels" . | nindent 4 }} + annotations: + ray.io/initializing-timeout: "20m" +spec: + serveConfigV2: | +{{ .Values.ray.serveConfig | indent 4 }} + rayClusterConfig: + rayVersion: {{ .Values.ray.version | quote }} + + headGroupSpec: + serviceType: {{ .Values.head.serviceType }} + rayStartParams: + {{- toYaml .Values.head.rayStartParams | nindent 8 }} + + template: + metadata: + labels: + {{- include "huri.selectorLabels" . | nindent 12 }} + spec: + {{- with .Values.head.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- with .Values.head.affinity }} + affinity: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- with .Values.head.tolerations }} + tolerations: + {{- toYaml . | nindent 12 }} + {{- end }} + containers: + - name: ray-head + image: {{ .Values.image.repository }}:{{ .Values.image.tag }} + imagePullPolicy: {{ .Values.image.pullPolicy }} + ports: + - containerPort: 6379 + name: gcs-server + - containerPort: 8265 + name: dashboard + - containerPort: 10001 + name: client + - containerPort: 8000 + name: serve + resources: + {{- toYaml .Values.head.resources | nindent 16 }} + + workerGroupSpecs: + {{- range .Values.workerGroups }} + {{- $group := . }} + - replicas: {{ .replicas }} + minReplicas: {{ .minReplicas }} + maxReplicas: {{ .maxReplicas }} + groupName: {{ .groupName | quote }} + rayStartParams: + {{- toYaml .rayStartParams | nindent 10 }} + {{- if .customResources }} + resources: {{ .customResources | squote }} + {{- end }} + + template: + metadata: + labels: + {{- include "huri.selectorLabels" $ | nindent 14 }} + spec: + {{- if .hostIPC }} + hostIPC: {{ .hostIPC }} + {{- end }} + {{- if .runtimeClassName }} + # Tells containerd to invoke the nvidia/amd runtime, which mounts + # the GPU devices into the pod. Required on WSL2 + k3s. + runtimeClassName: {{ .runtimeClassName }} + {{- end }} + {{- if .podSecurityContext }} + # Pod-level security: supplementalGroups, fsGroup, etc. + securityContext: + {{- toYaml .podSecurityContext | nindent 14 }} + {{- end }} + {{- with .nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 14 }} + {{- end }} + {{- with .affinity }} + affinity: + {{- toYaml . | nindent 14 }} + {{- end }} + {{- with .tolerations }} + tolerations: + {{- toYaml . | nindent 14 }} + {{- end }} + volumes: + - name: dshm + emptyDir: + medium: Memory + sizeLimit: {{ .shmSize | default "1Gi" }} + {{- if .cudaCacheHostPath }} + - name: cuda-cache + hostPath: + path: {{ .cudaCacheHostPath }} + type: DirectoryOrCreate + {{- end }} + {{- range .mountedModels }} + {{- $model := index $.Values.models . }} + {{- if $model.enabled }} + - name: model-{{ . }} + persistentVolumeClaim: + claimName: {{ include "huri.fullname" $ }}-{{ . }}-models + {{- end }} + {{- end }} + {{- if and $.Values.voiceAssets.enabled .mountVoiceAssets }} + - name: voice-assets + persistentVolumeClaim: + claimName: {{ include "huri.fullname" $ }}-voice-assets + {{- end }} + containers: + - name: ray-worker + {{- if .image }} + image: {{ .image }} + {{- else }} + image: {{ $.Values.image.repository }}:{{ $.Values.image.tag }} + {{- end }} + imagePullPolicy: {{ $.Values.image.pullPolicy }} + readinessProbe: + exec: + command: + - bash + - -c + - wget --tries 1 -T 2 -q -O- http://localhost:52365/api/local_raylet_healthz + | grep success + initialDelaySeconds: 10 + periodSeconds: 5 + failureThreshold: 10 + timeoutSeconds: 5 + env: + {{- $hasEnv := false }} + {{- if .containerEnv }} + {{- $hasEnv = true }} + {{- toYaml .containerEnv | nindent 18 }} + {{- end }} + {{- range .mountedModels }} + {{- $model := index $.Values.models . }} + {{- if and $model.enabled $model.env }} + {{- $hasEnv = true }} + {{- range $envKey, $envVal := $model.env }} + - name: {{ $envKey }} + value: {{ $envVal | quote }} + {{- end }} + {{- end }} + {{- end }} + {{- if and $.Values.voiceAssets.enabled $group.mountVoiceAssets }} + {{- range $envKey, $envVal := $.Values.voiceAssets.env }} + {{- $hasEnv = true }} + - name: {{ $envKey }} + value: {{ $envVal | quote }} + {{- end }} + {{- end }} + {{- if not $hasEnv }} + [] + {{- end }} + {{- if .securityContext }} + # Container-level security: seLinuxOptions, capabilities, etc. + securityContext: + {{- toYaml .securityContext | nindent 18 }} + {{- end }} + resources: + {{- toYaml .resources | nindent 18 }} + volumeMounts: + - name: dshm + mountPath: /dev/shm + {{- if .cudaCacheHostPath }} + - name: cuda-cache + mountPath: /home/ray/.nv + {{- end }} + {{- range .mountedModels }} + {{- $model := index $.Values.models . }} + {{- if $model.enabled }} + - name: model-{{ . }} + mountPath: {{ $model.mountPath }} + {{- end }} + {{- end }} + {{- if and $.Values.voiceAssets.enabled .mountVoiceAssets }} + - name: voice-assets + mountPath: {{ $.Values.voiceAssets.mountPath }} + {{- end }} + {{- end }} diff --git a/deploy/examples/local_nvidia_amd/templates/voice-assets-pvc.yaml b/deploy/examples/local_nvidia_amd/templates/voice-assets-pvc.yaml new file mode 100644 index 0000000..6456e8e --- /dev/null +++ b/deploy/examples/local_nvidia_amd/templates/voice-assets-pvc.yaml @@ -0,0 +1,28 @@ +{{- if .Values.voiceAssets.enabled }} +{{- $pvcName := printf "%s-voice-assets" (include "huri.fullname" .) }} +{{- if not (lookup "v1" "PersistentVolumeClaim" .Release.Namespace $pvcName) }} +--- +# PVC for the voice sample used by the TTS module (HURI_VOICE_SAMPLE_PATH). +# Populate after first install with: +# kubectl cp voice.wav :{{ .Values.voiceAssets.mountPath }}/voice.wav +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: {{ $pvcName }} + labels: + {{- include "huri.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-10" + "helm.sh/resource-policy": keep +spec: + accessModes: + {{- toYaml .Values.voiceAssets.pvc.accessModes | nindent 4 }} + resources: + requests: + storage: {{ .Values.voiceAssets.pvc.size }} + {{- if .Values.voiceAssets.pvc.storageClassName }} + storageClassName: {{ .Values.voiceAssets.pvc.storageClassName }} + {{- end }} +{{- end }} +{{- end }} diff --git a/deploy/examples/local_nvidia_amd/templates/whisper-model-init-job.yaml b/deploy/examples/local_nvidia_amd/templates/whisper-model-init-job.yaml new file mode 100644 index 0000000..879d908 --- /dev/null +++ b/deploy/examples/local_nvidia_amd/templates/whisper-model-init-job.yaml @@ -0,0 +1,86 @@ +{{- if .Values.models.whisper.enabled }} +{{- $model := .Values.models.whisper }} +{{- $pvcName := printf "%s-whisper-models" (include "huri.fullname" .) }} +{{- if not (lookup "v1" "PersistentVolumeClaim" .Release.Namespace $pvcName) }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: {{ $pvcName }} + labels: + {{- include "huri.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-10" + "helm.sh/resource-policy": keep +spec: + accessModes: + {{- toYaml $model.pvc.accessModes | nindent 4 }} + resources: + requests: + storage: {{ $model.pvc.size }} + {{- if $model.pvc.storageClassName }} + storageClassName: {{ $model.pvc.storageClassName }} + {{- end }} +{{- end }} +--- +# Runs only on first install (not on upgrade) — model is already on the PVC. +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ include "huri.fullname" . }}-whisper-init + labels: + {{- include "huri.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-5" + "helm.sh/hook-delete-policy": hook-succeeded,before-hook-creation +spec: + backoffLimit: 3 + template: + metadata: + labels: + {{- include "huri.selectorLabels" . | nindent 8 }} + spec: + restartPolicy: OnFailure + {{- with $model.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + volumes: + - name: models + persistentVolumeClaim: + claimName: {{ include "huri.fullname" . }}-whisper-models + containers: + - name: whisper-downloader + image: python:3.11-slim + command: ["/bin/sh", "-c"] + args: + - | + set -e + MODEL_DIR="{{ $model.mountPath }}/{{ $model.modelSource.repoId }}" + if [ -f "$MODEL_DIR/model.bin" ]; then + echo "Model already present at $MODEL_DIR — skipping download." + exit 0 + fi + echo "Downloading {{ $model.modelSource.repoId }} into $MODEL_DIR …" + pip install --quiet huggingface_hub + python - <<'PYEOF' + from huggingface_hub import snapshot_download + snapshot_download( + "{{ $model.modelSource.repoId }}", + local_dir="{{ $model.mountPath }}/{{ $model.modelSource.repoId }}", + ) + PYEOF + echo "Download complete." + volumeMounts: + - name: models + mountPath: {{ $model.mountPath }} + resources: + requests: + cpu: "500m" + memory: "512Mi" + limits: + cpu: "2" + memory: "1Gi" +{{- end }} diff --git a/deploy/examples/local_nvidia_amd/values.yaml b/deploy/examples/local_nvidia_amd/values.yaml new file mode 100644 index 0000000..e980c57 --- /dev/null +++ b/deploy/examples/local_nvidia_amd/values.yaml @@ -0,0 +1,360 @@ +nameOverride: "" +fullnameOverride: "" + +image: + repository: docker.pommier.dev/huri + tag: base-2.55.1 + pullPolicy: Always + +ray: + version: "2.55.1" + # Inline Ray Serve config (equivalent to config/huri.yaml). + # Change import_path or add deployments here without rebuilding the image. + # + # To pin a deployment to a specific GPU vendor, set resources in + # ray_actor_options matching the resources string in the desired workerGroup's + # rayStartParams: + # + # ray_actor_options: + # num_gpus: 1 + # resources: {"GPU_TYPE_NVIDIA": 1} # → runs only on gpu-nvidia workers + # resources: {"GPU_TYPE_AMD": 1} # → runs only on gpu-amd workers + serveConfig: | + proxy_location: EveryNode + http_options: + host: 0.0.0.0 + port: 8000 + applications: + - name: huri-app + route_prefix: / + import_path: src.app:app + runtime_env: + env_vars: + RAY_COLOR_PREFIX: "1" + # Gesture sliding-window defaults. The Gesture *module* runs in the + # HuRI (CPU) actor, so these live app-wide here rather than on the + # nvidia worker. context_sec primes EMAGE for continuity across + # audio chunks; min_chunk_sec coalesces tiny TTS chunks so fewer, + # bounded inferences run (lower latency, smoother motion). Can be + # overridden per session via the gesture module `args` in the + # client config. + HURI_GESTURE_CONTEXT_SEC: "2.0" + HURI_GESTURE_MIN_CHUNK_SEC: "0.5" + # CosyVoice3 contract: "<|endofprompt|>". + # The reference transcript MUST come AFTER the marker, or the LM treats it as + # an instruction and intermittently speaks it (prompt leakage). + HURI_VOICE_TRANSCRIPT: "You are a helpful assistant.<|endofprompt|>Instinct creates its own oppressors and bids us rise up against them." + HURI_STT_MODEL_PATH: /models/whisper/Systran/faster-whisper-base + deployments: + # HuRI: FastAPI/WebSocket ingress + per-session router. CPU only — + # all GPU work is offloaded to handle-backed deployments below. + - name: HuRI + ray_actor_options: + num_cpus: 1 + num_gpus: 0 + # STT: shared faster-whisper actor, pinned to AMD. + - name: STT + num_replicas: 1 + ray_actor_options: + num_cpus: 1 + num_gpus: 0.5 + resources: {"GPU_TYPE_AMD": 0.5} + # RAG: embeddings (API) + LLM client. Pinned to AMD for its dependencies, no GPU needed. + - name: RAGHandle + num_replicas: 1 + ray_actor_options: + num_cpus: 1 + num_gpus: 0 + resources: {"GPU_TYPE_AMD": 0.001} + user_config: + qdrant_url: "https://qdrant.pommier.lan" + llm_url: "https://llm.huri.lan" + embedding_url: "https://embedding.huri.lan" + embedding_model: "bge-large-en-v1.5-gguf-Q4_K_M" + llm_provider: "vllm" + llm_model: "Qwen3.5-4B-GGUF" + verify_ssl: false + # GPU split (manual override knob): TTS and Gesture share one NVIDIA + # GPU. num_gpus/resources are Ray *scheduling* fractions — they let + # both replicas pack onto the same device and bias the split. Audio + # (TTS) gets the lion's share so streamed speech stays low-latency; + # gesture is given the remainder. Tune these two pairs together so + # they sum to <= 1.0. To also cap gesture's actual VRAM allocation, + # set HURI_GESTURE_GPU_MEM_FRACTION on the nvidia worker (see below). + - name: TTS + ray_actor_options: + num_cpus: 1 + num_gpus: 0.8 + resources: {"GPU_TYPE_NVIDIA": 0.8} + - name: GestureGeneration + ray_actor_options: + num_cpus: 1 + num_gpus: 0.2 + resources: {"GPU_TYPE_NVIDIA": 0.2} + +head: + # ClusterIP is preferred on real clusters; use NodePort for kind/minikube/k3s. + serviceType: ClusterIP + # Pin the head to the control-plane node so it does not consume GPU memory. + # Set to {} to let the scheduler decide. + nodeSelector: + node-role.kubernetes.io/control-plane: "true" + # affinity: {} # optional Kubernetes affinity rules (nodeAffinity, podAffinity…) + tolerations: + # Required to schedule on the control-plane node (tainted by default in k3s). + - key: node-role.kubernetes.io/control-plane + operator: Exists + effect: NoSchedule + rayStartParams: + num-cpus: "2" + num-gpus: "0" + resources: + limits: + cpu: "2" + memory: "3Gi" + requests: + cpu: "2" + memory: "2Gi" + +# Worker groups: one logical "gpu" group per vendor and one "cpu" group. +# +# Node placement (Kubernetes level) +# ────────────────────────────────── +# nodeSelector / affinity / tolerations pin the *pod* to specific physical nodes. +# +# Ray deployment routing (Ray level) +# ──────────────────────────────────── +# rayStartParams.resources advertises named custom resources to the Ray +# scheduler. Ray Serve deployments request them via ray_actor_options.resources +# to run exclusively on a given worker group regardless of GPU vendor. +# +# Example — to route a deployment to NVIDIA workers only: +# workerGroups[gpu-nvidia].customResources: '{"GPU_TYPE_NVIDIA":1}' +# serveConfig deployment ray_actor_options.resources: {"GPU_TYPE_NVIDIA": 1} +# +# Model volumes +# ────────────── +# mountedModels lists keys from .Values.models. For each enabled model the +# template adds a PVC volume + volumeMount + env vars to the worker pod. +workerGroups: + # --- GPU workers (Nvidia) --- + - groupName: gpu-nvidia + image: docker.pommier.dev/huri:nvidia-2.55.1 + replicas: 1 + minReplicas: 1 + maxReplicas: 1 + mountVoiceAssets: true + cudaCacheHostPath: /var/cache/huri/cuda-nvidia + nodeSelector: + gpu: nvidia + # affinity: {} # optional + tolerations: [] + runtimeClassName: nvidia + customResources: '{\"GPU_TYPE_NVIDIA\":1}' + rayStartParams: + num-gpus: "1" + containerEnv: + - name: NVIDIA_VISIBLE_DEVICES + value: "all" + - name: NVIDIA_DRIVER_CAPABILITIES + value: "compute,utility" + - name: HF_HUB_DOWNLOAD_TIMEOUT + value: "10" + # Manual GPU split for gesture (see GestureGeneration above). Caps the + # EMAGE process to a fraction of GPU memory so TTS keeps the rest. "0" + # disables the cap. Keep roughly in line with its num_gpus fraction. + # Read by GestureGeneration, which runs in this worker group. + - name: HURI_GESTURE_GPU_MEM_FRACTION + value: "0.2" + # Models whose PVC will be mounted in this worker group. + # Keys must match entries under .Values.models. + mountedModels: + - cosytts + - emage + resources: + limits: + cpu: "4" + memory: "14Gi" + nvidia.com/gpu: "1" + requests: + cpu: "2" + memory: "8Gi" + nvidia.com/gpu: "1" + shmSize: 2Gi + + # --- GPU workers (AMD) --- + - groupName: gpu-amd + image: docker.pommier.dev/huri:amd-2.55.1 + replicas: 1 + minReplicas: 1 + maxReplicas: 1 + mountVoiceAssets: false + nodeSelector: + gpu: amd + # affinity: {} # optional + tolerations: [] + hostIPC: true + customResources: '{\"GPU_TYPE_AMD\":1}' + podSecurityContext: + supplementalGroups: [39, 107] + rayStartParams: + num-gpus: "1" + containerEnv: + - name: HSA_OVERRIDE_GFX_VERSION + value: "11.5.1" + - name: HF_HUB_DOWNLOAD_TIMEOUT + value: "10" + securityContext: + seLinuxOptions: + type: "spc_t" + # privileged: true # Uncomment if spc_t still gets blocked by Fedora + mountedModels: + - whisper + resources: + limits: + cpu: "4" + memory: "8Gi" + amd.com/gpu: "1" + requests: + cpu: "2" + memory: "4Gi" + amd.com/gpu: "1" + + # --- CPU-only workers --- + # Handles tasks that do not need a GPU (pre/post-processing, routing, etc.). + # Set replicas: 0 to disable this group entirely. + - groupName: cpu-workers + replicas: 0 + minReplicas: 0 + maxReplicas: 0 + nodeSelector: {} + tolerations: [] + rayStartParams: + num-cpus: "2" + mountedModels: [] + resources: + limits: + cpu: "2" + memory: "4Gi" + requests: + cpu: "1" + memory: "2Gi" + shmSize: 256Mi + +# AI model definitions. +# Each key corresponds to a model that can be mounted into worker groups via +# mountedModels. The chart creates one PVC per enabled model and a pre-install +# Job that downloads the weights on first deploy (idempotent). +models: + cosytts: + enabled: true + nodeSelector: + gpu: nvidia + pvc: + # Leave storageClassName empty to use the cluster default StorageClass. + storageClassName: "" + size: 20Gi + # Use ReadWriteMany if the PVC must be shared across multiple worker pods + # (requires a CSI driver that supports RWX, e.g. NFS or Longhorn). + # Use ReadWriteOnce for single-node clusters. + accessModes: + - ReadWriteOnce + # Where the model weights PVC is mounted inside the worker container. + mountPath: /models/cosytts + # Set to true to wipe the existing model dir and re-download on the next + # helm upgrade. Needed when switching model versions on a PVC that already + # holds weights (the init Job otherwise skips download when the model dir + # already looks complete). Set back to false after the redownload runs. + forceDownload: false + modelSource: + # type: modelscope | huggingface (only modelscope is implemented) + type: modelscope + # ModelScope model ID — snapshot_download uses this as the sub-path + # inside mountPath, so the final path is mountPath/modelId. + modelId: FunAudioLLM/Fun-CosyVoice3-0.5B-2512 + # Env vars injected into every worker that mounts this model. + # HURI_MODEL_PATH must match mountPath/modelId (see text_to_speech.py). + env: + HURI_MODEL_PATH: /models/cosytts/FunAudioLLM/Fun-CosyVoice3-0.5B-2512 + # Path to the CosyVoice repo root containing third_party/Matcha-TTS. + HURI_COSY_DIR: /app/cosyvoice + + whisper: + enabled: true + nodeSelector: + gpu: amd + pvc: + storageClassName: "" + size: 2Gi + accessModes: + - ReadWriteOnce + mountPath: /models/whisper + modelSource: + type: huggingface + repoId: Systran/faster-whisper-base + env: + HURI_STT_MODEL_PATH: /models/whisper/Systran/faster-whisper-base + + emage: + enabled: true + nodeSelector: + gpu: nvidia + pvc: + storageClassName: "" + size: 10Gi + accessModes: + - ReadWriteOnce + # Where the EMAGE weights PVC is mounted inside the worker container. + mountPath: /models/emage + modelSource: + # type: huggingface — snapshot_download uses repoId as the sub-path + # inside mountPath, so the final path is mountPath/repoId. + type: huggingface + repoId: H-Liu1997/emage_audio + # HURI_EMAGE_REPO must match mountPath/repoId (see gesture.py). + env: + HURI_EMAGE_REPO: /models/emage/H-Liu1997/emage_audio + +# Voice sample asset volume. +# Populate the PVC after first install: +# kubectl cp voice.wav :/assets/voice.wav +voiceAssets: + enabled: true + pvc: + storageClassName: "" + size: 100Mi + accessModes: + - ReadWriteOnce + mountPath: /assets + env: + HURI_VOICE_SAMPLE_PATH: /assets/voice.wav + +# Ingress for the Ray Serve endpoint (port 8000). +ingress: + enabled: true + className: nginx + annotations: + nginx.ingress.kubernetes.io/proxy-read-timeout: "3600" + nginx.ingress.kubernetes.io/proxy-send-timeout: "3600" + nginx.ingress.kubernetes.io/proxy-buffering: "off" + host: huri.lan + tls: [] + # - secretName: huri-tls + # hosts: + # - huri.example.com + +# Ingress for the Ray Dashboard (port 8265). Disabled by default – not safe to +# expose publicly without additional auth. +dashboard: + ingress: + enabled: true + className: nginx + annotations: {} + host: dashboard.huri.lan + tls: [] + +# Set to true to let this chart manage the KubeRay operator as a sub-chart. +# Set to false when the operator is already installed cluster-wide (typical for +# shared clusters). +kuberay: + install: false diff --git a/requirements-amd.txt b/requirements-amd.txt new file mode 100644 index 0000000..7fa9358 --- /dev/null +++ b/requirements-amd.txt @@ -0,0 +1,16 @@ +# AMD/ROCm worker extras (on top of serve_requirements.txt installed by Dockerfile.base layer +# and the ROCm torch/torchaudio wheels installed directly in Dockerfile.amd). +# +# Hosts: STT (faster-whisper, already in serve_requirements.txt) and RAG/LLM. +# Does NOT include CosyVoice2 or EMAGE — those run on the NVIDIA worker. + +# --- RAG / LLM --- +# httpx + qdrant-client moved to serve_requirements.txt so the head/base +# controller can import rag.py without falling back to None (see that file). +sentence-transformers==3.2.1 +pypdf==5.1.0 +semantic_chunker==0.2.0 + +# transformers is pulled in by sentence-transformers; pin to a version compatible +# with the ROCm torch 2.8 wheel. +transformers==4.46.3 diff --git a/requirements-nvidia.txt b/requirements-nvidia.txt new file mode 100644 index 0000000..98a8950 --- /dev/null +++ b/requirements-nvidia.txt @@ -0,0 +1,49 @@ +# --- Shared / pinned by CosyVoice (keep these versions) --- +torch==2.3.1 +torchaudio==2.3.1 +numpy==1.26.4 +transformers==4.51.3 +diffusers==0.29.0 +omegaconf==2.3.0 +librosa==0.10.2 +soundfile==0.12.1 +hydra-core==1.3.2 # only because HyperPyYAML configs may resolve hydra refs; safe to keep +HyperPyYAML==1.2.3 +modelscope==1.20.0 +onnx==1.16.0 +onnxruntime-gpu==1.18.0 # Linux; campplus + speech_tokenizer +openai-whisper==20250625 # frontend.py: import whisper +inflect==7.3.1 +wetext==0.0.4 # text normalization fallback (ttsfrd not installed) +conformer==0.3.2 +x-transformers==2.11.24 +einops==0.8.2 +tiktoken==0.13.0 # cosyvoice/tokenizer +pyarrow==18.1.0 # imported by cli paths via dataset utils? actually only dataset/processor — can drop +protobuf==4.25 +pydantic==2.7.0 # transitive (transformers/fastapi), but pinning avoids drift +regex==2025.11.3 +tqdm==4.67.3 + +# --- RAG / LLM extras --- +httpx==0.27.2 +qdrant-client==1.18.0 +sentence-transformers==3.2.1 +pypdf==5.1.0 +semantic_chunker==0.2.0 + +# --- EMAGE extras --- +huggingface_hub==0.36.2 # from_pretrained +smplx==0.1.28 +pyrender==0.1.45 # fast_render top-level import +trimesh==4.12.2 +imageio==2.33.0 +# Visualization-only (skip if --visualization off): +# opencv-python==4.8.1.78 +# pytorch3d # has to be built from source for torch 2.3 / py3.12 +# torchvision +lightning==2.2.4 +gdown==5.1.0 +matplotlib==3.7.5 # ridiculous but necessary for init modules +wget==3.2 +pyworld==0.3.4 diff --git a/serve_requirements.txt b/serve_requirements.txt new file mode 100644 index 0000000..1ebe583 --- /dev/null +++ b/serve_requirements.txt @@ -0,0 +1,12 @@ +# server +numpy +click<8.2 +webrtcvad +faster-whisper + +# RAG/LLM client deps. Kept in the *base* image (not just the AMD worker) +# because the Serve controller on the head node imports rag.py to build the app; +# if these are missing there, the module-level imports fall back to None and that +# None is baked into the cloudpickled RAGHandle shipped to the AMD replica. +httpx==0.27.2 +qdrant-client==1.18.0 diff --git a/src/app.py b/src/app.py index 79f58db..1b314de 100644 --- a/src/app.py +++ b/src/app.py @@ -1,51 +1,15 @@ -from pathlib import Path -from typing import Any - -import yaml from ray.serve import Application from src.core.huri import HuRI from src.modules.events import get_events from src.modules.factory import bind_deployment_handles from src.modules.modules import get_modules -from src.modules.rag.docker_services import OllamaService, QdrantService - - -def load_services_config() -> Any: - config_path = Path(__file__).resolve().parents[1] / "config" / "huri.yaml" - with open(config_path) as f: - config = yaml.safe_load(f) - return config.get("services", {}) - - -def build_qdrant(config: dict) -> Any: - return QdrantService.bind( # type: ignore[attr-defined] - port=config.get("port", 6333), - image=config.get("image", "qdrant/qdrant:latest"), - storage_volume=config.get("storage_volume", "qdrant_data"), - ) - - -def build_ollama(config: dict) -> Any: - return OllamaService.options( # type: ignore[attr-defined] - num_replicas=config.get("num_replicas", 1), - ).bind( - model=config.get("model", "mistral:7b"), - image=config.get("image", "ollama/ollama:latest"), - gpu_devices=config.get("gpu_devices", False), - ) def build_app() -> Application: modules = get_modules() events = get_events() - - services_config = load_services_config() - - qdrant = build_qdrant(services_config.get("qdrant", {})) - ollama = build_ollama(services_config.get("ollama", {})) - - handles = bind_deployment_handles(modules, ollama=ollama, qdrant=qdrant) + handles = bind_deployment_handles(modules) app: Application = HuRI.bind(modules, handles, events) # type: ignore[attr-defined] return app diff --git a/src/client.py b/src/client.py index 466cc05..d6bba63 100644 --- a/src/client.py +++ b/src/client.py @@ -26,11 +26,21 @@ async def launch_client(): required=True, help="Path to Client config file (YAML)", ) + parser.add_argument( + "--save-audio", + nargs="?", + const="audio_dumps", + default=None, + metavar="DIR", + help="Save streamed TTS audio to .wav files (one per utterance) in DIR " + "for quality-checking. Defaults to ./audio_dumps when the flag is given " + "without a value.", + ) args = parser.parse_args() config = load_client_config(args.config) - await Client(config=config).run() + await Client(config=config, save_audio_dir=args.save_audio).run() if __name__ == "__main__": diff --git a/src/core/client.py b/src/core/client.py index 085a0b8..2af97f2 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -1,9 +1,13 @@ import asyncio import json import os +import struct +import wave from dataclasses import asdict +from datetime import datetime from typing import Dict, List, Optional, Type +import numpy as np import websockets from src.core.dataclasses.config import ClientConfig @@ -19,11 +23,22 @@ def __init__( config: ClientConfig, user_id_file: str = os.path.expanduser("~/.huri_user_id"), senders_dict: Dict[str, Type[ClientSender]] = get_senders(), + save_audio_dir: Optional[str] = None, ): self.config = config self.user_id_file = user_id_file self.senders_dict = senders_dict + # When set, incoming audio chunks are buffered per utterance and written + # to a .wav under this directory each time an end-of-utterance marker + # arrives — handy for ear-checking what the TTS actually streamed. + self.save_audio_dir = save_audio_dir + self._audio_buf: List[np.ndarray] = [] + self._audio_sr: Optional[int] = None + self._audio_idx = 0 + if save_audio_dir: + os.makedirs(save_audio_dir, exist_ok=True) + def _load_user_id(self) -> Optional[str]: if os.path.exists(self.user_id_file): with open(self.user_id_file) as f: @@ -34,15 +49,76 @@ def _save_user_id(self, _user_id: str): with open(self.user_id_file, "w") as f: f.write(_user_id) + def _collect_audio(self, samples: np.ndarray, sample_rate: int, end: bool) -> None: + if samples.size: + self._audio_buf.append(samples) + self._audio_sr = sample_rate + if end: + self._flush_audio() + + def _flush_audio(self) -> None: + if not self._audio_buf or self._audio_sr is None: + self._audio_buf = [] + return + audio = np.concatenate(self._audio_buf) + stamp = datetime.now().strftime("%Y%m%d-%H%M%S") + path = os.path.join( + self.save_audio_dir, f"utt-{self._audio_idx:03d}-{stamp}.wav" + ) + self._write_wav(path, audio, self._audio_sr) + print( + f"** saved audio: {path} ({audio.size} samples, " + f"~{audio.size / self._audio_sr:.2f}s @ {self._audio_sr}Hz)" + ) + self._audio_idx += 1 + self._audio_buf = [] + + @staticmethod + def _write_wav(path: str, audio: np.ndarray, sample_rate: int) -> None: + # float32 [-1, 1] -> 16-bit PCM, clipped to avoid wraparound on overshoot. + pcm = np.clip(audio, -1.0, 1.0) + pcm = (pcm * 32767.0).astype("H", msg[:2]) + topic = msg[2:2 + topic_len].decode() + payload = msg[2 + topic_len:] + + if topic == "audio" and len(payload) >= 13: + sample_rate, end_flag, pts = struct.unpack(">IBd", payload[:13]) + # Samples are native-endian float32 (Sender uses ndarray.tobytes()). + samples = np.frombuffer(payload[13:], dtype=np.float32) + print( + f"<< audio: pts={pts:.3f}s samples={samples.size} @ {sample_rate}Hz " + f"end={bool(end_flag)}" + ) + if self.save_audio_dir: + self._collect_audio(samples, sample_rate, bool(end_flag)) + elif topic == "motion" and len(payload) >= 16: + pts, fps, n_frames = struct.unpack(">dII", payload[:16]) + print(f"<< motion: pts={pts:.3f}s frames={n_frames} @ {fps}fps") + else: + print(f"<< {topic}: bytes ({len(payload)}B)") + else: + print("<<", msg) except (asyncio.CancelledError, websockets.ConnectionClosedOK): pass + finally: + if self.save_audio_dir: + self._flush_audio() # save anything left if the stream ended mid-utterance async def run(self): async with websockets.connect(self.config.huri_url) as ws: diff --git a/src/core/client_senders.py b/src/core/client_senders.py index 03301a6..c1ae31f 100644 --- a/src/core/client_senders.py +++ b/src/core/client_senders.py @@ -45,7 +45,9 @@ async def send(self, topic: str, data: EventData | bytes): class AudioSender(ClientSender): - output_type = "audio" + # Mic frames go out on "audio_in"; the server's "audio" topic is reserved + # for TTS output streamed back to us (see MIC.input_type). + output_type = "audio_in" def __init__( self, sample_rate: int = 16000, frame_duration: float = 0.030, **kwargs diff --git a/src/core/dataclasses/config.py b/src/core/dataclasses/config.py index aea111f..8cccb03 100644 --- a/src/core/dataclasses/config.py +++ b/src/core/dataclasses/config.py @@ -47,7 +47,7 @@ def from_dict(cls, raw: Dict) -> "ClientConfig": for module_id, mod_raw in raw.get("modules", {}).items() } return cls( - user_id=None, + user_id=raw.get("user_id"), huri_url=raw["huri_url"], topic_list=raw["topic_list"], senders=senders, diff --git a/src/core/events.py b/src/core/events.py index b764258..dfaf66b 100644 --- a/src/core/events.py +++ b/src/core/events.py @@ -1,9 +1,14 @@ import asyncio +import logging from collections import defaultdict from dataclasses import dataclass +import numpy as np + from .module import Module +logger = logging.getLogger("ray.serve") + @dataclass class EventData: @@ -43,7 +48,13 @@ def register(self, module: Module): self.subscribers[module.input_type].append(module) async def publish(self, event_topic, data): - for module in self.subscribers[event_topic]: + subs = self.subscribers[event_topic] + if event_topic not in ("audio_in",): # skip mic-frame spam + logger.info( + "[GRAPH] publish topic=%r subscribers=%s", + event_topic, [type(m).__name__ for m in subs], + ) + for module in subs: asyncio.create_task(self._run(module, data)) async def _run(self, module: Module, data): @@ -55,17 +66,34 @@ async def _run(self, module: Module, data): async for item in result: if item is None: continue + logger.info( + "[GRAPH] %s -> %r: %s", + type(module).__name__, module.output_type, _summarize(item), + ) await self.publish(module.output_type, item) - except Exception as e: - print(f"[ERROR] async generator in {module}: {e}") + except Exception: + logger.exception("[GRAPH] async generator failed in %s", type(module).__name__) else: try: value = await result if value is not None: + logger.info( + "[GRAPH] %s -> %r: %s", + type(module).__name__, module.output_type, _summarize(value), + ) await self.publish(module.output_type, value) - except Exception as e: - print(f"[ERROR] coroutine in {module}: {e}") + except Exception: + logger.exception("[GRAPH] coroutine failed in %s", type(module).__name__) + + except Exception: + logger.exception("[GRAPH] process() call failed in %s", type(module).__name__) + - except Exception as e: - print(f"[ERROR] process() call failed in {module}: {e}") +def _summarize(item) -> str: + """Short repr that avoids dumping full numpy arrays into the log.""" + cls = type(item).__name__ + data = getattr(item, "data", None) + if isinstance(data, np.ndarray): + return f"{cls}(shape={data.shape}, dtype={data.dtype})" + return f"{cls}({item!r})" diff --git a/src/modules/events.py b/src/modules/events.py index 43f6c71..cb7bd61 100644 --- a/src/modules/events.py +++ b/src/modules/events.py @@ -1,15 +1,32 @@ from typing import Dict, Type from src.core.events import EventData -from src.modules.rag.events import RAGResult from src.modules.speech_to_text.events import Sentence, Transcript, Voice +from src.modules.text_to_speech.events import Audio, Token def get_events() -> Dict[str, Type[EventData | bytes]]: - return { + events: Dict[str, Type[EventData | bytes]] = { + "audio_in": bytes, # inbound mic frames (raw int16 PCM) "audio": bytes, "voice": Voice, "transcript": Transcript, "question": Sentence, - "rag_response": RAGResult, + "token": Token, } + + # Motion lives in the gesture module — only available when EMAGE deps installed. + try: + from src.modules.gesture.gesture import Motion + except Exception: + pass + else: + events["motion"] = Motion + + # TTS output "audio" is an Audio dataclass internally, sent to the client by + # the Sender's Audio branch (never decoded inbound). Inbound mic frames use + # the separate "audio_in" topic above. The registry keeps bytes for "audio" + # only for output-type registration. Keep Audio importable for completeness. + _ = Audio + + return events diff --git a/src/modules/factory.py b/src/modules/factory.py index 728aba3..fb0f58c 100644 --- a/src/modules/factory.py +++ b/src/modules/factory.py @@ -1,8 +1,10 @@ from typing import Any, Dict, List, Mapping, Type +from ray.serve import handle + from src.core.dataclasses.config import ModuleConfig from src.core.events import EventData -from src.core.module import Module, ModuleWithHandle, ModuleWithId, handle +from src.core.module import Module, ModuleWithHandle, ModuleWithId class EventDataFactory: @@ -94,7 +96,6 @@ def create_from_config( def bind_deployment_handles( modules: Dict[str, Type[Module]], - **service_handles, ) -> Dict[str, handle.DeploymentHandle]: handles: Dict[str, handle.DeploymentHandle] = {} for name, module_cls in modules.items(): @@ -106,12 +107,6 @@ def bind_deployment_handles( handle_cls = module_cls._handle_cls - if name == "rag" and service_handles: - handles[name] = handle_cls.bind( - ollama_handle=service_handles.get("ollama"), - qdrant_handle=service_handles.get("qdrant"), - ) - else: - handles[name] = handle_cls.bind() + handles[name] = handle_cls.bind() return handles diff --git a/src/modules/gesture/__init__.py b/src/modules/gesture/__init__.py new file mode 100644 index 0000000..8d8e6c2 --- /dev/null +++ b/src/modules/gesture/__init__.py @@ -0,0 +1 @@ +from .gesture import Gesture, Motion diff --git a/src/modules/gesture/emage/__init__.py b/src/modules/gesture/emage/__init__.py new file mode 100644 index 0000000..df5a071 --- /dev/null +++ b/src/modules/gesture/emage/__init__.py @@ -0,0 +1,12 @@ +from .configuration import EmageAudioConfig, EmageVAEConvConfig, EmageVQVAEConvConfig +from .modeling import EmageAudioModel, EmageVAEConv, EmageVQModel, EmageVQVAEConv + +__all__ = [ + "EmageAudioConfig", + "EmageAudioModel", + "EmageVAEConvConfig", + "EmageVAEConv", + "EmageVQVAEConvConfig", + "EmageVQVAEConv", + "EmageVQModel", +] diff --git a/src/modules/gesture/emage/configuration.py b/src/modules/gesture/emage/configuration.py new file mode 100644 index 0000000..ad97076 --- /dev/null +++ b/src/modules/gesture/emage/configuration.py @@ -0,0 +1,32 @@ +from omegaconf import OmegaConf +from transformers import PretrainedConfig + + +class EmageAudioConfig(PretrainedConfig): + model_type = "emage_audio" + + def __init__(self, config_obj=None, **kwargs): + if config_obj is not None: + cfg_dict = OmegaConf.to_container(config_obj, resolve=True) + kwargs.update(cfg_dict) + super().__init__(**kwargs) + + +class EmageVQVAEConvConfig(PretrainedConfig): + model_type = "emage_vqvaeconv" + + def __init__(self, config_obj=None, **kwargs): + if config_obj is not None: + cfg_dict = OmegaConf.to_container(config_obj, resolve=True) + kwargs.update(cfg_dict) + super().__init__(**kwargs) + + +class EmageVAEConvConfig(PretrainedConfig): + model_type = "emage_vaeconv" + + def __init__(self, config_obj=None, **kwargs): + if config_obj is not None: + cfg_dict = OmegaConf.to_container(config_obj, resolve=True) + kwargs.update(cfg_dict) + super().__init__(**kwargs) diff --git a/src/modules/gesture/emage/modeling.py b/src/modules/gesture/emage/modeling.py new file mode 100644 index 0000000..b4d71ab --- /dev/null +++ b/src/modules/gesture/emage/modeling.py @@ -0,0 +1,449 @@ +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedModel + +from .configuration import EmageAudioConfig, EmageVAEConvConfig, EmageVQVAEConvConfig +from .processing import ( + MLP, + PeriodicPositionalEncoding, + Quantizer, + VQDecoderV5, + VQEncoderV5, + VQEncoderV6, + WavEncoder, + axis_angle_to_rotation_6d, + matrix_to_axis_angle, + matrix_to_rotation_6d, + recover_from_mask_ts, + rotation_6d_to_axis_angle, + rotation_6d_to_matrix, + velocity2position, + axis_angle_to_matrix, +) + + +class EmageVAEConv(PreTrainedModel): + config_class = EmageVAEConvConfig + base_model_prefix = "emage_vaeconv" + + def __init__(self, config): + super().__init__(config) + self.encoder = VQEncoderV5(config) + self.decoder = VQDecoderV5(config) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + rec_pose = self.decoder(pre_latent) + return {"rec_pose": rec_pose} + + +class EmageVQVAEConv(PreTrainedModel): + config_class = EmageVQVAEConvConfig + base_model_prefix = "emage_vqvaeconv" + + def __init__(self, config): + super().__init__(config) + self.encoder = VQEncoderV5(config) + self.quantizer = Quantizer(config.vae_codebook_size, config.vae_length, config.vae_quantizer_lambda) + self.decoder = VQDecoderV5(config) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return {"poses_feat": vq_latent, "embedding_loss": embedding_loss, "perplexity": perplexity, "rec_pose": rec_pose} + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + return self.quantizer.map2index(pre_latent) + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return self.quantizer.get_codebook_entry(index) + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + return self.decoder(z_q) + + def decode_from_latent(self, latent): + z_flattened = latent.contiguous().view(-1, self.quantizer.e_dim) + d = ( + torch.sum(z_flattened ** 2, dim=1, keepdim=True) + + torch.sum(self.quantizer.embedding.weight ** 2, dim=1) + - 2 * torch.matmul(z_flattened, self.quantizer.embedding.weight.t()) + ) + min_encoding_indices = torch.argmin(d, dim=1) + indices = min_encoding_indices.view(latent.shape[0], latent.shape[1]) + z_q = self.quantizer.get_codebook_entry(indices) + return self.decoder(z_q) + + +class EmageVQModel(nn.Module): + def __init__(self, face_model, upper_model, hands_model, lower_model, global_model): + super().__init__() + self.joint_mask_upper = [ + False, False, False, True, False, False, True, False, False, True, + False, False, True, True, True, True, True, True, True, True, + True, True, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, + ] + self.joint_mask_lower = [ + True, True, True, False, True, True, False, True, True, False, + True, True, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, + ] + self.vq_model_face = face_model + self.vq_model_upper = upper_model + self.vq_model_hands = hands_model + self.vq_model_lower = lower_model + self.global_motion = global_model + + def spilt_inputs(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): + bs, t, j6 = smplx_body_rot6d.shape + smplx_body_rot6d = smplx_body_rot6d.reshape(bs, t, j6 // 6, 6) + jaw_rot6d = smplx_body_rot6d[:, :, 22:23, :].reshape(bs, t, 6) + face = torch.cat([jaw_rot6d, expression], dim=2) + upper_rot6d = smplx_body_rot6d[:, :, self.joint_mask_upper, :].reshape(bs, t, 78) + hands_rot6d = smplx_body_rot6d[:, :, 25:55, :].reshape(bs, t, 180) + lower_rot6d = smplx_body_rot6d[:, :, self.joint_mask_lower, :].reshape(bs, t, 54) + tar_contact = torch.zeros(bs, t, 4, device=smplx_body_rot6d.device) if tar_contact is None else tar_contact + tar_trans = torch.zeros(bs, t, 3, device=smplx_body_rot6d.device) if tar_trans is None else tar_trans + lower = torch.cat([lower_rot6d, tar_trans, tar_contact], dim=2) + return dict(face=face, upper=upper_rot6d, hands=hands_rot6d, lower=lower) + + def map2index(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): + inputs = self.spilt_inputs(smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans) + return dict( + face=self.vq_model_face.map2index(inputs["face"]), + upper=self.vq_model_upper.map2index(inputs["upper"]), + hands=self.vq_model_hands.map2index(inputs["hands"]), + lower=self.vq_model_lower.map2index(inputs["lower"]), + ) + + def map2latent(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): + inputs = self.spilt_inputs(smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans) + return dict( + face=self.vq_model_face.map2latent(inputs["face"]), + upper=self.vq_model_upper.map2latent(inputs["upper"]), + hands=self.vq_model_hands.map2latent(inputs["hands"]), + lower=self.vq_model_lower.map2latent(inputs["lower"]), + ) + + def decode( + self, + face_index=None, upper_index=None, hands_index=None, lower_index=None, + face_latent=None, upper_latent=None, hands_latent=None, lower_latent=None, + get_global_motion=False, ref_trans=None, + ): + for t in [face_index, upper_index, hands_index, lower_index, face_latent, upper_latent, hands_latent, lower_latent]: + if t is not None: + bs, seq = t.shape[:2] + break + + if face_index is not None: + face_mix = self.vq_model_face.decode(face_index) + face_jaw_6d, expression = face_mix[:, :, :6], face_mix[:, :, 6:] + face_jaw = rotation_6d_to_axis_angle(face_jaw_6d) + elif face_latent is not None: + face_mix = self.vq_model_face.decode_from_latent(face_latent) + face_jaw_6d, expression = face_mix[:, :, :6], face_mix[:, :, 6:] + face_jaw = rotation_6d_to_axis_angle(face_jaw_6d) + else: + face_jaw = torch.zeros(bs, seq, 3, device=self.vq_model_face.device) + expression = torch.zeros(bs, seq, 100, device=self.vq_model_face.device) + + if upper_index is not None: + upper_6d = self.vq_model_upper.decode(upper_index) + upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + elif upper_latent is not None: + upper_6d = self.vq_model_upper.decode_from_latent(upper_latent) + upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + else: + upper = torch.zeros(bs, seq, 39, device=self.vq_model_upper.device) + + if hands_index is not None: + hands_6d = self.vq_model_hands.decode(hands_index) + hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + elif hands_latent is not None: + hands_6d = self.vq_model_hands.decode_from_latent(hands_latent) + hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + else: + hands = torch.zeros(bs, seq, 90, device=self.vq_model_hands.device) + + if lower_index is not None: + lower_mix = self.vq_model_lower.decode(lower_index) + lower_6d, transfoot = lower_mix[:, :, :-7], lower_mix[:, :, -7:] + lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + elif lower_latent is not None: + lower_mix = self.vq_model_lower.decode_from_latent(lower_latent) + lower_6d, transfoot = lower_mix[:, :, :-7], lower_mix[:, :, -7:] + lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, seq, -1, 6)).reshape(bs, seq, -1) + else: + lower = torch.zeros(bs, seq, 27, device=self.vq_model_lower.device) + transfoot = torch.zeros(bs, seq, 7, device=self.vq_model_lower.device) + lower_6d = axis_angle_to_rotation_6d(lower.reshape(bs, seq, -1, 3)).reshape(bs, seq, -1) + lower_mix = torch.cat([lower_6d, transfoot], dim=-1) + + upper2all = recover_from_mask_ts(upper, self.joint_mask_upper) + hands2all = recover_from_mask_ts(hands, [False] * 25 + [True] * 30) + lower2all = recover_from_mask_ts(lower, self.joint_mask_lower) + + all_motion_axis_angle = upper2all + hands2all + lower2all + all_motion_axis_angle[:, :, 22 * 3:22 * 3 + 3] = face_jaw + all_motion_rot6d = axis_angle_to_rotation_6d(all_motion_axis_angle.reshape(bs, seq, 55, 3)).reshape(bs, seq, 55 * 6) + all_motion4inference = torch.cat([all_motion_rot6d, transfoot], dim=2) + + global_motion = None + if get_global_motion: + global_motion = self._get_global_motion(lower_mix, ref_trans) + + return dict( + expression=expression, + all_motion4inference=all_motion4inference, + motion_axis_angle=all_motion_axis_angle, + trans=global_motion, + ) + + def _get_global_motion(self, lower_body, ref_trans): + global_motion = self.global_motion(lower_body) + rec_trans_v_s = global_motion["rec_pose"][:, :, 54:57] + if len(ref_trans.shape) == 2: + ref_trans = ref_trans.unsqueeze(0).repeat(rec_trans_v_s.shape[0], 1, 1) + rec_x_trans = velocity2position(rec_trans_v_s[:, :, 0:1], 1 / 30, ref_trans[:, 0, 0:1]) + rec_z_trans = velocity2position(rec_trans_v_s[:, :, 2:3], 1 / 30, ref_trans[:, 0, 2:3]) + rec_y_trans = rec_trans_v_s[:, :, 1:2] + return torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) + + +class EmageAudioModel(PreTrainedModel): + config_class = EmageAudioConfig + base_model_prefix = "emage_audio" + + def __init__(self, config: EmageAudioConfig): + super().__init__(config) + self.cfg = config + self.audio_encoder_face = WavEncoder(self.cfg.audio_f) + self.audio_encoder_body = WavEncoder(self.cfg.audio_f) + self.speaker_embedding_body = nn.Embedding(self.cfg.speaker_dims, self.cfg.hidden_size) + self.speaker_embedding_face = nn.Embedding(self.cfg.speaker_dims, self.cfg.hidden_size) + self.mask_embedding = nn.Parameter(torch.zeros(1, 1, self.cfg.pose_dims + 3 + 4)) + nn.init.normal_(self.mask_embedding, 0, self.cfg.hidden_size ** -0.5) + + args_top = copy.deepcopy(self.cfg) + args_top.vae_layer = 3 + args_top.vae_length = self.cfg.motion_f + args_top.vae_test_dim = self.cfg.pose_dims + 3 + 4 + self.motion_encoder = VQEncoderV6(args_top) + self.bodyhints_face = MLP(self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f) + self.bodyhints_body = MLP(self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f) + self.audio_body_motion_proj = nn.Linear(self.cfg.audio_f, self.cfg.hidden_size) + self.moton_proj = nn.Linear(self.cfg.motion_f, self.cfg.hidden_size) + self.position_embeddings = PeriodicPositionalEncoding(self.cfg.hidden_size, period=self.cfg.pose_length, max_seq_len=self.cfg.pose_length) + self.transformer_en_layer = nn.TransformerEncoderLayer(d_model=self.cfg.hidden_size, nhead=4, dim_feedforward=self.cfg.hidden_size * 2) + self.motion_self_encoder = nn.TransformerEncoder(self.transformer_en_layer, num_layers=1) + self.audio_motion_cross_attn_layer = nn.TransformerDecoderLayer(d_model=self.cfg.hidden_size, nhead=4, dim_feedforward=self.cfg.hidden_size * 2) + self.audio_motion_cross_attn = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=8) + self.motion2latent_upper = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) + self.motion2latent_hands = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) + self.motion2latent_lower = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) + self.body_motion_decoder_upper = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) + self.body_motion_decoder_hands = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) + self.body_motion_decoder_lower = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) + self.motion_out_proj_upper = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) + self.motion_out_proj_hands = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) + self.motion_out_proj_lower = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) + self.motion_cls_upper = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) + self.motion_cls_hands = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) + self.motion_cls_lower = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) + self.audio_face_motion_proj = nn.Linear(self.cfg.audio_f + self.cfg.motion_f, self.cfg.hidden_size) + self.face_motion_decoder = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=4) + self.face_out_proj = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) + self.face_cls = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) + + def forward(self, audio, speaker_id, masked_motion, mask, use_audio=True): + masked_embeddings = self.mask_embedding.expand_as(masked_motion) + masked_motion = torch.where(mask == 1, masked_embeddings, masked_motion) + + body_hint = self.motion_encoder(masked_motion) + body_hint_body = self.bodyhints_body(body_hint) + body_hint_face = self.bodyhints_face(body_hint) + + audio2face_fea = self.audio_encoder_face(audio) + audio2body_fea = self.audio_encoder_body(audio) + + if audio2face_fea.shape[1] > body_hint_face.shape[1]: + audio2face_fea = audio2face_fea[:, :body_hint_face.shape[1]] + if audio2body_fea.shape[1] > body_hint_face.shape[1]: + audio2face_fea = audio2face_fea[:, :body_hint_face.shape[1]] + + bs, t, _ = audio2face_fea.shape + + speaker_motion_fea_proj = self.speaker_embedding_body(speaker_id).repeat(1, t, 1) + speaker_face_fea_proj = self.speaker_embedding_face(speaker_id).repeat(1, t, 1) + + audio2face_fea_proj = self.audio_face_motion_proj(torch.cat([audio2face_fea, body_hint_face], dim=2)) + face_proj = self.position_embeddings(speaker_face_fea_proj) + decode_face = self.face_motion_decoder(tgt=face_proj.permute(1, 0, 2), memory=audio2face_fea_proj.permute(1, 0, 2)).permute(1, 0, 2) + face_latent = self.face_out_proj(decode_face) + classify_face = self.face_cls(face_latent) + + masked_motion_proj = self.moton_proj(body_hint_body) + masked_motion_proj = self.position_embeddings(masked_motion_proj) + masked_motion_proj = speaker_motion_fea_proj + masked_motion_proj + motion_fea = self.motion_self_encoder(masked_motion_proj.permute(1, 0, 2)).permute(1, 0, 2) + + audio2body_fea_proj = self.audio_body_motion_proj(audio2body_fea) + motion_fea = motion_fea + speaker_motion_fea_proj + motion_fea = self.position_embeddings(motion_fea) + audio2body_fea_cross = self.audio_motion_cross_attn(tgt=motion_fea.permute(1, 0, 2), memory=audio2body_fea_proj.permute(1, 0, 2)).permute(1, 0, 2) + if not use_audio: + audio2body_fea_cross = audio2body_fea_cross * 0.0 + motion_fea = motion_fea + audio2body_fea_cross + + upper_latent = self.motion2latent_upper(motion_fea) + hands_latent = self.motion2latent_hands(motion_fea) + lower_latent = self.motion2latent_lower(motion_fea) + + motion_upper_refine = self.body_motion_decoder_upper(tgt=upper_latent.permute(1, 0, 2) + speaker_motion_fea_proj.permute(1, 0, 2), memory=(hands_latent + lower_latent).permute(1, 0, 2)).permute(1, 0, 2) + motion_hands_refine = self.body_motion_decoder_hands(tgt=hands_latent.permute(1, 0, 2) + speaker_motion_fea_proj.permute(1, 0, 2), memory=(upper_latent + lower_latent).permute(1, 0, 2)).permute(1, 0, 2) + motion_lower_refine = self.body_motion_decoder_lower(tgt=lower_latent.permute(1, 0, 2) + speaker_motion_fea_proj.permute(1, 0, 2), memory=(upper_latent + hands_latent).permute(1, 0, 2)).permute(1, 0, 2) + upper_latent = self.motion_out_proj_upper(upper_latent + motion_upper_refine) + hands_latent = self.motion_out_proj_hands(hands_latent + motion_hands_refine) + lower_latent = self.motion_out_proj_lower(lower_latent + motion_lower_refine) + + classify_upper = self.motion_cls_upper(upper_latent) + classify_hands = self.motion_cls_hands(hands_latent) + classify_lower = self.motion_cls_lower(lower_latent) + + return { + "rec_face": face_latent, + "rec_upper": upper_latent, + "rec_hands": hands_latent, + "rec_lower": lower_latent, + "cls_face": classify_face, + "cls_upper": classify_upper, + "cls_hands": classify_hands, + "cls_lower": classify_lower, + } + + def inference(self, audio, speaker_id, vq_model, masked_motion=None, mask=None): + length = audio.shape[1] * 30 // 16000 + bs = audio.shape[0] + + fake_axis_angle = torch.zeros(bs, length, 55, 3).to(audio.device) + fake_motion = axis_angle_to_rotation_6d(fake_axis_angle).reshape(bs, length, -1) + fake_foot_and_trans = torch.zeros(bs, length, 7).to(audio.device) + fake_motion = torch.cat([fake_motion, fake_foot_and_trans], dim=-1) + if masked_motion is not None: + fake_motion[:, :masked_motion.shape[1]] = masked_motion + masked_motion = fake_motion + + fake_mask = torch.ones_like(masked_motion) + if mask is not None: + fake_mask[:, :mask.shape[1]] = mask + mask = fake_mask + + bs, total_len, c = masked_motion.shape + window = self.cfg.pose_length + pre_frames = self.cfg.seed_frames + rounds = (total_len - pre_frames) // (window - pre_frames) + remain = (total_len - pre_frames) % (window - pre_frames) + + rec_all_face, rec_all_lower, rec_all_upper, rec_all_hands = [], [], [], [] + cls_all_face, cls_all_lower, cls_all_upper, cls_all_hands = [], [], [], [] + + last_motion = masked_motion[:, :pre_frames, :] + + for i in range(rounds): + start_idx = i * (window - pre_frames) + end_idx = start_idx + window + + window_mask = mask[:, start_idx:end_idx, :].clone() + window_motion = masked_motion[:, start_idx:end_idx, :].clone() + window_motion[:, :pre_frames, :] = torch.where( + (window_mask[:, :pre_frames, :] == 0), + masked_motion[:, start_idx:start_idx + pre_frames, :], + last_motion, + ) + window_mask[:, :pre_frames, :] = 0 + + audio_slice_len = (end_idx - start_idx) * (16000 // 30) + audio_slice = audio[:, start_idx * (16000 // 30):start_idx * (16000 // 30) + audio_slice_len] + net_out_val = self.forward(audio_slice, speaker_id, masked_motion=window_motion, mask=window_mask, use_audio=True) + + _, cls_face = torch.max(F.log_softmax(net_out_val["cls_face"], dim=2), dim=2) + _, cls_upper = torch.max(F.log_softmax(net_out_val["cls_upper"], dim=2), dim=2) + _, cls_hands = torch.max(F.log_softmax(net_out_val["cls_hands"], dim=2), dim=2) + _, cls_lower = torch.max(F.log_softmax(net_out_val["cls_lower"], dim=2), dim=2) + + face_latent = net_out_val["rec_face"] if self.cfg.lf > 0 and self.cfg.cf == 0 else None + upper_latent = net_out_val["rec_upper"] if self.cfg.lu > 0 and self.cfg.cu == 0 else None + hands_latent = net_out_val["rec_hands"] if self.cfg.lh > 0 and self.cfg.ch == 0 else None + lower_latent = net_out_val["rec_lower"] if self.cfg.ll > 0 and self.cfg.cl == 0 else None + face_index = cls_face if self.cfg.cf > 0 else None + upper_index = cls_upper if self.cfg.cu > 0 else None + hands_index = cls_hands if self.cfg.ch > 0 else None + lower_index = cls_lower if self.cfg.cl > 0 else None + + decode_dict = vq_model.decode( + face_latent=face_latent, upper_latent=upper_latent, + lower_latent=lower_latent, hands_latent=hands_latent, + face_index=face_index, upper_index=upper_index, + lower_index=lower_index, hands_index=hands_index, + ) + + last_motion = decode_dict["all_motion4inference"][:, -pre_frames:, :] + rec_all_face.append(net_out_val["rec_face"][:, :-pre_frames, :]) + rec_all_upper.append(net_out_val["rec_upper"][:, :-pre_frames, :]) + rec_all_hands.append(net_out_val["rec_hands"][:, :-pre_frames, :]) + rec_all_lower.append(net_out_val["rec_lower"][:, :-pre_frames, :]) + cls_all_face.append(net_out_val["cls_face"][:, :-pre_frames]) + cls_all_upper.append(net_out_val["cls_upper"][:, :-pre_frames]) + cls_all_hands.append(net_out_val["cls_hands"][:, :-pre_frames]) + cls_all_lower.append(net_out_val["cls_lower"][:, :-pre_frames]) + + if remain > pre_frames: + final_start = rounds * (window - pre_frames) + final_end = final_start + pre_frames + remain + + final_mask = mask[:, final_start:final_end, :].clone() + final_motion = masked_motion[:, final_start:final_end, :].clone() + final_motion[:, :pre_frames, :] = torch.where( + (final_mask[:, :pre_frames, :] == 0), + masked_motion[:, final_start:final_start + pre_frames, :], + last_motion, + ) + final_mask[:, :pre_frames, :] = 0 + + audio_slice_len = (final_end - final_start) * (16000 // 30) + audio_slice = audio[:, final_start * (16000 // 30):final_start * (16000 // 30) + audio_slice_len] + net_out_val = self.forward(audio_slice, speaker_id, masked_motion=final_motion, mask=final_mask, use_audio=True) + + rec_all_face.append(net_out_val["rec_face"]) + rec_all_upper.append(net_out_val["rec_upper"]) + rec_all_hands.append(net_out_val["rec_hands"]) + rec_all_lower.append(net_out_val["rec_lower"]) + cls_all_face.append(net_out_val["cls_face"]) + cls_all_upper.append(net_out_val["cls_upper"]) + cls_all_hands.append(net_out_val["cls_hands"]) + cls_all_lower.append(net_out_val["cls_lower"]) + + return { + "rec_face": torch.cat(rec_all_face, dim=1), + "rec_upper": torch.cat(rec_all_upper, dim=1), + "rec_hands": torch.cat(rec_all_hands, dim=1), + "rec_lower": torch.cat(rec_all_lower, dim=1), + "cls_face": torch.cat(cls_all_face, dim=1), + "cls_upper": torch.cat(cls_all_upper, dim=1), + "cls_hands": torch.cat(cls_all_hands, dim=1), + "cls_lower": torch.cat(cls_all_lower, dim=1), + } diff --git a/src/modules/gesture/emage/processing.py b/src/modules/gesture/emage/processing.py new file mode 100644 index 0000000..2d128b1 --- /dev/null +++ b/src/modules/gesture/emage/processing.py @@ -0,0 +1,380 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _copysign(a, b): + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def quaternion_to_axis_angle(quaternions): + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def matrix_to_axis_angle(matrix): + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def rotation_6d_to_axis_angle(rot6d): + return matrix_to_axis_angle(rotation_6d_to_matrix(rot6d)) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) + + +def axis_angle_to_quaternion(axis_angle): + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_matrix(quaternions): + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def axis_angle_to_matrix(axis_angle): + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def axis_angle_to_rotation_6d(axis_angle): + return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle)) + + +def velocity2position(data_seq, dt, init_pos): + res_trans = [] + for i in range(data_seq.shape[1]): + if i == 0: + res_trans.append(init_pos.unsqueeze(1)) + else: + res = data_seq[:, i - 1:i] * dt + res_trans[-1] + res_trans.append(res) + return torch.cat(res_trans, dim=1) + + +def recover_from_mask_ts(selected_motion: torch.Tensor, mask: list) -> torch.Tensor: + device = selected_motion.device + dtype = selected_motion.dtype + mask_arr = torch.tensor(mask, dtype=torch.bool, device=device) + j = len(mask_arr) + sum_mask = mask_arr.sum().item() + c_channels = selected_motion.shape[-1] // sum_mask + new_shape = selected_motion.shape[:-1] + (sum_mask, c_channels) + selected_motion = selected_motion.reshape(new_shape) + out_shape = list(selected_motion.shape[:-2]) + [j, c_channels] + recovered = torch.zeros(out_shape, dtype=dtype, device=device) + recovered[..., mask_arr, :] = selected_motion + final_shape = list(recovered.shape[:-2]) + [j * c_channels] + recovered = recovered.reshape(final_shape) + return recovered + + +class Quantizer(nn.Module): + def __init__(self, n_e, e_dim, beta): + super().__init__() + self.e_dim = e_dim + self.n_e = n_e + self.beta = beta + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + assert z.shape[-1] == self.e_dim + z_flattened = z.contiguous().view(-1, self.e_dim) + d = ( + torch.sum(z_flattened ** 2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight ** 2, dim=1) + - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + ) + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + loss = torch.mean((z_q - z.detach()) ** 2) + self.beta * torch.mean((z_q.detach() - z) ** 2) + z_q = z + (z_q - z).detach() + min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + return loss, z_q, min_encoding_indices, perplexity + + def map2index(self, z): + assert z.shape[-1] == self.e_dim + z_flattened = z.contiguous().view(-1, self.e_dim) + d = ( + torch.sum(z_flattened ** 2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight ** 2, dim=1) + - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + ) + min_encoding_indices = torch.argmin(d, dim=1) + return min_encoding_indices.reshape(z.shape[0], -1) + + def get_codebook_entry(self, indices): + index_flattened = indices.view(-1) + z_q = self.embedding(index_flattened) + z_q = z_q.view(indices.shape + (self.e_dim,)).contiguous() + return z_q + + +def init_weight(m): + if isinstance(m, (nn.Conv1d, nn.Linear, nn.ConvTranspose1d)): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class ResBlock(nn.Module): + def __init__(self, channel): + super().__init__() + self.model = nn.Sequential( + nn.Conv1d(channel, channel, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv1d(channel, channel, 3, 1, 1), + ) + + def forward(self, x): + return self.model(x) + x + + +class VQEncoderV5(nn.Module): + def __init__(self, args): + super().__init__() + n_down = args.vae_layer + channels = [args.vae_length] * n_down + input_size = args.vae_test_dim + layers = [ + nn.Conv1d(input_size, channels[0], 3, 1, 1), + nn.LeakyReLU(0.2, True), + ResBlock(channels[0]), + ] + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1), + nn.LeakyReLU(0.2, True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + + +class VQEncoderV6(nn.Module): + def __init__(self, args): + super().__init__() + n_down = args.vae_layer + channels = [args.vae_length] * n_down + input_size = args.vae_test_dim + layers = [ + nn.Conv1d(input_size, channels[0], 3, 1, 1), + nn.LeakyReLU(0.2, True), + ResBlock(channels[0]), + ] + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1), + nn.LeakyReLU(0.2, True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + + +class VQDecoderV5(nn.Module): + def __init__(self, args): + super().__init__() + n_up = args.vae_layer + channels = [args.vae_length] * n_up + [args.vae_test_dim] + input_size = args.vae_length + n_resblk = 2 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], 3, 1, 1)] + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + for i in range(n_up): + layers += [ + nn.Conv1d(channels[i], channels[i + 1], 3, 1, 1), + nn.LeakyReLU(0.2, True), + ] + layers += [nn.Conv1d(channels[-1], channels[-1], 3, 1, 1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm1d): + super().__init__() + self.conv1 = nn.Conv1d( + inplanes, planes, kernel_size=ker_size, stride=stride, + padding=first_dilation, dilation=dilation, bias=True, + ) + self.bn1 = norm_layer(planes) + self.act1 = act_layer(inplace=True) + self.conv2 = nn.Conv1d( + planes, planes, kernel_size=ker_size, padding=ker_size // 2, + dilation=dilation, bias=True, + ) + self.bn2 = norm_layer(planes) + self.act2 = act_layer(inplace=True) + if downsample is not None: + self.downsample = nn.Sequential( + nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, + padding=first_dilation, dilation=dilation, bias=True), + norm_layer(planes), + ) + else: + self.downsample = None + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act2(x) + return x + + +class WavEncoder(nn.Module): + def __init__(self, out_dim, audio_in=1): + super().__init__() + self.out_dim = out_dim + self.feat_extractor = nn.Sequential( + BasicBlock(audio_in, out_dim // 4, 15, 5, first_dilation=1600, downsample=True), + BasicBlock(out_dim // 4, out_dim // 4, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim // 4, out_dim // 4, 15, 1, first_dilation=7), + BasicBlock(out_dim // 4, out_dim // 2, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim // 2, out_dim // 2, 15, 1, first_dilation=7), + BasicBlock(out_dim // 2, out_dim, 15, 3, first_dilation=0, downsample=True), + ) + + def forward(self, wav_data): + if wav_data.dim() == 2: + wav_data = wav_data.unsqueeze(1) + else: + wav_data = wav_data.transpose(1, 2) + out = self.feat_extractor(wav_data) + return out.transpose(1, 2) + + +class MLP(nn.Module): + def __init__(self, in_dim, middle_dim, out_dim): + super().__init__() + self.fc1 = nn.Linear(in_dim, middle_dim) + self.fc2 = nn.Linear(middle_dim, out_dim) + self.act = nn.LeakyReLU(0.1, True) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +class PeriodicPositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=60): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(period, d_model) + position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + repeat_num = (max_seq_len // period) + 1 + pe = pe.repeat(1, repeat_num, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1), :] + return self.dropout(x) diff --git a/src/modules/gesture/gesture.py b/src/modules/gesture/gesture.py new file mode 100644 index 0000000..ecd50f1 --- /dev/null +++ b/src/modules/gesture/gesture.py @@ -0,0 +1,440 @@ +import asyncio +import os +from dataclasses import dataclass +from typing import AsyncGenerator, Optional + +import numpy as np +from ray import serve +from ray.serve import handle + +from src.core.module import Module, ModuleWithHandle +from src.modules.text_to_speech.events import Audio + + +_HF_REPO = os.environ.get("HURI_EMAGE_REPO", "H-Liu1997/emage_audio") +_EMAGE_SR = 16000 # EMAGE expects 16 kHz mono audio +_EMAGE_FPS = 30 # EMAGE emits motion at 30 fps + +# Sliding-window defaults. Overridable per-deployment via the module `args` +# block in the client config, or globally via the env vars below. +_CONTEXT_SEC = float(os.environ.get("HURI_GESTURE_CONTEXT_SEC", "2.0")) +_MIN_CHUNK_SEC = float(os.environ.get("HURI_GESTURE_MIN_CHUNK_SEC", "0.5")) + +# Seconds over which a fresh window's first frames are eased onto the last +# emitted pose, killing the seam snap between windows and between utterances. +_BLEND_SEC = float(os.environ.get("HURI_GESTURE_BLEND_SEC", "0.2")) + +# Optional manual GPU split: cap the gesture process to a fraction of the GPU so +# TTS keeps the lion's share. Only applied on CUDA when the value is set (>0). +_GPU_MEM_FRACTION = float(os.environ.get("HURI_GESTURE_GPU_MEM_FRACTION", "0.0")) + +# Source sample rate used to warm the inference path. Real audio arrives from +# the TTS (CosyVoice ≈ 24 kHz), so every real infer() call resamples to 16 kHz. +# Warming at 16 kHz — as the old warmup did — skips that resample entirely, +# leaving librosa's first-call cost to land on the first user-facing gesture. +# Default to the TTS rate so the resample path is warmed too. Override if your +# TTS uses a different rate (the exact value only affects which resampler +# filter is pre-built; the model shapes follow the 16 kHz duration regardless). +_WARMUP_SRC_SR = int(os.environ.get("HURI_GESTURE_WARMUP_SR", "24000")) + + +@dataclass +class Motion: + poses: np.ndarray # (t, 165) SMPL-X axis-angle, 55 joints × 3 + expressions: np.ndarray # (t, 100) facial expression coefficients + trans: np.ndarray # (t, 3) global root translation + fps: int = _EMAGE_FPS + pts: float = 0.0 # presentation timestamp in seconds, paired with Audio.pts + + +@serve.deployment(name="GestureGeneration") +class GestureDeployment: + def __init__( + self, + hf_repo: str = _HF_REPO, + device: Optional[str] = None, + gpu_mem_fraction: float = _GPU_MEM_FRACTION, + ): + print(f"[Gesture] importing torch...") + import torch + + # Pin algorithm selection so the kernels warmed below are the same ones + # used at serve time. With cudnn.benchmark enabled, cuDNN re-autotunes + # for every new input length — and the sliding window feeds a different + # length almost every call — so the first inference at each new shape + # would stall on autotuning, defeating the warmup. Keep it off (also the + # default) and pin it explicitly. TF32 just speeds matmul/conv on + # Ampere+ with no meaningful quality impact for gesture. + torch.backends.cudnn.benchmark = False + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + print(f"[Gesture] importing emage...") + from .emage import EmageAudioModel, EmageVAEConv, EmageVQModel, EmageVQVAEConv + + self.device = torch.device( + device if device else ("cuda" if torch.cuda.is_available() else "cpu") + ) + print(f"[Gesture] device={self.device} hf_repo={hf_repo!r}") + + # Manual GPU split: cap this process' share of GPU memory so the audio + # (TTS) path keeps the rest. num_gpus in the Ray serveConfig handles + # scheduling/packing; this caps actual allocation on the device. + if self.device.type == "cuda" and gpu_mem_fraction > 0: + try: + torch.cuda.set_per_process_memory_fraction( + gpu_mem_fraction, self.device.index or 0 + ) + print( + f"[Gesture] GPU memory fraction capped at {gpu_mem_fraction:.2f}", + ) + except Exception as e: # noqa: BLE001 — best-effort knob, never fatal + print(f"[Gesture] WARNING could not cap GPU memory: {e!r}") + + print("[Gesture] loading face_vq...") + face_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/face").to(self.device) + print("[Gesture] loading upper_vq...") + upper_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/upper").to(self.device) + print("[Gesture] loading lower_vq...") + lower_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/lower").to(self.device) + print("[Gesture] loading hands_vq...") + hands_vq = EmageVQVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/hands").to(self.device) + print("[Gesture] loading global_ae...") + global_ae = EmageVAEConv.from_pretrained(hf_repo, subfolder="emage_vq/global").to(self.device) + + self.motion_vq = EmageVQModel( + face_model=face_vq, + upper_model=upper_vq, + lower_model=lower_vq, + hands_model=hands_vq, + global_model=global_ae, + ) + self.motion_vq.eval() + + print("[Gesture] loading EmageAudioModel...") + self.model = EmageAudioModel.from_pretrained(hf_repo).to(self.device) + self.model.eval() + + self._warmup() + print(f"[Gesture] ready") + + def _warmup(self) -> None: + # The first inference pays one-time costs that are *shape- and + # path-dependent*: per-input-length kernel/primitive selection (cuDNN + # algo pick on GPU, oneDNN primitive build on CPU), librosa's first-call + # resampler build, the caching allocator's first growth, and CUDA + # context/kernel load. The old warmup ran a single 16 kHz, fixed-length, + # no-resample pass — so it warmed exactly one shape on a path real calls + # never take. The first real gesture (a different length, arriving at the + # TTS rate and therefore resampled) re-paid almost all of it, which is + # why the warmup "did nothing". + # + # Instead, sweep the window lengths the sliding window actually feeds + # infer() — from the small first-chunk window up to a full context+chunk + # steady-state window — on the *real* resample path, twice (the first + # pass pays the costs, the second confirms the path is hot), and + # synchronize so the GPU work is finished before we report ready. + # Best-effort: a failure here must never prevent the deployment coming up. + import time + import torch + + # Representative window lengths (seconds). The dominant per-window + # transformer forward is a fixed shape warmed by any window, but the + # trailing remainder forward varies with total length, so warm a spread. + secs = sorted({ + round(s, 3) + for s in ( + _MIN_CHUNK_SEC, # first tiny window of an utterance + _CONTEXT_SEC, # context-only sized window + _CONTEXT_SEC + _MIN_CHUNK_SEC, # steady-state window + _CONTEXT_SEC + 2 * _MIN_CHUNK_SEC, # a larger fresh chunk + ) + if s and s > 0 + }) or [3.0] + + try: + t0 = time.time() + for pass_idx in range(2): + for s in secs: + n = max(1, int(_WARMUP_SRC_SR * s)) + dummy = np.zeros(n, dtype=np.float32) + ts = time.time() + self.infer(dummy, source_sr=_WARMUP_SRC_SR) + if self.device.type == "cuda": + torch.cuda.synchronize(self.device) + print( + f"[Gesture] warmup pass {pass_idx} {s:.2f}s " + f"({n} samples @ {_WARMUP_SRC_SR} Hz) in {time.time() - ts:.2f}s", + ) + print( + f"[Gesture] warmup done ({len(secs)} shapes x2) in {time.time() - t0:.2f}s", + ) + except Exception as e: # noqa: BLE001 — warmup is an optimisation, never fatal + print(f"[Gesture] WARNING warmup failed: {e!r}") + + def infer(self, audio_np: np.ndarray, source_sr: int = _EMAGE_SR) -> Motion: + import torch + import torch.nn.functional as F + + if source_sr != _EMAGE_SR: + import librosa + audio_np = librosa.resample(audio_np, orig_sr=source_sr, target_sr=_EMAGE_SR) + + audio_ts = torch.from_numpy(audio_np).to(self.device).unsqueeze(0) + speaker_id = torch.zeros(1, 1, dtype=torch.long, device=self.device) + + with torch.no_grad(): + ref_trans = torch.zeros(1, 1, 3, device=self.device) + latent_dict = self.model.inference(audio_ts, speaker_id, self.motion_vq) + + cfg = self.model.cfg + face_latent = latent_dict["rec_face"] if cfg.lf > 0 and cfg.cf == 0 else None + upper_latent = latent_dict["rec_upper"] if cfg.lu > 0 and cfg.cu == 0 else None + hands_latent = latent_dict["rec_hands"] if cfg.lh > 0 and cfg.ch == 0 else None + lower_latent = latent_dict["rec_lower"] if cfg.ll > 0 and cfg.cl == 0 else None + face_index = torch.max(F.log_softmax(latent_dict["cls_face"], dim=2), dim=2)[1] if cfg.cf > 0 else None + upper_index = torch.max(F.log_softmax(latent_dict["cls_upper"], dim=2), dim=2)[1] if cfg.cu > 0 else None + hands_index = torch.max(F.log_softmax(latent_dict["cls_hands"], dim=2), dim=2)[1] if cfg.ch > 0 else None + lower_index = torch.max(F.log_softmax(latent_dict["cls_lower"], dim=2), dim=2)[1] if cfg.cl > 0 else None + + all_pred = self.motion_vq.decode( + face_latent=face_latent, upper_latent=upper_latent, + lower_latent=lower_latent, hands_latent=hands_latent, + face_index=face_index, upper_index=upper_index, + lower_index=lower_index, hands_index=hands_index, + get_global_motion=True, ref_trans=ref_trans[:, 0], + ) + + t = all_pred["motion_axis_angle"].shape[1] + return Motion( + poses=all_pred["motion_axis_angle"].cpu().numpy().reshape(t, -1), + expressions=all_pred["expression"].cpu().numpy().reshape(t, -1), + trans=all_pred["trans"].cpu().numpy().reshape(t, -1), + ) + + +class Gesture(ModuleWithHandle): + """Gesture Module + + Consumes streaming Audio chunks produced by TTS and generates whole-body + SMPL-X motion using the EMAGE audio-to-gesture model. + + Sliding window + ────────────── + TTS emits short, uneven audio chunks. Running EMAGE on each chunk in + isolation produces motion that is jerky at chunk seams (the model has no + context across boundaries) and is slow because the per-chunk overhead is + re-paid for tiny inputs — and gets worse the longer the utterance runs if + naively re-fed the whole buffer. + + Instead we keep a rolling buffer and, each time at least ``min_chunk_sec`` + of fresh audio has arrived, run inference over a window of + ``[context_sec of already-spoken audio] + [the fresh audio]``. The context + primes the model so the seam is continuous; only the motion frames for the + fresh audio are emitted. The window length is bounded by + ``context_sec + chunk size`` so inference cost stays flat regardless of + utterance length. + + Seam blending + ───────────── + Priming with context keeps the *audio* continuous across a window, but the + motion still snaps at seams: EMAGE has no future context at a window's right + edge, so its last frames wind down differently from how the next window — + fully primed — opens, and at utterance boundaries it cold-starts from a rest + pose entirely. So each fresh segment is eased onto the previously emitted + frame: poses, expressions and root translation all start exactly continuous + and a cosine-decaying offset fades to zero over ``blend_sec``, restoring the + model's intended motion (and avoiding the cumulative drift a constant rebase + would cause). The anchors survive the end-of-utterance reset, so the first + window of the next utterance blends out of the pose still on screen instead + of teleporting. + + input: audio (Audio) + output: motion (Motion) + + :hf_repo: HuggingFace repository to load EMAGE weights from. + :device: PyTorch device string; defaults to CUDA when available. + :context_sec: Seconds of prior audio prepended to each window for continuity. + :min_chunk_sec: Minimum seconds of fresh audio to accumulate before inferring. + :blend_sec: Seconds over which each window's seam is eased onto the prior frame. + """ + + _handle_cls = GestureDeployment + input_type = "audio" + output_type = "motion" + + def __init__( + self, + _handle: handle.DeploymentHandle, + context_sec: float = _CONTEXT_SEC, + min_chunk_sec: float = _MIN_CHUNK_SEC, + blend_sec: float = _BLEND_SEC, + ): + super().__init__(_handle) + self._context_sec = float(context_sec) + self._min_chunk_sec = float(min_chunk_sec) + self._blend_sec = float(blend_sec) + + # Per-utterance sliding-window state. All sample counts are in the + # source sample rate; resampling to 16 kHz happens once inside infer(). + self._lock = asyncio.Lock() + self._sr: Optional[int] = None + self._buffer = np.empty(0, dtype=np.float32) # trailing audio (ctx + unprocessed) + self._buf_start = 0 # source-sr sample index of buffer[0] in utterance timeline + self._emitted = 0 # source-sr samples whose motion has been emitted + + # Last emitted frame per channel, used to ease the next segment's seam. + # These persist across the end-of-utterance reset (see _end_utterance) so + # gestures stay continuous when a new utterance starts. + self._trans_anchor: Optional[np.ndarray] = None + self._pose_anchor: Optional[np.ndarray] = None + self._expr_anchor: Optional[np.ndarray] = None + + def _end_utterance(self) -> None: + # Reset only per-utterance buffering/timeline state. The seam anchors + # deliberately survive so the first window of the next utterance eases + # out of the pose currently on screen instead of snapping to EMAGE's + # cold-start rest pose. + self._sr = None + self._buffer = np.empty(0, dtype=np.float32) + self._buf_start = 0 + self._emitted = 0 + + async def process(self, audio: Audio) -> AsyncGenerator[Motion, None]: # type: ignore[override] + # Each chunk arrives as its own process() task on the shared per-session + # instance, so serialise under a lock to keep the buffer ordered. + async with self._lock: + if audio.data.size > 0: + if self._sr is None: + self._sr = audio.sample_rate + self._buffer = np.concatenate( + [self._buffer, audio.data.astype(np.float32)] + ) + + sr = self._sr + end_of_utterance = audio.end + + if sr is None: + # Nothing buffered yet (e.g. a lone end marker). Reset and bail. + if end_of_utterance: + self._end_utterance() + return + + ctx_samples = int(self._context_sec * sr) + min_new_samples = int(self._min_chunk_sec * sr) + + global_end = self._buf_start + len(self._buffer) + new_samples = global_end - self._emitted + + # Wait for more audio unless this is the final flush of the utterance. + if new_samples <= 0 or (not end_of_utterance and new_samples < min_new_samples): + if end_of_utterance: + self._end_utterance() + return + + motion = await self._infer_window(sr, ctx_samples, global_end) + if motion is not None: + yield motion + + if end_of_utterance: + self._end_utterance() + + async def _infer_window( + self, sr: int, ctx_samples: int, global_end: int + ) -> Optional[Motion]: + # Window = [context of already-emitted audio] + [fresh audio]. + win_start = max(self._buf_start, self._emitted - ctx_samples) + window = self._buffer[win_start - self._buf_start :] + if window.size == 0: + return None + + motion: Motion = await self._handle.infer.remote(window, sr) + total_frames = motion.poses.shape[0] + + # EMAGE's internal windowing (EmageAudioModel.inference) emits a + # contiguous *prefix* of the requested window and silently drops up to + # ~2*seed_frames frames off the END whenever the trailing partial window + # is shorter than its motion seed. So the returned frames cover only + # [win_start, win_start + total_frames] — not necessarily the whole + # window. Map emission off the actual frame count, not the requested + # length: otherwise the freshest motion is dropped while _emitted skips + # over it, tearing a hole in the timeline that reads as a freeze-then- + # jump (and drifts gesture out of sync with speech). + covered_end = win_start + int(round(total_frames * sr / motion.fps)) + + # Drop the leading frames that correspond to the context (already emitted). + skip_sec = (self._emitted - win_start) / sr + skip_frames = int(round(skip_sec * motion.fps)) + skip_frames = max(0, min(skip_frames, total_frames)) + + poses = motion.poses[skip_frames:].copy() + expressions = motion.expressions[skip_frames:].copy() + trans = motion.trans[skip_frames:].copy() + + # Advance only past audio the model actually turned into motion; any + # dropped tail stays buffered and is re-inferred next window, this time + # with real right-context. Cap at global_end so rounding can't overrun + # the buffer, and never move backwards. + self._emitted = min(global_end, max(self._emitted, covered_end)) + self._trim_buffer(ctx_samples) + + if poses.shape[0] == 0: + return None + + # Ease this segment's seam onto the last emitted frame. Poses and + # expressions snap because EMAGE regenerates the boundary without the + # right-context the next window will have (and cold-starts across + # utterances); root translation snaps because every window restarts near + # the origin. Blending all three keeps the seam continuous, and the + # decaying (vs. constant) offset returns to the model's intended motion + # so root translation doesn't accumulate drift across windows. + blend_frames = int(round(self._blend_sec * motion.fps)) + self._blend_into(poses, self._pose_anchor, blend_frames) + self._blend_into(expressions, self._expr_anchor, blend_frames) + self._blend_into(trans, self._trans_anchor, blend_frames) + + self._pose_anchor = poses[-1].copy() + self._expr_anchor = expressions[-1].copy() + self._trans_anchor = trans[-1].copy() + + out = Motion( + poses=poses, + expressions=expressions, + trans=trans, + fps=motion.fps, + pts=win_start / sr + skip_sec, # == self._emitted_before / sr + ) + return out + + @staticmethod + def _blend_into( + arr: np.ndarray, anchor: Optional[np.ndarray], blend_frames: int + ) -> None: + """Ease the start of a fresh segment onto ``anchor`` in place. + + Frame 0 is shifted to equal ``anchor`` (a continuous seam) and the + offset fades to zero over ``blend_frames`` with a cosine ease — zero + slope at both ends, so neither the value nor its velocity jumps — after + which the segment is the model's untouched output. + + Poses are SMPL-X axis-angle, so this is a linear blend in axis-angle + space: exact only for small seam offsets, which is the regime here since + consecutive frames are already close. A quaternion slerp would be needed + for large discontinuities but is overkill for seam clean-up. + """ + if anchor is None or blend_frames <= 0 or arr.shape[0] == 0: + return + n = min(blend_frames, arr.shape[0]) + w = 0.5 * (1.0 + np.cos(np.pi * np.linspace(0.0, 1.0, n, dtype=arr.dtype))) + arr[:n] += w[:, None] * (anchor - arr[0]) + + def _trim_buffer(self, ctx_samples: int) -> None: + # Keep one context window of audio before the last *emitted* sample so + # the next window stays bounded — but never discard audio whose motion + # hasn't been emitted yet (the dropped tail above lives between + # _emitted and the buffer end). + keep_from = self._emitted - ctx_samples + if keep_from > self._buf_start: + self._buffer = self._buffer[keep_from - self._buf_start :] + self._buf_start = keep_from diff --git a/src/modules/modules.py b/src/modules/modules.py index 8fbc53c..d8fefb9 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -4,9 +4,12 @@ from src.modules.speech_to_text.microphone_vad import MIC from src.modules.speech_to_text.speech_to_text import STT from src.modules.speech_to_text.text_aggregator import TAG +from src.modules.text_to_speech.text_to_speech import TTS +from src.modules.gesture.gesture import Gesture from .factory import Module def get_modules() -> Dict[str, Type[Module]]: - return {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG} + modules: Dict[str, Type[Module]] = {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG, "tts": TTS, "gesture": Gesture} + return modules diff --git a/src/modules/rag/docker_services.py b/src/modules/rag/docker_services.py deleted file mode 100644 index cc9f4b5..0000000 --- a/src/modules/rag/docker_services.py +++ /dev/null @@ -1,240 +0,0 @@ -import socket -import subprocess -import time -from typing import Any - -import httpx -from ray import serve - - -def find_free_port() -> Any: - """ - Ask the OS for a random free port. - We need this because if we run multiple Ollama containers, - they can't all use port 11434 — each needs its own. - """ - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -def wait_for_service(url: str, timeout: int = 120) -> bool: - """ - Returns True if ready, False if timeout. - """ - start = time.time() - while time.time() - start < timeout: - try: - resp = httpx.get(url, timeout=5) - if resp.status_code == 200: - return True - except Exception: - pass - time.sleep(2) - return False - - -def is_container_running(name: str) -> bool: - """Check if a Docker container with this name is already running.""" - result = subprocess.run( - ["docker", "ps", "-q", "-f", f"name=^{name}$"], - capture_output=True, - text=True, - ) - return bool(result.stdout.strip()) - - -def remove_container(name: str): - """Force remove a container by name (ignores errors if it doesn't exist).""" - subprocess.run(["docker", "rm", "-f", name], capture_output=True) - - -@serve.deployment -class OllamaService: - """ - Manages one Ollama Docker container. - - LIFECYCLE: - __init__: starts container -> waits for it -> pulls model - generate: sends a prompt to the container, returns the answer - __del__: stops and removes the container - """ - - def __init__( - self, - model: str = "mistral:7b", - image: str = "ollama/ollama:latest", - gpu_devices: bool = False, - ): - self.model = model - self.port = find_free_port() - self.container_name = f"ollama-ray-{self.port}" - self.base_url = f"http://localhost:{self.port}" - - remove_container(self.container_name) - - cmd = [ - "docker", - "run", - "-d", - "--name", - self.container_name, - "-p", - f"{self.port}:11434", - "-v", - "ollama_shared:/root/.ollama", - ] - - if gpu_devices: - cmd.extend( - [ - "--device=/dev/kfd", - "--device=/dev/dri", - "--group-add=video", - ] - ) - - cmd.append(image) - - print(f"[OllamaService] Starting container \ -'{self.container_name}' on port {self.port}...") - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - raise RuntimeError(f"Docker failed: {result.stderr}") - - print("[OllamaService] Waiting for Ollama to be ready...") - if not wait_for_service(f"{self.base_url}/api/tags"): - raise RuntimeError(f"Ollama didn't start within \ -timeout on port {self.port}") - - print(f"[OllamaService] Pulling model '{model}'...") - pull_result = subprocess.run( - ["docker", "exec", self.container_name, "ollama", "pull", model], - capture_output=True, - text=True, - ) - if pull_result.returncode != 0: - raise RuntimeError(f"Failed to pull model: {pull_result.stderr}") - - print(f"[OllamaService] Ready! \ -container='{self.container_name}', port={self.port}, model='{model}'") - - async def generate( - self, - messages: list, - max_tokens: int = 1024, - temperature: float = 0.1, - ) -> Any: - """ - Send messages to Ollama and return the response. - This is what RAGHandle calls to get LLM answers. - """ - async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post( - f"{self.base_url}/api/chat", - json={ - "model": self.model, - "messages": messages, - "stream": False, - "options": { - "num_predict": max_tokens, - "temperature": temperature, - }, - }, - ) - resp.raise_for_status() - return resp.json()["message"]["content"] - - async def health(self) -> dict: - """Check if this Ollama instance is alive.""" - try: - async with httpx.AsyncClient(timeout=5.0) as client: - await client.get(f"{self.base_url}/api/tags") - return { - "status": "ok", - "port": self.port, - "container": self.container_name, - } - except Exception as e: - return {"status": "error", "error": str(e)} - - def __del__(self): - """Cleanup when Ray destroys this replica.""" - print(f"[OllamaService] Removing container '{self.container_name}'") - remove_container(self.container_name) - - -@serve.deployment(num_replicas=1) -class QdrantService: - """ - Manages a Qdrant Docker container. - - LIFECYCLE: - __init__: starts container (or reuses if already running) - get_url: returns the URL other services should connect to - __del__: leaves the container running (it has data!) - """ - - def __init__( - self, - port: int = 6333, - image: str = "qdrant/qdrant:latest", - storage_volume: str = "qdrant_data", - ): - self.port = port - self.container_name = "qdrant-ray" - self.url = f"http://localhost:{self.port}" - - if self._is_healthy(): - print(f"[QdrantService] Qdrant already running on port {self.port}") - return - - remove_container(self.container_name) - - cmd = [ - "docker", - "run", - "-d", - "--name", - self.container_name, - "-p", - f"{self.port}:6333", - "-v", - f"{storage_volume}:/qdrant/storage", - image, - ] - - print(f"[QdrantService] Starting Qdrant on port {self.port}...") - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - raise RuntimeError(f"Docker failed: {result.stderr}") - - if not wait_for_service(f"{self.url}/healthz"): - raise RuntimeError( - f"Qdrant didn't start within timeout on port {self.port}" - ) - - print(f"[QdrantService] Ready on port {self.port}") - - def _is_healthy(self) -> bool: - try: - resp = httpx.get(f"{self.url}/healthz", timeout=3) - return resp.status_code == 200 - except Exception: - return False - - async def get_url(self) -> str: - """Return the URL. Called by RAGHandle to know where Qdrant is.""" - return self.url - - async def health(self) -> dict: - try: - async with httpx.AsyncClient(timeout=5.0) as client: - await client.get(f"{self.url}/healthz") - return {"status": "ok", "port": self.port, "url": self.url} - except Exception as e: - return {"status": "error", "error": str(e)} - - def __del__(self): - print(f"[QdrantService] Actor destroyed. \ -Container '{self.container_name}' left running.") diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index f4e4dae..529c7ff 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -7,7 +7,8 @@ from pathlib import Path from typing import Any, List -from pypdf import PdfReader +import httpx +import numpy as np from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, @@ -17,12 +18,37 @@ PointStruct, VectorParams, ) -from semantic_chunker import SemanticChunker -from sentence_transformers import SentenceTransformer USER_ID_FILE = os.path.expanduser("~/.huri_user_id") +class RemoteEmbedder: + """Embed via an OpenAI-compatible ``/v1/embeddings`` endpoint (e.g. llama.cpp). + + Drop-in for the subset of ``SentenceTransformer`` this tool uses: a single + ``.encode(text, normalize_embeddings=...)`` returning a 1-D numpy array, so + the existing ``.tolist()`` / ``len(...)`` call sites keep working unchanged. + """ + + def __init__(self, url: str, model_name: str): + self.url = url.rstrip("/") + self.model_name = model_name + self._client = httpx.Client(timeout=60.0, verify=False) + + def encode(self, text: str, normalize_embeddings: bool = True) -> np.ndarray: + resp = self._client.post( + f"{self.url}/v1/embeddings", + json={"model": self.model_name, "input": str(text)}, + ) + resp.raise_for_status() + vec = np.asarray(resp.json()["data"][0]["embedding"], dtype=np.float32) + if normalize_embeddings: + norm = np.linalg.norm(vec) + if norm > 0: + vec = vec / norm + return vec + + def _split_sentences(text: str) -> list[str]: """Simple sentence splitter.""" result: List = [] @@ -77,6 +103,8 @@ def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str] def extract_text_from_pdf(pdf_path: str) -> str: """Extract text from a PDF file.""" try: + from pypdf import PdfReader + reader = PdfReader(pdf_path) text = "" for page in reader.pages: @@ -116,7 +144,7 @@ def ensure_collection(client: QdrantClient, collection: str, vector_size: int): def ingest_chunks( client: QdrantClient, - model: SentenceTransformer, + model: Any, collection: str, chunks: list[str], _user_id: str, @@ -154,9 +182,12 @@ def ingest_chunks( return len(points) -def chunk_strat(text: str, args, model: SentenceTransformer) -> list[str] | Any: +def chunk_strat(text: str, args, model: Any) -> list[str] | Any: """Pick the right chunking strategy based on args.""" if args.chunking == "semantic": + # Thomas: I need to import here, bceause it takes too much time earlier, or use a jupyter notebook to do it instead + from .semantic_chunker import SemanticChunker + chunker = SemanticChunker( model=model, strategy=args.semantic_strategy, @@ -287,6 +318,50 @@ def cmd_write(args, client, model, _user_id): print(f"Done. Ingested {count} chunks as '{title}'") +def cmd_profile(args, client, model, _user_id): + """Store always-on profile facts about the user (name, etc.). + + Unlike regular documents, profile facts are NOT retrieved by vector + similarity. The RAG handle pulls them by filter (_user_id + type=profile) + on every query and injects them into the system prompt, so the character + always knows them. + """ + sample = model.encode("test", normalize_embeddings=True) + ensure_collection(client, args.collection, len(sample)) + + facts: List[str] = [] + if args.name: + facts.append(f"The user's name is {args.name}.") + for fact in args.fact or []: + facts.append(fact) + + if not facts: + print("Nothing to store. Use --name and/or --fact 'some fact'.") + return + + # Replace the existing profile so facts don't pile up across runs. + client.delete( + collection_name=args.collection, + points_selector=Filter( + must=[ + FieldCondition(key="_user_id", match=MatchValue(value=_user_id)), + FieldCondition(key="type", match=MatchValue(value="profile")), + ] + ), + ) + + count = ingest_chunks( + client, + model, + args.collection, + facts, + _user_id, + source="profile", + doc_type="profile", + ) + print(f"Stored {count} profile fact(s) for user {_user_id}") + + def cmd_list(args, client, model, _user_id): """List what's in the database for this user.""" @@ -355,7 +430,23 @@ def main(): parser.add_argument("--user-id", type=str, default=None) parser.add_argument("--collection", type=str, default="documents") parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333") + parser.add_argument( + "--no-verify-ssl", + action="store_true", + default=False, + help="Disable SSL certificate verification (needed for self-signed LAN certs).", + ) parser.add_argument("--embedding-model", type=str, default="BAAI/bge-large-en-v1.5") + parser.add_argument( + "--embedding-url", + type=str, + default="", + help=( + "OpenAI-compatible embedding endpoint (e.g. llama.cpp at " + "http://localhost:8080). When set, embeddings are computed remotely " + "instead of with a local SentenceTransformer. Requires --chunking fixed." + ), + ) parser.add_argument( "--chunk-size", type=int, @@ -394,6 +485,16 @@ def main(): p_write = subparsers.add_parser("write", help="Write text interactively") p_write.add_argument("--title", type=str, default=None, help="Title/source name") + p_profile = subparsers.add_parser( + "profile", help="Store always-on profile facts (name, etc.)" + ) + p_profile.add_argument("--name", type=str, default=None, help="User's name") + p_profile.add_argument( + "--fact", + action="append", + help="A fact about the user, e.g. --fact 'Likes cheese' (repeatable)", + ) + subparsers.add_parser("list", help="List ingested documents") p_delete = subparsers.add_parser("delete", help="Delete documents by source") @@ -403,16 +504,42 @@ def main(): args = parser.parse_args() - _user_id = get_user_id(args._user_id) + if args.embedding_url and args.chunking == "semantic": + parser.error( + "--chunking semantic needs a local SentenceTransformer model and " + "cannot run over --embedding-url. Use --chunking fixed." + ) + + _user_id = get_user_id(args.user_id) print(f"User: {_user_id}") - client = QdrantClient(url=args.qdrant_url) - model = SentenceTransformer(args.embedding_model) + verify_ssl = not args.no_verify_ssl + try: + from .qdrant_utils import make_qdrant_client + except ImportError: + from qdrant_utils import make_qdrant_client + client = make_qdrant_client(args.qdrant_url, verify_ssl) + + # Lazy-load the model only if the command needs embeddings. + # Commands that don't need it: list, delete, profile (doesn't use embeddings). + needs_embeddings = args.command in ("pdf", "text", "write", "profile") + + if needs_embeddings: + if args.embedding_url: + print(f"Embedding remotely via {args.embedding_url} (model={args.embedding_model})") + model = RemoteEmbedder(args.embedding_url, args.embedding_model) + else: + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer(args.embedding_model) + else: + model = None commands = { "pdf": cmd_pdf, "text": cmd_text, "write": cmd_write, + "profile": cmd_profile, "list": cmd_list, "delete": cmd_delete, } diff --git a/src/modules/rag/qdrant_utils.py b/src/modules/rag/qdrant_utils.py new file mode 100644 index 0000000..080980e --- /dev/null +++ b/src/modules/rag/qdrant_utils.py @@ -0,0 +1,31 @@ +"""Shared Qdrant client construction. + +Centralises the URL→client parsing used by both the RAGHandle deployment +(``rag.py``) and the offline ingestion CLI (``ingestion.py``), so the port/SSL +handling lives in exactly one place instead of being copy-pasted. +""" + +from urllib.parse import urlparse + +from qdrant_client import QdrantClient + +with open("portal.tmp", 'w') as f: + f.write("WTFF SWORKS ,??,,\n") + +def make_qdrant_client(qdrant_url: str, verify_ssl: bool = True) -> QdrantClient: + """Build a :class:`QdrantClient` from a URL. + + Parses the URL explicitly so the client gets the correct host/port/https. + When given just ``https://host`` with no port, some qdrant-client versions + silently fall back to their default port (6333) instead of 443, causing a + timeout that looks like an SSL issue — so derive the port from the scheme. + """ + parsed = urlparse(qdrant_url) + is_https = parsed.scheme == "https" + return QdrantClient( + host=parsed.hostname, + port=parsed.port or (443 if is_https else 6333), + https=is_https, + verify=verify_ssl, + check_compatibility=verify_ssl, + ) diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 6b9744d..0e65bbe 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -1,17 +1,46 @@ +import json +import os +import traceback from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Any, AsyncGenerator -import httpx -from qdrant_client import QdrantClient -from qdrant_client.models import FieldCondition, Filter, MatchValue +from pydantic import BaseModel from ray import serve from ray.serve import handle -from sentence_transformers import SentenceTransformer from src.core.module import ModuleWithHandle, ModuleWithId from src.modules.speech_to_text.events import Sentence +from src.modules.text_to_speech.events import Token + +import httpx -from .events import RAGResult +from qdrant_client.models import FieldCondition, Filter, MatchValue +from .qdrant_utils import make_qdrant_client + +# Default character persona. Overridable per session via the `persona` key in the +# client config's module args, or globally via HURI_RAG_DEFAULT_PERSONA in the +# Serve app runtime_env.env_vars (see deploy values.yaml) — no rebuild needed. +_DEFAULT_PERSONA = os.environ.get( + "HURI_RAG_DEFAULT_PERSONA", + "You are Mouse-Man, a witty, charismatic animated mouse character. " + "You are NOT an AI, an assistant, or a language model, and you must " + "never say that you are one. Stay fully in character at all times, " + "speak naturally and with personality, and never break the illusion.", +) + + +class RAGDeploymentConfig(BaseModel): + qdrant_url: str = "http://localhost:6333" + default_collection: str = "documents" + embedding_model: str = "BAAI/bge-large-en-v1.5" + embedding_url: str = "" + llm_provider: str = "ollama" # "vllm", "ollama", "api" + llm_url: str = "http://localhost:11434" + llm_model: str = "mistral:7b" + llm_api_key: str = "" + verify_ssl: bool = True + top_k: int = 5 + score_threshold: float = 0.5 @dataclass @@ -21,78 +50,83 @@ class RAGQuery: _user_id: str question: str preferences: dict = field(default_factory=dict) - # preferences can include: language, tone, - # response_format, max_length, system_prompt, extra_instructions, etc. + # Prior conversation turns as OpenAI-style messages + # ([{"role": "user"|"assistant", "content": str}, ...]). The handle is + # stateless, so the per-session RAG module owns and supplies this. + history: list = field(default_factory=list) -@serve.deployment( - num_replicas=2, - ray_actor_options={"num_cpus": 1}, -) +@serve.deployment(name="RAGHandle") class RAGHandle: - """ - Stateless RAG processor. Knows nothing about sessions. - Receives a _user_id + question, uses _user_id to find the right - collection/data in the vector DB, runs embed -> search -> LLM. - """ + """Stateless RAG processor. Streams LLM tokens to the caller.""" - def __init__( - self, - ollama_handle=None, - qdrant_handle=None, - qdrant_url: str = "http://localhost:6333", - default_collection: str = "documents", - embedding_model: str = "BAAI/bge-large-en-v1.5", - llm_provider: str = "ollama", # "vllm", "ollama", "api" - llm_url: str = "http://localhost:11434", - llm_model: str = "mistral:7b", - llm_api_key: str = "", - top_k: int = 5, - score_threshold: float = 0.5, - ): - self.embed_model = SentenceTransformer(embedding_model) - self.default_collection = default_collection - self.top_k = top_k - self.score_threshold = score_threshold - - self.llm_provider = llm_provider - self.llm_url = llm_url - self.llm_model = llm_model - self.llm_api_key = llm_api_key - - self.ollama_handle = ollama_handle - self.qdrant_handle = qdrant_handle - - self._qdrant_url = qdrant_url - self._qdrant: QdrantClient | None = None - - async def _get_qdrant(self): - """Connect to Qdrant on first use. Solves the async-in-init problem.""" - if self._qdrant is None: - if self.qdrant_handle: - self._qdrant_url = await self.qdrant_handle.get_url.remote() - self._qdrant = QdrantClient(url=self._qdrant_url) - print(f"[RAGHandle] Connected to Qdrant at {self._qdrant_url}") - return self._qdrant + def __init__(self, **kwargs): + self._cfg = RAGDeploymentConfig(**kwargs) + self._apply_config() - def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: - """ - Given a _user_id, decide which collection to search - and which filters to apply. + def reconfigure(self, config: dict) -> None: + self._cfg = RAGDeploymentConfig(**{**self._cfg.model_dump(), **config}) + self._apply_config() - Options (pick what fits your data model): - A) One collection per user: collection = f"user_{_user_id}" - B) Shared collection, filter by _user_id in payload - C) Lookup in a DB to find the user's config - """ + def _apply_config(self) -> None: + cfg = self._cfg + self.embedding_url = cfg.embedding_url or cfg.llm_url + self._qdrant = make_qdrant_client(cfg.qdrant_url, cfg.verify_ssl) + print(f"[RAGHandle] Connected to Qdrant at {cfg.qdrant_url}") + self._embed_client = httpx.AsyncClient(timeout=30.0, verify=cfg.verify_ssl) + self._llm_client = httpx.AsyncClient(timeout=120.0, verify=cfg.verify_ssl) - collection = self.default_collection + def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: + collection = self._cfg.default_collection filters = {"_user_id": _user_id} - return collection, filters - def _embed(self, text) -> list[float] | Any: - return self.embed_model.encode(str(text), normalize_embeddings=True).tolist() + async def _embed(self, text: str) -> list[float]: + url = f"{self.embedding_url}/v1/embeddings" + resp = await self._embed_client.post( + url, + json={"model": self._cfg.embedding_model, "input": str(text)}, + ) + if resp.status_code != 200: + raise RuntimeError( + f"Embedding HTTP {resp.status_code} from {url}: {resp.text[:1000]}" + ) + try: + payload = resp.json() + except Exception as e: + raise RuntimeError( + f"Embedding non-JSON response from {url}: {resp.text[:1000]}" + ) from e + try: + return payload["data"][0]["embedding"] + except (KeyError, IndexError, TypeError) as e: + raise RuntimeError( + f"Embedding unexpected schema from {url}: {str(payload)[:1000]}" + ) from e + + def _get_profile(self, collection: str, _user_id: str) -> list[str]: + """Always-on facts about the user (name, etc.). + + Retrieved deterministically by filter — NOT by vector similarity — + so they are always available to the prompt regardless of the question. + Populated via `ingestion.py profile`. + """ + try: + points, _ = self._qdrant.scroll( + collection_name=collection, + scroll_filter=Filter( + must=[ + FieldCondition(key="_user_id", match=MatchValue(value=_user_id)), + FieldCondition(key="type", match=MatchValue(value="profile")), + ] + ), + limit=50, + with_payload=True, + with_vectors=False, + ) + except Exception: + return [] + return [p.payload.get("text", "") for p in points if p.payload.get("text")] def _search( self, @@ -101,7 +135,6 @@ def _search( collection: str, filters: dict | None = None, ) -> list[dict]: - qdrant_filter: Any = None if filters: conditions: Any = [ @@ -115,8 +148,8 @@ def _search( collection_name=collection, query=query_vector, query_filter=qdrant_filter, - limit=self.top_k, - score_threshold=self.score_threshold, + limit=self._cfg.top_k, + score_threshold=self._cfg.score_threshold, ).points except Exception: results = [] @@ -134,12 +167,17 @@ def _build_prompt( question: str, chunks: list[dict], preferences: dict, + profile_facts: list[str] | None = None, ) -> tuple[str, str]: + persona = preferences.get("persona", _DEFAULT_PERSONA) + parts = [persona] + + if profile_facts: + parts.append( + "Here is what you know about the person you're talking to: " + + " ".join(profile_facts) + ) - parts = [ - "You are a robot speaking to a user. Answer based on the provided context.", - "If the context is insufficient, say so clearly.", - ] if preferences.get("language"): parts.append(f"Always respond in {preferences['language']}.") if preferences.get("tone"): @@ -150,6 +188,14 @@ def _build_prompt( parts.append("Keep your answer to 2-3 sentences maximum.") if preferences.get("extra_instructions"): parts.append(preferences["extra_instructions"]) + + parts.append( + "Use the context in the user's message to inform your answers when " + "it is relevant, but always answer in character. If you don't know " + "something, improvise in character rather than admitting you lack " + "information or breaking character. " + "IMPORTANT: Reply in 1-3 short sentences maximum. Be extremely concise. No lists, no emojis, no long explanations." + ) system_prompt = " ".join(parts) if not chunks: @@ -162,160 +208,236 @@ def _build_prompt( context_parts = [] for i, chunk in enumerate(chunks, 1): source = chunk["metadata"].get("source", "unknown") - context_parts.append(f"[{i}] (source: {source}, score: \ -{chunk['score']:.2f})\n{chunk['text']}") + context_parts.append( + f"[{i}] (source: {source}, score: {chunk['score']:.2f})\n" + f"{chunk['text']}" + ) context_block = "\n\n".join(context_parts) user_prompt = ( f"Context:\n{context_block}\n\n" f"Question: {question}\n\n" - "Answer based on the context above.\ -Don't speak about the sources, just use them to answer the question." + "Answer based on the context above. " + "Don't speak about the sources, just use them to answer." ) return system_prompt, user_prompt - async def _llm_generate( + async def _stream_ollama( + self, messages: list, max_tokens: int, temperature: float = 0.7 + ) -> AsyncGenerator[str, None]: + async with self._llm_client.stream( + "POST", + f"{self._cfg.llm_url}/api/chat", + json={ + "model": self._cfg.llm_model, + "messages": messages, + "stream": True, + "options": {"num_predict": max_tokens, "temperature": temperature}, + }, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + try: + chunk = json.loads(line) + except json.JSONDecodeError: + continue + delta = chunk.get("message", {}).get("content", "") + if delta: + yield delta + if chunk.get("done"): + return + + async def _stream_openai_compatible( self, - system_prompt: str, - user_prompt: str, - preferences: dict, - ) -> Any: - max_tokens = preferences.get("max_length", 1024) - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - if self.ollama_handle: - return await self.ollama_handle.generate.remote(messages, max_tokens) - - if self.llm_provider == "vllm": - return await self._call_openai_compatible( - f"{self.llm_url}/v1/chat/completions", messages, max_tokens - ) - elif self.llm_provider == "ollama": - return await self._call_ollama(messages, max_tokens) - - elif self.llm_provider == "api": - return await self._call_openai_compatible( - f"{self.llm_url}/v1/chat/completions", - messages, - max_tokens, - self.llm_api_key, - ) - else: - raise ValueError(f"Unknown llm_provider: {self.llm_provider}") - - async def _call_openai_compatible( - self, url: str, messages: list, max_tokens: int, api_key: str = "" - ) -> Any: + url: str, + messages: list, + max_tokens: int, + api_key: str = "", + temperature: float = 0.7, + ) -> AsyncGenerator[str, None]: headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" - async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post( + async with self._llm_client.stream( + "POST", url, headers=headers, json={ - "model": self.llm_model, + "model": self._cfg.llm_model, "messages": messages, "max_tokens": max_tokens, - "temperature": 0.1, + "temperature": temperature, + "stream": True, }, - ) - resp.raise_for_status() - return resp.json()["choices"][0]["message"]["content"] - - async def _call_ollama(self, messages: list, max_tokens: int) -> Any: - async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post( - f"{self.llm_url}/api/chat", - json={ - "model": self.llm_model, - "messages": messages, - "stream": False, - "options": {"num_predict": max_tokens, "temperature": 0.1}, - }, - ) - resp.raise_for_status() - return resp.json()["message"]["content"] - - async def process(self, query: RAGQuery) -> RAGResult: - """ - Main entry point. Called by the RAG module. - Uses _user_id to determine which collection / filters to use. - """ + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line or not line.startswith("data:"): + continue + payload = line[len("data:"):].strip() + if payload == "[DONE]": + return + try: + chunk = json.loads(payload) + except json.JSONDecodeError: + continue + delta = ( + chunk.get("choices", [{}])[0] + .get("delta", {}) + .get("content", "") + ) + if delta: + yield delta + + async def _llm_stream( + self, + system_prompt: str, + user_prompt: str, + preferences: dict, + history: list | None = None, + ) -> AsyncGenerator[str, None]: + max_tokens = preferences.get("max_length", 1024) + temperature = preferences.get("temperature", 0.7) + messages = [{"role": "system", "content": system_prompt}] + if history: + messages.extend(history) + messages.append({"role": "user", "content": user_prompt}) + + if self._cfg.llm_provider == "vllm": + async for d in self._stream_openai_compatible( + f"{self._cfg.llm_url}/v1/chat/completions", + messages, + max_tokens, + temperature=temperature, + ): + yield d + elif self._cfg.llm_provider == "api": + async for d in self._stream_openai_compatible( + f"{self._cfg.llm_url}/v1/chat/completions", + messages, + max_tokens, + self._cfg.llm_api_key, + temperature=temperature, + ): + yield d + elif self._cfg.llm_provider == "ollama": + async for d in self._stream_ollama(messages, max_tokens, temperature): + yield d + else: + raise ValueError(f"Unknown llm_provider: {self._cfg.llm_provider}") + async def stream(self, query: RAGQuery) -> AsyncGenerator[str, None]: + """Main streaming entry point — yields LLM text deltas.""" print(f"[RAG] Question: {query.question}") - qdrant = await self._get_qdrant() - collection, filters = self._resolve_user_context(query._user_id) - query_vector = self._embed(query.question) - chunks = self._search(qdrant, query_vector, collection, filters) + query_vector = await self._embed(query.question) - print(f"[RAG] Found {len(chunks)} chunks") - for c in chunks: - print(f" - score: {c['score']:.2f} | {c['text'][:100]}...") + try: + chunks = self._search(self._qdrant, query_vector, collection, filters) + except Exception: + print(f"[RAG] FAILED during Qdrant search:\n{traceback.format_exc()}") + raise + print(f"[RAG] Found {len(chunks)} chunks") + profile_facts = self._get_profile(collection, query._user_id) + if profile_facts: + print(f"[RAG] Loaded {len(profile_facts)} profile fact(s)") system_prompt, user_prompt = self._build_prompt( - query.question, chunks, query.preferences + query.question, chunks, query.preferences, profile_facts ) - print(f"[RAG] System prompt: {system_prompt[:200]}...") - answer = await self._llm_generate(system_prompt, user_prompt, query.preferences) - print(f"[RAG] Answer: {answer}") - - return RAGResult( - answer=answer, - sources=[ - {"text": c["text"], "score": c["score"], "metadata": c["metadata"]} - for c in chunks - ], + + print( + f"[RAG] Streaming from LLM at {self._cfg.llm_url} " + f"(provider={self._cfg.llm_provider}, model={self._cfg.llm_model}, " + f"history_msgs={len(query.history)})" ) + try: + async for delta in self._llm_stream( + system_prompt, user_prompt, query.preferences, query.history + ): + yield delta + except Exception: + print(f"[RAG] FAILED during LLM stream:\n{traceback.format_exc()}") + raise class RAG(ModuleWithHandle, ModuleWithId): + """RAG Module — streams LLM tokens. + + input: question (Sentence) + output: token (Token) + """ + _handle_cls = RAGHandle input_type = "question" - output_type = "rag_response" + output_type = "token" def __init__( self, - _handle: handle.DeploymentHandle[RAGHandle], + _handle: handle.DeploymentHandle, _user_id: str, language="en", tone="formal", response_format="paragraph", - max_length=1024, + max_length=220, extra_instructions="", + persona="", + temperature=0.7, + max_history_turns=6, **kwargs, ): super().__init__(_handle=_handle, _user_id=_user_id, **kwargs) + print(f"[RAG] Initialized with user_id={_user_id}, language={language}, tone={tone}, response_format={response_format}, max_length={max_length}, temperature={temperature}, max_history_turns={max_history_turns}") + self.preferences = { "language": language, "tone": tone, "response_format": response_format, "max_length": max_length, "extra_instructions": extra_instructions, + "temperature": temperature, } + if persona: + self.preferences["persona"] = persona - async def process(self, data: Sentence) -> Optional[RAGResult]: - """ - Called when a "question" event arrives through the event bus. - Packages _user_id + question, sends to the stateless RAGHandle. - """ - question_text = data.text + # Per-session conversation memory, kept on the (per-WebSocket) module + # instance because the RAGHandle deployment is stateless/shared. + # Stored as OpenAI-style messages; trimmed to the last N turns. + self._max_history_turns = max_history_turns + self.history: list[dict] = [] + async def process(self, data: Sentence) -> AsyncGenerator[Token, None]: # type: ignore[override] query = RAGQuery( _user_id=self._user_id if self._user_id else "anonymous", - question=question_text, + question=data.text, preferences=self.preferences, + history=list(self.history), # snapshot of prior turns ) - result: RAGResult = await self._handle.process.remote(query) - return result + parts: list[str] = [] + stream = self._handle.options(stream=True).stream.remote(query) + async for delta in stream: + parts.append(delta) + yield Token(text=delta, end=False) + yield Token(text="", end=True) + + self._record_turn(data.text, "".join(parts)) + + def _record_turn(self, question: str, answer: str) -> None: + """Append this turn to the session history (raw Q/A, no RAG context) + and trim to the most recent `max_history_turns` exchanges.""" + answer = answer.strip() + if not answer: + return + self.history.append({"role": "user", "content": question}) + self.history.append({"role": "assistant", "content": answer}) + max_msgs = self._max_history_turns * 2 + if len(self.history) > max_msgs: + del self.history[:-max_msgs] def update_preferences(self, new_preferences: dict): - """Client can update preferences mid-session via the event bus.""" self.preferences.update(new_preferences) diff --git a/src/modules/speech_to_text/microphone_vad.py b/src/modules/speech_to_text/microphone_vad.py index ffb82f7..eb52f74 100644 --- a/src/modules/speech_to_text/microphone_vad.py +++ b/src/modules/speech_to_text/microphone_vad.py @@ -13,7 +13,7 @@ class MIC(Module): Detect voice and silence using WebRTC VAD. - input: audio, + input: audio_in, output: voice :vad_agressiveness: from 0 (low) to 3 (high, can distord audio). @@ -23,7 +23,10 @@ class MIC(Module): Can only be 0.010, 0.020 and 0.030. """ - input_type = "audio" + # Inbound microphone frames travel on their own topic so the TTS-output + # "audio" topic (consumed by Gesture and the client Sender) never collides + # with mic input — otherwise raw mic bytes get echoed back to the client. + input_type = "audio_in" output_type = "voice" def __init__( diff --git a/src/modules/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/speech_to_text.py index 1300dd3..9e48a62 100644 --- a/src/modules/speech_to_text/speech_to_text.py +++ b/src/modules/speech_to_text/speech_to_text.py @@ -1,47 +1,93 @@ import asyncio +import os from typing import List, Optional import numpy as np -from faster_whisper import WhisperModel +from ray import serve +from ray.serve import handle -from src.core.module import Module +from src.core.module import ModuleWithHandle from .events import Transcript, Voice +_MODEL_PATH = os.environ.get("HURI_STT_MODEL_PATH", "base") -class STT(Module): + +@serve.deployment(name="STT") +class STTDeployment: + """faster-whisper model wrapper. + + Holds the WhisperModel and runs transcription on its own Ray Serve actor, + off the HuRI master actor — model load and GPU inference no longer block the + websocket ingress / per-session router. Pinned to a GPU worker via + ray_actor_options in the Serve config (see deploy values.yaml). + + Stateless across calls: the per-session sliding-window buffering lives in the + STT module, so this deployment is shared across all sessions. + + :model: path to (or size name of) the faster-whisper model. Defaults to the + HURI_STT_MODEL_PATH env var, falling back to "base". + :device: "cpu", "cuda", or "auto". + :compute_type: e.g. "int8", "float16", or "auto". + """ + + def __init__( + self, + model: str = _MODEL_PATH, + device: str = "auto", + compute_type: str = "auto", + ): + from faster_whisper import WhisperModel + + self.model_faster = WhisperModel( + model, + device=device, + compute_type=compute_type, + ) + + async def transcribe(self, audio: np.ndarray, language: str = "en") -> str: + segments, _ = self.model_faster.transcribe( + audio, + language=language, + beam_size=1, # faster for realtime + ) + return " ".join([seg.text for seg in segments]).strip() + + +class STT(ModuleWithHandle): """STT Module Transcribe voice using Faster_Whisper. + Holds the per-session sliding-window buffer and delegates the actual + transcription to a handle-backed STTDeployment, so the Whisper model runs + off the HuRI master node. + input: voice, output: transcript - :model: size of the model to use (tiny, tiny.en, base, base.en, small, - small.en, distil-small.en, medium, medium.en, distil-medium.en, - large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, - large-v3-turbo, or turbo). :language: language spoken in the audio. It should be a language code such as "en" or "fr". :sample_rate: size of received voice audio. Usually 8000, 16000 or 48000. :block_duration: size of received voice audio (in s). """ + _handle_cls = STTDeployment input_type = "voice" output_type = "transcript" def __init__( self, - model: str = "base", + _handle: handle.DeploymentHandle, language: str = "en", sample_rate: int = 16000, block_duration: float = 0.020, # s transcribe_window: float = 2.0, # s transcribe_step: float = 1.0, # s + **kwargs, ): - super().__init__() + super().__init__(_handle=_handle, **kwargs) - self.model_faster = WhisperModel(model) self.language = language self.sample_rate = sample_rate @@ -58,7 +104,7 @@ def __init__( self.running = False self.lock: asyncio.Lock = asyncio.Lock() - async def process(self, voice: Voice) -> Optional[Transcript]: + async def process(self, voice: Voice) -> Optional[Transcript]: # type: ignore[override] if voice.data is None: self.silence = True else: @@ -83,14 +129,10 @@ async def process(self, voice: Voice) -> Optional[Transcript]: self.pending_silence = False processing_audio = np.concatenate(processing_chunks, axis=0) - segments, _ = self.model_faster.transcribe( - processing_audio, - language=self.language, - beam_size=1, # faster for realtime + current_text = await self._handle.transcribe.remote( + processing_audio, self.language ) - current_text = " ".join([seg.text for seg in segments]).strip() - processed_size = self.window_size - self.step_size async with self.lock: self.buffer = self.buffer[processed_size:] diff --git a/src/modules/text_to_speech/__init__.py b/src/modules/text_to_speech/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/text_to_speech/events.py b/src/modules/text_to_speech/events.py new file mode 100644 index 0000000..dceb269 --- /dev/null +++ b/src/modules/text_to_speech/events.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +import numpy as np + +from src.core.events import EventData + + +@dataclass +class Token(EventData): + text: str + end: bool + + +@dataclass +class Audio(EventData): + data: np.ndarray + sample_rate: int + end: bool = False + pts: float = 0.0 # presentation timestamp in seconds from utterance start diff --git a/src/modules/text_to_speech/text_to_speech.py b/src/modules/text_to_speech/text_to_speech.py new file mode 100644 index 0000000..26a4b52 --- /dev/null +++ b/src/modules/text_to_speech/text_to_speech.py @@ -0,0 +1,239 @@ +import asyncio +import os +import queue +import sys +import traceback +import uuid +from typing import AsyncGenerator, Optional + +import numpy as np +from ray import serve +from ray.serve import handle + +from src.core.module import ModuleWithHandle + +from .events import Audio, Token + +# Defaults — overridden by env vars in production (see README.md) +_MODEL_PATH = os.environ.get( + "HURI_MODEL_PATH", "/models/cosytts/FunAudioLLM/Fun-CosyVoice3-0.5B-2512" +) +_VOICE_SAMPLE_PATH = os.environ.get("HURI_VOICE_SAMPLE_PATH", "/assets/voice.wav") +_DEFAULT_INSTRUCTION = "You are a helpful assistant." + + +def _normalize_transcript(raw: str) -> str: + """Make a reference transcript safe for CosyVoice3. + + CosyVoice3 expects "<|endofprompt|>". + If the configured transcript supplies a bare transcript (no marker), prepend + the default instruction so the transcript lands AFTER <|endofprompt|> — + otherwise the LM treats it as an instruction and intermittently renders it as + speech (prompt leakage). + """ + return ( + raw + if "<|endofprompt|>" in raw + else f"{_DEFAULT_INSTRUCTION}<|endofprompt|>{raw}" + ) + +_END_TEXT = object() # sentinel pushed into the text queue to close synth +_END_AUDIO = object() # sentinel pushed into the audio queue when synth completes +_DONE = object() # sentinel for exhausted sync generator + + +@serve.deployment(name="TTS", max_ongoing_requests=200) +class TTSDeployment: + """CosyVoice3 wrapper with per-session bistream synthesis. + + The model's `inference_zero_shot` accepts a Python generator as `tts_text` + and yields audio chunks as text arrives — that's the "bistream" mode. + Because the model call is fully synchronous, each session runs in a thread + via `run_in_executor` and is fed by a thread-safe `queue.Queue` that the + asyncio side pushes text into. + """ + + def __init__( + self, + model_path: str = _MODEL_PATH, + voice_sample_path: str = _VOICE_SAMPLE_PATH, + voice_sample_transcript: Optional[str] = None, + ): + # Resolve the reference transcript here (deploy time on the GPU worker) + # rather than at module import: importing this module must not require + # HURI_VOICE_TRANSCRIPT, since modules.py imports it inside a broad + # try/except that would otherwise make TTS silently vanish from the + # pipeline when the var is unset. Fail loudly and locally instead. + if voice_sample_transcript is None: + raw = os.environ.get("HURI_VOICE_TRANSCRIPT") + if not raw: + raise RuntimeError( + "HURI_VOICE_TRANSCRIPT is not set. The TTS deployment needs the " + "transcript of the reference voice sample (voice.wav). Set it in " + "the Serve app runtime_env.env_vars (see deploy values.yaml)." + ) + voice_sample_transcript = raw + voice_sample_transcript = _normalize_transcript(voice_sample_transcript) + + cosy_dir = os.environ.get("HURI_COSY_DIR") + if cosy_dir: + matcha_path = os.path.join(cosy_dir, "third_party", "Matcha-TTS") + if os.path.isdir(matcha_path) and matcha_path not in sys.path: + sys.path.insert(0, matcha_path) + + from cosyvoice.cli.cosyvoice import CosyVoice3 + + self.model = CosyVoice3(model_dir=model_path, load_trt=False) + self.sample_rate: int = self.model.sample_rate + + self.prompt_speech = voice_sample_path + self.prompt_text: str = voice_sample_transcript + + self._text_queues: dict[str, queue.Queue] = {} + + async def get_sample_rate(self) -> int: + return self.sample_rate + + async def start_session(self, session_id: str) -> None: + self._text_queues[session_id] = queue.Queue() + + async def push_text(self, session_id: str, text: str, end: bool) -> None: + q = self._text_queues.get(session_id) + if q is None: + return + if text: + q.put(text) + if end: + q.put(_END_TEXT) + + async def stream_audio(self, session_id: str) -> AsyncGenerator[Audio, None]: + text_q = self._text_queues[session_id] + loop = asyncio.get_running_loop() + chunk_count = 0 + + def text_gen(): + while True: + item = text_q.get() + if item is _END_TEXT: + return + yield item + + try: + audio_iter = self.model.inference_zero_shot( + text_gen(), + self.prompt_text, + self.prompt_speech, + stream=True, + ) + while True: + result = await loop.run_in_executor(None, next, audio_iter, _DONE) + if result is _DONE: + break + assert isinstance(result, dict) + chunk_count += 1 + speech = result["tts_speech"].squeeze(0).numpy().astype(np.float32) + yield Audio(data=speech, sample_rate=self.sample_rate) + except Exception: + traceback.print_exc() + raise + finally: + self._text_queues.pop(session_id, None) + + +class TTS(ModuleWithHandle): + """TTS Module — bistream tokens-in / audio-out via CosyVoice3. + + Opens one synthesis session per utterance (delimited by `token.end`). Each + incoming token is pushed straight into the model's text generator so audio + starts coming back before the LLM has finished producing the response. + No clause buffering on our side — CosyVoice's frontend handles segmentation + and stitches LM calls together across the whole utterance. + + input: token (Token) + output: audio (Audio) + """ + + _handle_cls = TTSDeployment + input_type = "token" + output_type = "audio" + + def __init__(self, _handle: handle.DeploymentHandle): + super().__init__(_handle) + self._session_id: str | None = None + self._audio_q: asyncio.Queue | None = None + self._stream_task: asyncio.Task | None = None + # The EventGraph fans each token out as its own concurrent process() + # task on this shared instance. This lock serialises session setup and + # text pushes so tokens reach CosyVoice's text queue in arrival order, + # exactly once. asyncio.Lock wakes waiters FIFO and tokens are created + # in order, so order is preserved — crucially the end-of-utterance token + # can no longer overtake a content token (which would truncate synthesis + # and silently drop trailing words). + self._push_lock = asyncio.Lock() + + async def process(self, token: Token) -> AsyncGenerator[Audio, None]: # type: ignore[override] + # Acquire BEFORE any await so lock-acquisition order matches token order. + # Setup + push happen under the lock; only the first token of an + # utterance goes on to drain/yield audio (outside the lock, so pushes of + # later tokens are never blocked by the long-running drain). + async with self._push_lock: + is_first = self._session_id is None + if is_first: + self._session_id = str(uuid.uuid4()) + self._audio_q = asyncio.Queue() + print(f"[TTS-client] [{self._session_id}] opening new utterance session") + await self._handle.start_session.remote(self._session_id) + self._stream_task = asyncio.create_task( + self._drain_audio(self._session_id, self._audio_q) + ) + + sid = self._session_id + audio_q = self._audio_q + stream_task = self._stream_task + print(f"[TTS-client] [{sid}] push token: {token.text!r} (end={token.end})") + await self._handle.push_text.remote(sid, token.text, token.end) + + if not is_first: + return + + assert audio_q is not None and stream_task is not None + try: + count = 0 + while True: + item = await audio_q.get() + if item is _END_AUDIO: + break + count += 1 + print(f"[TTS-client] [{sid}] yield chunk #{count}") + yield item + await stream_task + print(f"[TTS-client] [{sid}] utterance complete ({count} chunks)") + + sample_rate = await self._handle.get_sample_rate.remote() + yield Audio(data=np.array([], dtype=np.float32), sample_rate=sample_rate, end=True) + finally: + async with self._push_lock: + self._session_id = None + self._audio_q = None + self._stream_task = None + + async def _drain_audio(self, session_id: str, audio_q: asyncio.Queue) -> None: + try: + response = self._handle.options(stream=True).stream_audio.remote(session_id) + count = 0 + pts = 0.0 + async for audio in response: # type: ignore[attr-defined] + count += 1 + audio.pts = pts + pts += audio.data.shape[0] / audio.sample_rate + print( + f"[TTS-client] [{session_id}] drain received chunk #{count} " + f"pts={audio.pts:.3f}s next={pts:.3f}s", + ) + await audio_q.put(audio) + except Exception as e: + print(f"[TTS-client] [{session_id}] drain task FAILED: {e!r}") + raise + finally: + await audio_q.put(_END_AUDIO) + print(f"[TTS-client] [{session_id}] drain task finished") diff --git a/src/modules/utils/sender.py b/src/modules/utils/sender.py index f09b0ba..9261662 100644 --- a/src/modules/utils/sender.py +++ b/src/modules/utils/sender.py @@ -1,9 +1,16 @@ +import logging +import struct from dataclasses import asdict +import numpy as np from fastapi import WebSocket from src.core.events import EventData from src.core.module import Module +from src.modules.gesture.gesture import Motion +from src.modules.text_to_speech.events import Audio + +logger = logging.getLogger("ray.serve") class Sender(Module): @@ -11,6 +18,9 @@ class Sender(Module): Send output data to the client. This data must be JSON serialisable, like a dataclass. + Audio wire format: [4B sample_rate uint32][1B end][8B pts float64][float32 PCM]. + Motion wire format: [8B pts float64][4B fps uint32][4B n_frames uint32] + [poses float32 n*165][expressions float32 n*100][trans float32 n*3]. input: auto, output: None""" @@ -21,10 +31,36 @@ def __init__(self, ws: WebSocket, type: str): self.ws: WebSocket = ws self.input_type = type - async def process(self, data: EventData | bytes): + async def process(self, _): + data = _ + logger.info("[Sender:%s] received %s", self.input_type, type(data).__name__) if isinstance(data, bytes): - await self.ws.send_bytes(data) + await self.ws.send_bytes(self._prefix(data)) + elif isinstance(data, Audio): + logger.info( + "[Sender:%s] Audio samples=%d sr=%d end=%s pts=%.3fs", + self.input_type, data.data.shape[0], data.sample_rate, data.end, data.pts, + ) + header = struct.pack(">IBd", data.sample_rate, int(data.end), data.pts) + await self.ws.send_bytes(self._prefix(header + data.data.tobytes())) + elif isinstance(data, Motion): + n_frames = data.poses.shape[0] + logger.info( + "[Sender:%s] Motion frames=%d fps=%d pts=%.3fs", + self.input_type, n_frames, data.fps, data.pts, + ) + header = struct.pack(">dII", data.pts, data.fps, n_frames) + body = ( + data.poses.astype(np.float32).tobytes() + + data.expressions.astype(np.float32).tobytes() + + data.trans.astype(np.float32).tobytes() + ) + await self.ws.send_bytes(self._prefix(header + body)) elif isinstance(data, EventData): - await self.ws.send_json(asdict(data)) + await self.ws.send_json({"topic": self.input_type, **asdict(data)}) else: - await self.ws.send_text(data) + await self.ws.send_text(str(data)) + + def _prefix(self, payload: bytes) -> bytes: + topic_bytes = self.input_type.encode() + return struct.pack(">H", len(topic_bytes)) + topic_bytes + payload