|
import os |
|
import argparse |
|
import subprocess |
|
import requests |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
|
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomObject, DailyRoomProperties, DailyRoomParams |
|
|
|
|
|
from dotenv import load_dotenv |
|
load_dotenv(override=True) |
|
|
|
|
|
|
|
MAX_SESSION_TIME = 8 * 60 |
|
|
|
daily_rest_helper = DailyRESTHelper( |
|
os.getenv("DAILY_API_KEY", ""), |
|
os.getenv("DAILY_API_URL", 'https://api.daily.co/v1')) |
|
|
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
STATIC_DIR = "frontend/out" |
|
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR, html=True), name="static") |
|
|
|
|
|
@app.post("/start_bot") |
|
async def start_bot(request: Request) -> JSONResponse: |
|
if os.getenv("ENV", "dev") == "production": |
|
|
|
host_header = request.headers.get("host") |
|
allowed_domains = ["storytelling-chatbot.fly.dev", "www.storytelling-chatbot.fly.dev"] |
|
|
|
if host_header not in allowed_domains: |
|
raise HTTPException(status_code=403, detail="Access denied") |
|
|
|
try: |
|
data = await request.json() |
|
|
|
if "test" in data: |
|
return JSONResponse({"test": True}) |
|
except Exception as e: |
|
pass |
|
|
|
|
|
room_url = os.getenv("DAILY_SAMPLE_ROOM_URL", "") |
|
|
|
if not room_url: |
|
params = DailyRoomParams( |
|
properties=DailyRoomProperties() |
|
) |
|
try: |
|
room: DailyRoomObject = daily_rest_helper.create_room(params=params) |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Unable to provision room {e}") |
|
else: |
|
|
|
try: |
|
room: DailyRoomObject = daily_rest_helper.get_room_from_url(room_url) |
|
except Exception: |
|
raise HTTPException( |
|
status_code=500, detail=f"Room not found: {room_url}") |
|
|
|
|
|
token = daily_rest_helper.get_token(room.url, MAX_SESSION_TIME) |
|
|
|
if not room or not token: |
|
raise HTTPException( |
|
status_code=500, detail=f"Failed to get token for room: {room_url}") |
|
|
|
|
|
if os.getenv("RUN_AS_VM", False): |
|
try: |
|
virtualize_bot(room.url, token) |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, detail=f"Failed to spawn VM: {e}") |
|
else: |
|
try: |
|
subprocess.Popen( |
|
[f"python3 -m bot -u {room.url} -t {token}"], |
|
shell=True, |
|
bufsize=1, |
|
cwd=os.path.dirname(os.path.abspath(__file__))) |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, detail=f"Failed to start subprocess: {e}") |
|
|
|
|
|
user_token = daily_rest_helper.get_token(room.url, MAX_SESSION_TIME) |
|
|
|
return JSONResponse({ |
|
"room_url": room.url, |
|
"token": user_token, |
|
}) |
|
|
|
|
|
@app.get("/{path_name:path}", response_class=FileResponse) |
|
async def catch_all(path_name: Optional[str] = ""): |
|
if path_name == "": |
|
return FileResponse(f"{STATIC_DIR}/index.html") |
|
|
|
file_path = Path(STATIC_DIR) / (path_name or "") |
|
|
|
if file_path.is_file(): |
|
return file_path |
|
|
|
html_file_path = file_path.with_suffix(".html") |
|
if html_file_path.is_file(): |
|
return FileResponse(html_file_path) |
|
|
|
raise HTTPException(status_code=450, detail="Incorrect API call") |
|
|
|
|
|
|
|
|
|
def virtualize_bot(room_url: str, token: str): |
|
""" |
|
This is an example of how to virtualize the bot using Fly.io |
|
You can adapt this method to use whichever cloud provider you prefer. |
|
""" |
|
FLY_API_HOST = os.getenv("FLY_API_HOST", "https://api.machines.dev/v1") |
|
FLY_APP_NAME = os.getenv("FLY_APP_NAME", "storytelling-chatbot") |
|
FLY_API_KEY = os.getenv("FLY_API_KEY", "") |
|
FLY_HEADERS = { |
|
'Authorization': f"Bearer {FLY_API_KEY}", |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
|
|
res = requests.get(f"{FLY_API_HOST}/apps/{FLY_APP_NAME}/machines", headers=FLY_HEADERS) |
|
if res.status_code != 200: |
|
raise Exception(f"Unable to get machine info from Fly: {res.text}") |
|
image = res.json()[0]['config']['image'] |
|
|
|
|
|
cmd = f"python3 src/bot.py -u {room_url} -t {token}" |
|
cmd = cmd.split() |
|
worker_props = { |
|
"config": { |
|
"image": image, |
|
"auto_destroy": True, |
|
"init": { |
|
"cmd": cmd |
|
}, |
|
"restart": { |
|
"policy": "no" |
|
}, |
|
"guest": { |
|
"cpu_kind": "shared", |
|
"cpus": 1, |
|
"memory_mb": 512 |
|
} |
|
}, |
|
|
|
} |
|
|
|
|
|
res = requests.post( |
|
f"{FLY_API_HOST}/apps/{FLY_APP_NAME}/machines", |
|
headers=FLY_HEADERS, |
|
json=worker_props) |
|
|
|
if res.status_code != 200: |
|
raise Exception(f"Problem starting a bot worker: {res.text}") |
|
|
|
|
|
vm_id = res.json()['id'] |
|
|
|
res = requests.get( |
|
f"{FLY_API_HOST}/apps/{FLY_APP_NAME}/machines/{vm_id}/wait?state=started", |
|
headers=FLY_HEADERS) |
|
|
|
if res.status_code != 200: |
|
raise Exception(f"Bot was unable to enter started state: {res.text}") |
|
|
|
print(f"Machine joined room: {room_url}") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
required_env_vars = ['OPENAI_API_KEY', 'DAILY_API_KEY', |
|
'FAL_KEY', 'OPENAI_BASE_URL'] |
|
for env_var in required_env_vars: |
|
if env_var not in os.environ: |
|
raise Exception(f"Missing environment variable: {env_var}.") |
|
|
|
import uvicorn |
|
|
|
default_host = os.getenv("HOST", "0.0.0.0") |
|
default_port = int(os.getenv("FAST_API_PORT", "7860")) |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Daily Storyteller FastAPI server") |
|
parser.add_argument("--host", type=str, |
|
default=default_host, help="Host address") |
|
parser.add_argument("--port", type=int, |
|
default=default_port, help="Port number") |
|
parser.add_argument("--reload", action="store_true", |
|
help="Reload code on change") |
|
|
|
config = parser.parse_args() |
|
|
|
uvicorn.run( |
|
"bot_runner:app", |
|
host=config.host, |
|
port=config.port, |
|
reload=config.reload |
|
) |
|
|