132 lines
4.7 KiB
Python
132 lines
4.7 KiB
Python
import os
|
|
import json
|
|
import base64
|
|
from pathlib import Path
|
|
from contextlib import asynccontextmanager
|
|
from dotenv import load_dotenv
|
|
|
|
# Load .env before importing modules that use env vars
|
|
load_dotenv()
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import FileResponse
|
|
|
|
from stt import transcribe, load_model as load_whisper
|
|
from llm import get_response
|
|
from tts import synthesize, set_voice, TTS_ENGINE
|
|
|
|
# Paths
|
|
BASE_DIR = Path(__file__).parent.parent
|
|
FRONTEND_DIR = BASE_DIR / "frontend"
|
|
VOICES_DIR = BASE_DIR / "backend" / "voices"
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app):
|
|
# Startup
|
|
print("Loading Whisper model...")
|
|
load_whisper()
|
|
print("Whisper model loaded.")
|
|
|
|
voice_files = list(VOICES_DIR.glob("*.onnx")) if VOICES_DIR.exists() else []
|
|
if voice_files:
|
|
set_voice(str(voice_files[0]))
|
|
print(f"TTS voice loaded: {voice_files[0].name}")
|
|
else:
|
|
print("WARNING: No voice model found in backend/voices/")
|
|
print("Download a voice from: https://github.com/rhasspy/piper/blob/master/VOICES.md")
|
|
print("Place the .onnx and .onnx.json files in backend/voices/")
|
|
|
|
yield
|
|
# Shutdown
|
|
print("Server shutting down.")
|
|
|
|
|
|
app = FastAPI(title="EcoBot - Scoala Verde", lifespan=lifespan)
|
|
|
|
|
|
# Serve frontend
|
|
app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
|
|
|
|
|
|
@app.get("/")
|
|
async def index():
|
|
return FileResponse(str(FRONTEND_DIR / "index.html"))
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(ws: WebSocket):
|
|
await ws.accept()
|
|
print("Client connected")
|
|
|
|
try:
|
|
while True:
|
|
# Receive audio data from browser
|
|
data = await ws.receive_text()
|
|
message = json.loads(data)
|
|
|
|
if message["type"] == "audio":
|
|
audio_bytes = base64.b64decode(message["data"])
|
|
is_wake_mode = message.get("mode") == "wake"
|
|
|
|
# Step 1: Speech to Text
|
|
await ws.send_text(json.dumps({"type": "status", "text": "Ascult..."}))
|
|
text = transcribe(audio_bytes)
|
|
|
|
if not text:
|
|
await ws.send_text(json.dumps({"type": "status", "text": "Nu am inteles. Incearca din nou."}))
|
|
continue
|
|
|
|
# Wake word detection
|
|
if is_wake_mode:
|
|
text_lower = text.lower().strip()
|
|
wake_words = ["ecobot", "eco bot", "eco-bot", "hello bot", "helo bot"]
|
|
detected = any(w in text_lower for w in wake_words)
|
|
await ws.send_text(json.dumps({
|
|
"type": "wake_detected" if detected else "wake_not_detected",
|
|
"text": text,
|
|
}))
|
|
continue
|
|
|
|
await ws.send_text(json.dumps({"type": "user_text", "text": text}))
|
|
|
|
# Step 2: Get AI response
|
|
await ws.send_text(json.dumps({"type": "status", "text": "Ma gandesc..."}))
|
|
response_text = get_response(text)
|
|
await ws.send_text(json.dumps({"type": "bot_text", "text": response_text}))
|
|
|
|
# Step 3: Text to Speech
|
|
await ws.send_text(json.dumps({"type": "status", "text": "Pregatesc raspunsul..."}))
|
|
print(f"[TTS] Engine: {TTS_ENGINE}")
|
|
print(f"[TTS] Text to synthesize: {response_text.encode('ascii', 'replace').decode()}")
|
|
try:
|
|
audio_response = await synthesize(response_text)
|
|
print(f"[TTS] OK - {len(audio_response)} bytes")
|
|
audio_b64 = base64.b64encode(audio_response).decode("utf-8")
|
|
audio_mime = "audio/mpeg" if TTS_ENGINE == "edge" else "audio/wav"
|
|
await ws.send_text(json.dumps({
|
|
"type": "audio_response",
|
|
"data": audio_b64,
|
|
"text": response_text,
|
|
"mime": audio_mime,
|
|
}))
|
|
except Exception as e:
|
|
err_msg = str(e).encode('ascii', 'replace').decode()
|
|
print(f"[TTS] ERROR: {type(e).__name__}: {err_msg}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
await ws.send_text(json.dumps({
|
|
"type": "text_only_response",
|
|
"text": response_text,
|
|
}))
|
|
|
|
except WebSocketDisconnect:
|
|
print("Client disconnected")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
host = os.getenv("SERVER_HOST", "0.0.0.0")
|
|
port = int(os.getenv("SERVER_PORT", "8000"))
|
|
uvicorn.run(app, host=host, port=port)
|