"""
niki_desktop.py — Niki Desktop helper (v2.0).

The voice subsystem of YBA Edge Practice v2 runs in this single Python file.
It is the v2 port of `yba-platform/cass-desktop/cass_desktop.py` (v0.31.x), with:

  * Multi-tenant aware handshake — every connection sends
    {tenant_slug, practitioner_id, helper_version, gpu_available, model_name}.
  * Piper TTS — local, free, high-quality neural voices. Replaces pyttsx3
    for the read-aloud path. pyttsx3 remains as a last-ditch fallback when
    the Piper voice model can't be downloaded (offline first launch).
  * /v2 protocol bump — the WebSocket still listens on
    `ws://localhost:9876`, but messages are tagged with `"protocol": "v2"`.
    The browser bridge (`niki-bar.js` → NikiBridge) speaks v2; if the v1
    Cass helper happens to be running on the same box, the browser ignores
    its messages because they lack the v2 tag.
  * Renames — every user-visible string says "Niki" (matching Captain's
    2026-05-11 rebrand decision). Code-level filenames/symbols stay close
    to v1 so the diff against `cass_desktop.py` reads cleanly.

Hard-won fixes preserved from v1:
  * Windows CUDA DLL preload (Python 3.14 fix: `os.add_dll_directory` +
    PATH prepend + `ctypes.CDLL` belt-and-braces).
  * Whisper hallucination filter (drops "thanks for watching", "you.",
    "[music]" etc. so they never close a command and never poison the
    brain's input).
  * Auto-detect CUDA → CPU int8 fallback inside `WhisperEngine.__init__`.
  * TLS-style WebSocket close handling with dead-client pruning.

Pipeline:
    Mic (sounddevice) → ring buffer → faster-whisper inference loop
       → streaming transcript + wake/close-word detection
       → WebSocket server on ws://localhost:9876
       → browser bar (niki-bar.js → NikiBridge)

Protocol (JSON-line over WebSocket — v2):

    Server → Client:
        {"type":"hello.ack","protocol":"v2","helper_version":"2.0.0",
         "model":"small.en","device":"cuda|cpu","gpu_available":true|false,
         "tenant_slug":"...","practitioner_id":"..."}
        {"type":"transcript.partial","text":"...","ts":"ISO8601"}
        {"type":"transcript.final","text":"...","ts":"ISO8601"}
        {"type":"wake","buffer":"...","ts":"ISO8601"}
        {"type":"close","word":"thanks","buffer":"...","ts":"ISO8601"}
        {"type":"state","state":"passive|command|paused|error"}
        {"type":"audio","db":-45.0,"vu":0.0..1.0}
        {"type":"tts.start","request_id":"..."}
        {"type":"tts.end","request_id":"..."}
        {"type":"error","message":"..."}

    Client → Server:
        {"type":"hello","tenant_slug":"...","practitioner_id":"..."}
        {"type":"listen.start"}
        {"type":"listen.stop"}
        {"type":"tts.speak","text":"...","voice":"default","request_id":"..."}
        {"type":"tts.stop"}
        {"type":"pause"}
        {"type":"resume"}
        {"type":"shutdown"}

Usage:
    pip install --user -r requirements.txt
    python niki_desktop.py [--tenant-slug matrix-dental-demo]
                          [--practitioner-id 1]
                          [--model small.en] [--device auto] [--port 9876]
    python niki_desktop.py --selftest

Stop with Ctrl+C.
Logs to ~/.niki-desktop.log and stdout.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import logging
import os
import queue
import re
import secrets
import signal
import sys
import threading
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable, Optional

HELPER_VERSION = "2.1.0"
PROTOCOL_VERSION = "v2"


# ---------------------------------------------------------------------------
# Self-bootstrap — if the user just downloads this single .py file with no
# pip installs and runs it, the imports below would explode. Detect a
# missing core dep and offer to pip-install for them. Runs ONCE on first
# launch; future launches skip this branch instantly.
#
# This is the "single-file installer" path Captain asked for on 2026-05-14:
# download niki_desktop.py → double-click → it sets itself up.
# ---------------------------------------------------------------------------
_BOOTSTRAP_DEPS = [
    "faster-whisper>=1.0.3",
    "sounddevice>=0.4.6",
    "websockets>=12.0",
    "pyttsx3>=2.90",
    "piper-tts>=1.2.0",
    "numpy>=1.24",
]


def _bootstrap_if_needed() -> None:
    """First-run pip install when launched as a bare .py with no deps yet."""
    if os.environ.get("NIKI_SKIP_BOOTSTRAP", "").lower() in ("1", "true", "yes"):
        return
    try:
        import faster_whisper  # noqa: F401
        import sounddevice  # noqa: F401
        import websockets  # noqa: F401
        return  # all good — nothing to do
    except ImportError:
        pass

    print("=" * 60, flush=True)
    print(f" Niki Desktop v{HELPER_VERSION} · first-run setup", flush=True)
    print("=" * 60, flush=True)
    print(
        "\nI'll install the Python packages I need (~500 MB download,\n"
        "2-5 min on a typical broadband connection). One-time only —\n"
        "future launches skip this step.\n",
        flush=True,
    )
    import subprocess
    try:
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "--user", "--upgrade", "pip"],
            stdout=subprocess.DEVNULL,
        )
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "--user", *_BOOTSTRAP_DEPS]
        )
    except subprocess.CalledProcessError as exc:
        print(
            f"\n[ERROR] pip install failed (exit {exc.returncode}).\n"
            "If the error mentions 'Microsoft Visual C++', install:\n"
            "  https://aka.ms/vs/17/release/vc_redist.x64.exe\n"
            "then re-run this file.\n",
            file=sys.stderr,
        )
        sys.exit(1)
    print("\n✓ Setup complete. Starting Niki Desktop …\n", flush=True)


_bootstrap_if_needed()


# ---------------------------------------------------------------------------
# Windows CUDA DLL fix — lifted verbatim from v1. faster-whisper / CTranslate2
# needs cublas64_12.dll + cudnn_*.dll loaded BEFORE faster_whisper imports
# anywhere. Python 3.14 silently lost add_dll_directory propagation to worker
# threads, so we use three mechanisms (add_dll_directory + PATH prepend +
# ctypes.CDLL preload). See yba-platform/cass-desktop/cass_desktop.py for the
# bug history.
# ---------------------------------------------------------------------------
def _register_cuda_dll_paths() -> None:
    if not sys.platform.startswith("win"):
        return
    try:
        import site

        candidates: list[Path] = []
        for sp in site.getsitepackages() + [site.getusersitepackages()]:
            nvidia_root = Path(sp) / "nvidia"
            if nvidia_root.exists():
                for lib_dir in nvidia_root.iterdir():
                    bin_dir = lib_dir / "bin"
                    if bin_dir.exists():
                        candidates.append(bin_dir)
        added = 0
        for p in candidates:
            try:
                os.add_dll_directory(str(p))  # type: ignore[attr-defined]
            except Exception:
                pass
            try:
                cur = os.environ.get("PATH", "")
                p_str = str(p)
                if p_str not in cur:
                    os.environ["PATH"] = p_str + os.pathsep + cur
                added += 1
            except Exception:
                pass

        preloaded: list[str] = []
        try:
            import ctypes

            for name in (
                "cublas64_12.dll",
                "cublasLt64_12.dll",
                "cudnn_ops_infer64_9.dll",
                "cudnn_cnn_infer64_9.dll",
                "cudnn64_9.dll",
            ):
                try:
                    ctypes.CDLL(name)
                    preloaded.append(name)
                except OSError:
                    pass
        except Exception:
            pass

        if added or preloaded:
            print(
                f"[niki-desktop] CUDA DLL setup: {added} paths registered, "
                f"{len(preloaded)} libs preloaded ({', '.join(preloaded) or 'none'})",
                file=sys.stderr,
            )
    except Exception as e:
        print(
            f"[niki-desktop] CUDA DLL path setup failed (non-fatal): {e}",
            file=sys.stderr,
        )


_register_cuda_dll_paths()


# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------

LOG_PATH = Path.home() / ".niki-desktop.log"
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
    handlers=[
        logging.FileHandler(LOG_PATH, encoding="utf-8"),
        logging.StreamHandler(sys.stderr),
    ],
)
for noisy in ("faster_whisper", "websockets.server", "websockets.protocol"):
    logging.getLogger(noisy).setLevel(logging.WARNING)
log = logging.getLogger("niki-desktop")


# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------

TOKEN_PATH = Path.home() / ".niki-desktop-token.json"
CONFIG_PATH = Path(__file__).resolve().parent / "niki_desktop_config.json"
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 9876
DEFAULT_DASHBOARD_URL = "https://matrix-v2.tenants.ybaedge.com"

# Where Piper voice models live. Cached locally — first-run download is
# ~60 MB. Reused on every subsequent launch.
PIPER_CACHE_DIR = Path.home() / ".niki-desktop" / "piper-voices"
PIPER_DEFAULT_VOICE = "en_GB-jenny_dioco-medium"
# Hugging Face URL pattern for Piper voices — pinned to rhasspy/piper-voices.
# jenny_dioco is a young female UK voice — matches Niki's persona.
# Override via NIKI_PIPER_VOICE env var if needed (e.g. en_US-amy-medium).
PIPER_HF_BASE = (
    "https://huggingface.co/rhasspy/piper-voices/resolve/main/en/en_GB/"
    "jenny_dioco/medium"
)

SAMPLE_RATE = 16_000
CHANNELS = 1
DTYPE = "float32"

WINDOW_SEC = 2.5
STRIDE_SEC = 0.5
RING_SEC = 30.0

WAKE_RE = re.compile(
    r"\b(hey\s+)?(niki|nicky|nikki|niko|nicki|neeky|nikey|knee\s*key)\b",
    re.IGNORECASE,
)
CLOSE_RE = re.compile(
    r"\b(thank\s*you|thanks|that'?s\s+all|that\s+is\s+all|that'?s\s+it|that\s+is\s+it)\b",
    re.IGNORECASE,
)

SILENCE_DB = -42.0
SILENCE_FINAL_SEC = 0.8


# ---------------------------------------------------------------------------
# Config / token loader — same shape as v1 with renamed paths
# ---------------------------------------------------------------------------


def load_config() -> dict:
    """
    Return {token, dashboard_url, tenant_slug, practitioner_id} from:

        1. niki_desktop_config.json next to this script (v2 — preferred,
           stamped by the dashboard's one-click download).
        2. ~/.niki-desktop-token.json (auto-generated fallback).
        3. Fresh local generation.
    """
    if CONFIG_PATH.exists():
        try:
            payload = json.loads(CONFIG_PATH.read_text(encoding="utf-8"))
            tok = payload.get("token")
            if isinstance(tok, str) and len(tok) >= 16:
                log.info(
                    "loaded config from %s (tenant=%s, dashboard=%s)",
                    CONFIG_PATH,
                    payload.get("tenant_slug"),
                    payload.get("dashboard_url"),
                )
                return {
                    "token": tok,
                    "dashboard_url": payload.get("dashboard_url")
                    or os.environ.get("NIKI_DASHBOARD_URL", DEFAULT_DASHBOARD_URL),
                    "tenant_slug": payload.get("tenant_slug"),
                    "practitioner_id": payload.get("practitioner_id"),
                    "source": "config.json",
                }
        except Exception as exc:
            log.warning("config.json unreadable (%s) — falling through", exc)

    if TOKEN_PATH.exists():
        try:
            payload = json.loads(TOKEN_PATH.read_text(encoding="utf-8"))
            tok = payload.get("token")
            if isinstance(tok, str) and len(tok) >= 16:
                log.info("loaded legacy token from %s", TOKEN_PATH)
                return {
                    "token": tok,
                    "dashboard_url": os.environ.get(
                        "NIKI_DASHBOARD_URL", DEFAULT_DASHBOARD_URL
                    ),
                    "tenant_slug": None,
                    "practitioner_id": None,
                    "source": "home-token",
                }
        except Exception as exc:
            log.warning("token file unreadable (%s) — regenerating", exc)

    tok = secrets.token_urlsafe(24)
    TOKEN_PATH.write_text(
        json.dumps({"token": tok, "created": _now_iso()}, indent=2),
        encoding="utf-8",
    )
    try:
        os.chmod(TOKEN_PATH, 0o600)
    except Exception:
        pass
    log.info("generated new auth token at %s", TOKEN_PATH)
    return {
        "token": tok,
        "dashboard_url": os.environ.get("NIKI_DASHBOARD_URL", DEFAULT_DASHBOARD_URL),
        "tenant_slug": None,
        "practitioner_id": None,
        "source": "generated",
    }


def _now_iso() -> str:
    return datetime.now(timezone.utc).isoformat(timespec="milliseconds")


# ---------------------------------------------------------------------------
# Audio ring buffer (verbatim from v1 — float32 lock-protected ring)
# ---------------------------------------------------------------------------


class AudioRing:
    def __init__(self, seconds: float = RING_SEC):
        import numpy as np

        self._np = np
        self._capacity = int(seconds * SAMPLE_RATE)
        self._buf = np.zeros(self._capacity, dtype=np.float32)
        self._write = 0
        self._lock = threading.Lock()
        self._level_lin = 0.0

    def push(self, chunk: "np.ndarray"):  # noqa: F821
        np = self._np
        flat = chunk.reshape(-1)
        n = flat.shape[0]
        with self._lock:
            end = self._write + n
            if end <= self._capacity:
                self._buf[self._write:end] = flat
            else:
                first = self._capacity - self._write
                self._buf[self._write:] = flat[:first]
                self._buf[: n - first] = flat[first:]
            self._write = end % self._capacity
            rms = float(np.sqrt(np.mean(np.square(flat)) + 1e-12))
            self._level_lin = rms

    def last(self, seconds: float) -> "np.ndarray":  # noqa: F821
        np = self._np
        n = min(int(seconds * SAMPLE_RATE), self._capacity)
        with self._lock:
            if n <= self._write:
                return self._buf[self._write - n:self._write].copy()
            tail = self._buf[self._capacity - (n - self._write):]
            head = self._buf[: self._write]
            return np.concatenate([tail, head])

    @property
    def level_lin(self) -> float:
        return self._level_lin


# ---------------------------------------------------------------------------
# Whisper engine — v1's hard-won CUDA detect + int8 fallback
# ---------------------------------------------------------------------------


class WhisperEngine:
    def __init__(self, model_size: str, device: str):
        from faster_whisper import WhisperModel  # local — heavy

        if device == "auto":
            device = self._autodetect_device()
        compute_type = "float16" if device == "cuda" else "int8"
        log.info(
            "loading faster-whisper model=%s device=%s compute_type=%s",
            model_size,
            device,
            compute_type,
        )
        t0 = time.perf_counter()
        try:
            self.model = WhisperModel(
                model_size, device=device, compute_type=compute_type
            )
            self.device = device
        except Exception as e:
            if device == "cuda":
                log.warning(
                    "CUDA model load failed (%s) — falling back to CPU int8", e
                )
                self.model = WhisperModel(
                    model_size, device="cpu", compute_type="int8"
                )
                self.device = "cpu"
                compute_type = "int8"
            else:
                raise
        self.model_size = model_size
        log.info(
            "faster-whisper ready in %.2fs (device=%s compute_type=%s)",
            time.perf_counter() - t0,
            self.device,
            compute_type,
        )

    @staticmethod
    def _autodetect_device() -> str:
        try:
            import ctranslate2  # type: ignore

            if ctranslate2.get_cuda_device_count() > 0:
                return "cuda"
        except Exception:
            pass
        try:
            import nvidia.cublas  # type: ignore  # noqa: F401
            import nvidia.cudnn  # type: ignore  # noqa: F401

            return "cuda"
        except Exception:
            pass
        return "cpu"

    @property
    def gpu_available(self) -> bool:
        return self.device == "cuda"

    def transcribe(self, audio) -> str:
        segments, _info = self.model.transcribe(
            audio,
            language="en",
            beam_size=1,
            vad_filter=True,
            vad_parameters={"min_silence_duration_ms": 250},
        )
        return " ".join(seg.text.strip() for seg in segments).strip()


# ---------------------------------------------------------------------------
# Piper TTS — local GPU/CPU TTS via Piper (https://github.com/rhasspy/piper).
# Replaces pyttsx3 for higher quality + optional GPU acceleration. pyttsx3
# remains as the last-ditch fallback when:
#   * the `piper` package can't be imported (very old Python), OR
#   * the voice model can't be downloaded (offline first run).
# ---------------------------------------------------------------------------


class PiperTTS:
    """Local Piper neural TTS. Plays through sounddevice. Non-blocking."""

    def __init__(self, voice_name: str = PIPER_DEFAULT_VOICE,
                 cache_dir: Optional[Path] = None):
        self.voice_name = voice_name
        self.cache_dir = cache_dir or PIPER_CACHE_DIR
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self._voice = None              # loaded PiperVoice instance, or None
        self._init_error: Optional[str] = None
        self._stop_flag = threading.Event()
        self._play_lock = threading.Lock()
        self._sample_rate: int = 22_050  # Piper default — overwritten on load
        self._try_init()

    # ---- model bootstrap ----------------------------------------------

    def _voice_paths(self) -> tuple[Path, Path]:
        onnx = self.cache_dir / f"{self.voice_name}.onnx"
        json_path = self.cache_dir / f"{self.voice_name}.onnx.json"
        return onnx, json_path

    def _ensure_voice_downloaded(self) -> bool:
        """Download voice model on first run. Returns True if files are usable."""
        onnx, jsn = self._voice_paths()
        if onnx.exists() and jsn.exists():
            return True
        try:
            import urllib.request

            for name, dest in (
                (f"{self.voice_name}.onnx", onnx),
                (f"{self.voice_name}.onnx.json", jsn),
            ):
                if dest.exists():
                    continue
                url = f"{PIPER_HF_BASE}/{name}"
                log.info("[piper] downloading voice model: %s", url)
                tmp = dest.with_suffix(dest.suffix + ".part")
                with urllib.request.urlopen(url, timeout=60) as resp, tmp.open("wb") as f:
                    while True:
                        chunk = resp.read(64 * 1024)
                        if not chunk:
                            break
                        f.write(chunk)
                tmp.rename(dest)
            log.info("[piper] voice model ready at %s", onnx)
            return True
        except Exception as exc:
            log.warning("[piper] voice model download failed (%s)", exc)
            return False

    def _try_init(self) -> None:
        # Try the modern `piper` (piper-tts) package first.
        try:
            if not self._ensure_voice_downloaded():
                self._init_error = "voice-model-download-failed"
                return
            from piper import PiperVoice  # type: ignore

            onnx, _ = self._voice_paths()
            # use_cuda is best-effort — falls through silently to CPU
            try:
                self._voice = PiperVoice.load(str(onnx), use_cuda=True)
            except TypeError:
                # older piper signature
                self._voice = PiperVoice.load(str(onnx))
            cfg = getattr(self._voice, "config", None)
            sr = getattr(cfg, "sample_rate", None) if cfg else None
            if isinstance(sr, int):
                self._sample_rate = sr
            log.info(
                "[piper] loaded voice=%s sample_rate=%d", self.voice_name, self._sample_rate
            )
        except Exception as exc:
            self._init_error = f"piper-init-failed: {exc!r}"
            log.warning("[piper] init failed (%s) — pyttsx3 fallback will be used", exc)

    @property
    def available(self) -> bool:
        return self._voice is not None

    # ---- playback -----------------------------------------------------

    def speak(self, text: str, callback: Optional[Callable[[], None]] = None) -> None:
        """Synthesize + play `text`. Non-blocking from caller's POV — runs in a thread."""
        if not text or not text.strip():
            return
        if not self.available:
            _pyttsx3_say(text)
            if callback:
                try:
                    callback()
                except Exception:
                    pass
            return

        def _run():
            try:
                self._stop_flag.clear()
                self._synthesize_and_play(text)
            except Exception as exc:
                log.warning("[piper] speak failed (%s) — pyttsx3 fallback", exc)
                _pyttsx3_say(text)
            finally:
                if callback:
                    try:
                        callback()
                    except Exception:
                        pass

        threading.Thread(target=_run, name="niki-piper-tts", daemon=True).start()

    def _synthesize_and_play(self, text: str) -> None:
        """Pull PCM from Piper and stream it through sounddevice."""
        import numpy as np

        try:
            import sounddevice as sd
        except Exception as exc:
            log.warning("[piper] sounddevice unavailable (%s) — pyttsx3 fallback", exc)
            _pyttsx3_say(text)
            return

        with self._play_lock:
            # Piper API: synthesize(text) yields AudioChunk objects with
            # int16 PCM. Older versions expose synthesize_stream_raw(text)
            # which yields raw bytes. Support both.
            voice = self._voice
            assert voice is not None
            sr = self._sample_rate
            buf = bytearray()
            try:
                if hasattr(voice, "synthesize"):
                    for chunk in voice.synthesize(text):
                        if self._stop_flag.is_set():
                            return
                        # Newer piper-tts returns objects with .audio_int16_bytes / .sample_rate
                        raw = getattr(chunk, "audio_int16_bytes", None) or getattr(chunk, "audio_bytes", None) or bytes(chunk)
                        buf.extend(raw)
                        rate = getattr(chunk, "sample_rate", None)
                        if isinstance(rate, int):
                            sr = rate
                elif hasattr(voice, "synthesize_stream_raw"):
                    for raw in voice.synthesize_stream_raw(text):
                        if self._stop_flag.is_set():
                            return
                        buf.extend(raw)
                else:
                    raise RuntimeError("piper voice exposes no synthesize method")
            except Exception as exc:
                log.warning("[piper] synth failed (%s) — pyttsx3 fallback", exc)
                _pyttsx3_say(text)
                return

            if self._stop_flag.is_set():
                return
            audio = np.frombuffer(bytes(buf), dtype=np.int16)
            if audio.size == 0:
                return
            try:
                sd.play(audio, samplerate=sr, blocking=False)
                # Poll for completion so we can interrupt cleanly
                while sd.get_stream().active:  # type: ignore[attr-defined]
                    if self._stop_flag.is_set():
                        sd.stop()
                        return
                    time.sleep(0.05)
            except Exception as exc:
                log.warning("[piper] sounddevice playback failed (%s) — pyttsx3 fallback", exc)
                _pyttsx3_say(text)

    def stop(self) -> None:
        self._stop_flag.set()
        try:
            import sounddevice as sd

            sd.stop()
        except Exception:
            pass


def _pyttsx3_say(text: str) -> None:
    """Last-resort cross-platform TTS (SAPI/NSSpeech/espeak)."""
    try:
        import pyttsx3  # type: ignore

        eng = pyttsx3.init()
        eng.setProperty("rate", 185)
        eng.say(text)
        eng.runAndWait()
    except Exception as exc:
        log.warning("pyttsx3 fallback also failed (%s) — silent reply", exc)


# ---------------------------------------------------------------------------
# State machine + Hub (verbatim from v1)
# ---------------------------------------------------------------------------


class NikiState:
    PASSIVE = "passive"
    COMMAND = "command"
    PAUSED = "paused"
    ERROR = "error"

    def __init__(self):
        self.state: str = NikiState.PASSIVE
        self.last_wake_ts: float = 0.0
        self.command_buffer: str = ""


class Hub:
    def __init__(self):
        self._clients: set = set()
        self._lock = asyncio.Lock()

    async def add(self, ws) -> None:
        async with self._lock:
            self._clients.add(ws)

    async def remove(self, ws) -> None:
        async with self._lock:
            self._clients.discard(ws)

    async def broadcast(self, msg: dict) -> None:
        if not self._clients:
            return
        # Tag every server-emitted frame with the protocol version so the
        # browser bridge can ignore stale v1 helpers running on the same box.
        msg.setdefault("protocol", PROTOCOL_VERSION)
        payload = json.dumps(msg)
        dead = []
        for ws in list(self._clients):
            try:
                await ws.send(payload)
            except Exception:
                dead.append(ws)
        for ws in dead:
            async with self._lock:
                self._clients.discard(ws)


# ---------------------------------------------------------------------------
# NikiDesktop — server + inference loop. Same architecture as v1 with v2
# protocol on top.
# ---------------------------------------------------------------------------


class NikiDesktop:
    def __init__(
        self,
        model_size: str = "small.en",
        device: str = "auto",
        host: str = DEFAULT_HOST,
        port: int = DEFAULT_PORT,
        dashboard_url: Optional[str] = None,
        tenant_slug: Optional[str] = None,
        practitioner_id: Optional[str] = None,
        enable_tts: bool = True,
        piper_voice: str = PIPER_DEFAULT_VOICE,
    ):
        cfg = load_config()
        self.token = cfg["token"]
        self.config_source = cfg["source"]
        self.dashboard_url = dashboard_url or cfg["dashboard_url"]
        self.tenant_slug = tenant_slug or cfg.get("tenant_slug")
        self.practitioner_id = practitioner_id or cfg.get("practitioner_id")
        self.engine = WhisperEngine(model_size=model_size, device=device)
        self.ring = AudioRing()
        self.state = NikiState()
        self.hub = Hub()
        self.host = host
        self.port = port
        self.last_text: str = ""
        self._silence_started: Optional[float] = None
        self._paused: bool = False
        self._stop_event = asyncio.Event()
        self._audio_stream = None
        self._loop: Optional[asyncio.AbstractEventLoop] = None
        self._infer_q: queue.Queue = queue.Queue(maxsize=4)
        # TTS — Piper preferred, pyttsx3 fallback baked in
        self.tts: Optional[PiperTTS] = PiperTTS(voice_name=piper_voice) if enable_tts else None

    # ---- audio + inference plumbing (verbatim from v1) ----------------

    def _audio_callback(self, indata, _frames, _time_info, status):
        if status:
            log.debug("audio status: %s", status)
        self.ring.push(indata)

    async def _audio_meter_loop(self):
        while not self._stop_event.is_set():
            lin = self.ring.level_lin
            db = 20.0 * ((math_log10(lin + 1e-9)) if lin > 0 else -6.0)
            await self.hub.broadcast(
                {"type": "audio", "vu": min(1.0, lin * 8.0), "db": float(db)}
            )
            await asyncio.sleep(0.1)

    def _inference_thread(self):
        loop = self._loop
        assert loop is not None
        while not self._stop_event.is_set():
            try:
                window = self._infer_q.get(timeout=0.2)
            except queue.Empty:
                continue
            if window is None:
                break
            try:
                text = self.engine.transcribe(window)
            except Exception as exc:
                log.exception("transcribe failed: %s", exc)
                asyncio.run_coroutine_threadsafe(
                    self.hub.broadcast(
                        {"type": "error", "message": f"transcribe: {exc}"}
                    ),
                    loop,
                )
                continue
            if not text:
                continue
            if text == self.last_text:
                continue
            new_suffix = self._diff_suffix(self.last_text, text)
            self.last_text = text
            asyncio.run_coroutine_threadsafe(
                self._on_partial(new_suffix or text, full=text), loop
            )

    @staticmethod
    def _diff_suffix(prev: str, curr: str) -> str:
        prev_words = prev.split()
        curr_words = curr.split()
        i = 0
        while (
            i < len(prev_words)
            and i < len(curr_words)
            and prev_words[i].lower() == curr_words[i].lower()
        ):
            i += 1
        return " ".join(curr_words[i:])

    async def _on_partial(self, new_text: str, full: str):
        ts = _now_iso()
        await self.hub.broadcast(
            {"type": "transcript.partial", "text": full, "ts": ts}
        )
        lin = self.ring.level_lin
        db = 20.0 * (math_log10(lin + 1e-9)) if lin > 0 else -120.0
        now = time.monotonic()
        if db < SILENCE_DB:
            if self._silence_started is None:
                self._silence_started = now
            elif now - self._silence_started >= SILENCE_FINAL_SEC:
                self._silence_started = None
                await self.hub.broadcast(
                    {"type": "transcript.final", "text": full, "ts": ts}
                )
                await self._evaluate_wake_close(full)
                self.last_text = ""
        else:
            self._silence_started = None
        if WAKE_RE.search(new_text) and self.state.state == NikiState.PASSIVE:
            await self._wake(full)

    async def _evaluate_wake_close(self, finalised_text: str):
        # Hallucination filter from v1 — Whisper LOVES to emit YT/silence
        # garbage; never close on it, never count it toward command buffer.
        clean = (finalised_text or "").strip().lower()
        HALLUCINATIONS = {
            "thanks for watching", "thanks for watching.", "thanks for watching!",
            "thank you.", "thank you", "thanks.", "thanks", "you.", "you",
            "for watching!", "for watching", "for watching.",
            "[music]", "[applause]", "[silence]", ".",
            "subscribe", "subscribe.", "subscribe to my channel",
            "okay.", "ok.", "bye.", "bye",
        }
        if clean in HALLUCINATIONS or (
            len(clean) < 6
            and clean.replace(".", "").replace("!", "").replace("?", "").strip()
            in {"you", "thanks", "thank", "bye", "ok", "okay"}
        ):
            log.debug("dropped whisper hallucination: %r", finalised_text)
            return

        if self.state.state == NikiState.COMMAND:
            self.state.command_buffer = (
                (self.state.command_buffer + " " + finalised_text).strip()
            )
            m = CLOSE_RE.search(finalised_text)
            if m:
                word = m.group(1)
                cmd_clean = CLOSE_RE.sub("", self.state.command_buffer).strip()
                if not cmd_clean or cmd_clean.lower() in HALLUCINATIONS:
                    log.info("close-word in noise-only buffer — silent passive")
                    self._set_state(NikiState.PASSIVE)
                    self.state.command_buffer = ""
                    return
                await self.hub.broadcast(
                    {
                        "type": "close",
                        "word": word,
                        "ts": _now_iso(),
                        "buffer": cmd_clean,
                    }
                )
                self._set_state(NikiState.PASSIVE)
                self.state.command_buffer = ""
        elif self.state.state == NikiState.PASSIVE and WAKE_RE.search(
            finalised_text
        ):
            await self._wake(finalised_text)

    async def _wake(self, buffer: str):
        now = time.monotonic()
        if now - self.state.last_wake_ts < 3.0:
            return
        self.state.last_wake_ts = now
        await self.hub.broadcast(
            {"type": "wake", "buffer": buffer, "ts": _now_iso()}
        )
        self._set_state(NikiState.COMMAND)

    def _set_state(self, new_state: str):
        if new_state == self.state.state:
            return
        log.info("state: %s → %s", self.state.state, new_state)
        self.state.state = new_state
        loop = self._loop
        if loop is not None:
            asyncio.run_coroutine_threadsafe(
                self.hub.broadcast({"type": "state", "state": new_state}),
                loop,
            )

    async def _window_scheduler(self):
        while not self._stop_event.is_set():
            if not self._paused:
                window = self.ring.last(WINDOW_SEC)
                try:
                    self._infer_q.put_nowait(window)
                except queue.Full:
                    log.debug("inference queue full — dropping window")
            await asyncio.sleep(STRIDE_SEC)

    # ---- WebSocket handler (v2 protocol) ------------------------------

    async def _ws_handler(self, ws):
        """v2 handler — first frame is `{"type":"hello",...}`, replies with hello.ack."""
        client = getattr(ws, "remote_address", "?")
        log.info("client connected from %s", client)
        await self.hub.add(ws)
        try:
            await self._send_hello_ack(ws)
            async for raw in ws:
                await self._on_client_msg(ws, raw)
        except Exception as exc:
            log.warning("client %s closed with error: %s", client, exc)
        finally:
            await self.hub.remove(ws)
            log.info("client %s disconnected", client)

    async def _send_hello_ack(self, ws) -> None:
        """Wait up to 5s for a `hello`. Reply with hello.ack regardless.

        Localhost-only listener — same trust model as v1 (any code on this
        machine already has physical access; no token gate beyond config).
        """
        try:
            raw = await asyncio.wait_for(ws.recv(), timeout=5.0)
            try:
                msg = json.loads(raw)
            except Exception:
                msg = {}
            if isinstance(msg, dict) and msg.get("type") == "hello":
                # Capture per-connection tenant context — overrides launch defaults
                t_slug = msg.get("tenant_slug") or self.tenant_slug
                p_id = msg.get("practitioner_id") or self.practitioner_id
            else:
                t_slug = self.tenant_slug
                p_id = self.practitioner_id
                # Replay the unexpected first message through the normal handler
                if msg:
                    await self._on_client_msg(ws, raw)
        except asyncio.TimeoutError:
            t_slug = self.tenant_slug
            p_id = self.practitioner_id
            log.info("client %s sent no hello — accepted with launch defaults", ws.remote_address)

        ack = {
            "type": "hello.ack",
            "protocol": PROTOCOL_VERSION,
            "helper_version": HELPER_VERSION,
            "model": self.engine.model_size,
            "device": self.engine.device,
            "gpu_available": self.engine.gpu_available,
            "model_name": self.engine.model_size,
            "tenant_slug": t_slug,
            "practitioner_id": p_id,
            "tts_engine": "piper" if self.tts and self.tts.available else "pyttsx3",
        }
        await ws.send(json.dumps(ack))
        await ws.send(
            json.dumps(
                {"type": "state", "state": self.state.state, "protocol": PROTOCOL_VERSION}
            )
        )

    async def _on_client_msg(self, ws, raw):
        try:
            msg = json.loads(raw)
        except Exception:
            return
        if not isinstance(msg, dict):
            return
        mtype = msg.get("type")
        # v2 message types
        if mtype == "tts.speak":
            text = str(msg.get("text") or "").strip()
            request_id = str(msg.get("request_id") or "")
            if text:
                await self._tts_speak(text, request_id)
        elif mtype == "tts.stop":
            if self.tts:
                self.tts.stop()
        elif mtype == "listen.start":
            self._paused = False
            self._set_state(NikiState.PASSIVE)
        elif mtype == "listen.stop":
            self._paused = True
            self._set_state(NikiState.PAUSED)
        elif mtype == "pause":
            self._paused = True
            self._set_state(NikiState.PAUSED)
        elif mtype == "resume":
            self._paused = False
            self._set_state(NikiState.PASSIVE)
        elif mtype == "shutdown":
            log.info("shutdown requested by client")
            self._stop_event.set()
        # v1 compat — accept the legacy "tts" type so a half-migrated dashboard
        # still hears Niki speak.
        elif mtype == "tts":
            text = str(msg.get("text") or "").strip()
            if text:
                await self._tts_speak(text, "")

    async def _tts_speak(self, text: str, request_id: str) -> None:
        loop = asyncio.get_running_loop()
        await self.hub.broadcast(
            {"type": "tts.start", "request_id": request_id, "ts": _now_iso()}
        )

        def _on_done():
            asyncio.run_coroutine_threadsafe(
                self.hub.broadcast(
                    {"type": "tts.end", "request_id": request_id, "ts": _now_iso()}
                ),
                loop,
            )

        if self.tts:
            self.tts.speak(text, callback=_on_done)
        else:
            # Pure pyttsx3 fallback in its own thread so we don't block asyncio
            def _say():
                _pyttsx3_say(text)
                _on_done()

            threading.Thread(target=_say, name="niki-pyttsx3", daemon=True).start()

    # ---- dashboard registration (sanity check, fire-and-forget) -------

    def _register_with_dashboard(self) -> None:
        if not self.dashboard_url:
            return
        try:
            import urllib.request

            payload = json.dumps(
                {
                    "token": self.token,
                    "helper_version": HELPER_VERSION,
                    "protocol": PROTOCOL_VERSION,
                    "helper_host": f"{self.host}:{self.port}",
                    "tenant_slug": self.tenant_slug,
                    "practitioner_id": self.practitioner_id,
                    "gpu_available": self.engine.gpu_available,
                    "model_name": self.engine.model_size,
                }
            ).encode("utf-8")
            url = f"{self.dashboard_url.rstrip('/')}/niki-desktop/register"
            req = urllib.request.Request(
                url,
                data=payload,
                headers={"Content-Type": "application/json"},
                method="POST",
            )
            with urllib.request.urlopen(req, timeout=5) as resp:
                body = resp.read().decode("utf-8", errors="replace")
                log.info(
                    "registered with dashboard %s · http=%d · body=%s",
                    self.dashboard_url,
                    resp.status,
                    body[:200],
                )
        except Exception as exc:
            log.warning(
                "dashboard register failed (non-fatal · %s): %s",
                type(exc).__name__,
                exc,
            )

    # ---- entry point --------------------------------------------------

    async def run(self):
        import sounddevice as sd

        self._loop = asyncio.get_running_loop()

        log.info("opening audio stream at %d Hz mono", SAMPLE_RATE)
        self._audio_stream = sd.InputStream(
            samplerate=SAMPLE_RATE,
            channels=CHANNELS,
            dtype=DTYPE,
            callback=self._audio_callback,
        )
        self._audio_stream.start()

        infer_t = threading.Thread(
            target=self._inference_thread, name="niki-infer", daemon=True
        )
        infer_t.start()

        import websockets

        async with websockets.serve(
            self._ws_handler,
            self.host,
            self.port,
            ping_interval=20,
            ping_timeout=20,
        ):
            log.info(
                "Niki Desktop v%s listening on ws://%s:%d (tenant=%s, source=%s, gpu=%s)",
                HELPER_VERSION,
                self.host,
                self.port,
                self.tenant_slug,
                self.config_source,
                self.engine.gpu_available,
            )
            threading.Thread(
                target=self._register_with_dashboard,
                name="niki-register",
                daemon=True,
            ).start()
            tasks = [
                asyncio.create_task(self._window_scheduler()),
                asyncio.create_task(self._audio_meter_loop()),
            ]
            try:
                await self._stop_event.wait()
            finally:
                for t in tasks:
                    t.cancel()
                self._audio_stream.stop()
                self._audio_stream.close()
                self._infer_q.put(None)
                infer_t.join(timeout=2.0)
                if self.tts:
                    self.tts.stop()
                log.info("Niki Desktop shut down cleanly")


def math_log10(x: float) -> float:
    import math

    return math.log10(max(x, 1e-12))


# ---------------------------------------------------------------------------
# Self-test
# ---------------------------------------------------------------------------


def _run_selftest(model_size: str, device: str) -> int:
    print(f"Niki Desktop v{HELPER_VERSION} · self-test")
    print(f"  model:  {model_size}")
    print(f"  device: {device}")

    cfg = load_config()
    print(f"  token:  {cfg['token'][:8]}… (source: {cfg['source']})")
    print(f"  tenant: {cfg.get('tenant_slug') or '(unset)'}")
    print(f"  dashboard: {cfg['dashboard_url']}")

    try:
        engine = WhisperEngine(model_size=model_size, device=device)
    except Exception as exc:
        print(f"\n[FAIL] could not load faster-whisper: {exc}")
        return 2

    audio_path = Path(__file__).resolve().parent / "test_audio.wav"
    audio_buf = None
    if audio_path.exists():
        try:
            import wave

            import numpy as np

            with wave.open(str(audio_path), "rb") as wf:
                n = wf.getnframes()
                frames = wf.readframes(n)
                sr = wf.getframerate()
                ch = wf.getnchannels()
                sw = wf.getsampwidth()
            if sw == 2:
                arr = np.frombuffer(frames, dtype=np.int16).astype("float32") / 32768.0
            elif sw == 4:
                arr = (
                    np.frombuffer(frames, dtype=np.int32).astype("float32")
                    / 2147483648.0
                )
            else:
                arr = (
                    np.frombuffer(frames, dtype=np.uint8).astype("float32") / 128.0
                    - 1.0
                )
            if ch > 1:
                arr = arr.reshape(-1, ch).mean(axis=1)
            if sr != SAMPLE_RATE:
                print(f"  [warn] sample rate {sr} != {SAMPLE_RATE}; resampling naive")
                ratio = SAMPLE_RATE / sr
                idx = (np.arange(int(len(arr) * ratio)) / ratio).astype(int)
                arr = arr[idx]
            audio_buf = arr.astype("float32")
            print(f"  audio:  test_audio.wav ({len(audio_buf) / SAMPLE_RATE:.2f}s)")
        except Exception as exc:
            print(f"  [warn] test_audio.wav unreadable ({exc}) — synthetic")
    if audio_buf is None:
        import numpy as np

        rng = np.random.default_rng(0xCA55)
        audio_buf = (rng.standard_normal(3 * SAMPLE_RATE) * 0.01).astype("float32")
        print(f"  audio:  synthetic 3.00s (no test_audio.wav found)")

    t0 = time.perf_counter()
    try:
        text = engine.transcribe(audio_buf)
    except Exception as exc:
        print(f"\n[FAIL] transcribe crashed: {exc}")
        return 3
    dt = time.perf_counter() - t0
    print(f"\n[OK] transcribe completed in {dt:.2f}s (device={engine.device})")
    print(f"     transcript: {text!r}")

    # Bonus: Piper bootstrap check (download voice if missing, don't speak)
    print("\nChecking Piper TTS bootstrap (may download ~60 MB on first run)…")
    try:
        tts = PiperTTS()
        if tts.available:
            print("[OK] Piper voice loaded — neural TTS ready")
        else:
            print(f"[INFO] Piper unavailable ({tts._init_error}) — pyttsx3 fallback will be used")
    except Exception as exc:
        print(f"[INFO] Piper bootstrap raised ({exc}) — pyttsx3 fallback will be used")

    print("\nSelf-test passed. Launch with:")
    print("     python niki_desktop.py --tenant-slug <your-slug>")
    return 0


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------


def main(argv: Optional[list] = None) -> int:
    parser = argparse.ArgumentParser(
        prog="niki-desktop",
        description=f"Niki Desktop helper v{HELPER_VERSION} (local Whisper + Piper TTS).",
    )
    parser.add_argument(
        "--model",
        default=os.environ.get("NIKI_MODEL", "small.en"),
        help="faster-whisper model size (tiny.en, base.en, small.en, medium.en)",
    )
    parser.add_argument(
        "--device",
        default=os.environ.get("NIKI_DEVICE", "auto"),
        help="auto | cpu | cuda",
    )
    parser.add_argument("--host", default=DEFAULT_HOST)
    parser.add_argument("--port", type=int, default=DEFAULT_PORT)
    parser.add_argument(
        "--tenant-slug",
        default=os.environ.get("NIKI_TENANT_SLUG"),
        help="Tenant slug for the multi-tenant handshake (e.g. matrix-dental-demo).",
    )
    parser.add_argument(
        "--practitioner-id",
        default=os.environ.get("NIKI_PRACTITIONER_ID"),
        help="Practitioner ID this helper is bound to (matches YBA_PRACTITIONER_ID).",
    )
    parser.add_argument(
        "--dashboard",
        default=os.environ.get("NIKI_DASHBOARD_URL"),
        help="Dashboard base URL for the /niki-desktop/register handshake.",
    )
    parser.add_argument(
        "--no-tts",
        action="store_true",
        help="Disable Piper/pyttsx3 — the bar will fall back to browser speech.",
    )
    parser.add_argument(
        "--piper-voice",
        default=os.environ.get("NIKI_PIPER_VOICE", PIPER_DEFAULT_VOICE),
        help="Piper voice model name (default: en_GB-jenny_dioco-medium — young UK female).",
    )
    parser.add_argument(
        "--selftest",
        action="store_true",
        help="End-to-end install check (loads Whisper + checks Piper, no mic).",
    )
    args = parser.parse_args(argv)

    if args.selftest:
        return _run_selftest(model_size=args.model, device=args.device)

    niki = NikiDesktop(
        model_size=args.model,
        device=args.device,
        host=args.host,
        port=args.port,
        dashboard_url=args.dashboard,
        tenant_slug=args.tenant_slug,
        practitioner_id=args.practitioner_id,
        enable_tts=not args.no_tts,
        piper_voice=args.piper_voice,
    )

    def _on_signal(*_):
        log.info("signal received — shutting down")
        niki._stop_event.set()

    signal.signal(signal.SIGINT, _on_signal)
    if hasattr(signal, "SIGTERM"):
        signal.signal(signal.SIGTERM, _on_signal)

    try:
        asyncio.run(niki.run())
    except KeyboardInterrupt:
        log.info("interrupted by user")
    except Exception as exc:
        log.exception("fatal: %s", exc)
        return 1
    return 0


if __name__ == "__main__":
    sys.exit(main())
