benjolo's picture
Update backend/main.py
000aa13 verified
raw
history blame
12.2 kB
from operator import itemgetter
import os
from datetime import datetime
import uvicorn
from typing import Any, Optional, Tuple, Dict, TypedDict
from urllib import parse
from uuid import uuid4
import logging
from fastapi.logger import logger as fastapi_logger
import sys
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi import APIRouter, Body, Request, status
from pymongo import MongoClient
from dotenv import dotenv_values
from routes import router as api_router
from contextlib import asynccontextmanager
import requests
from typing import List
from datetime import date
from mongodb.operations.calls import *
from mongodb.models.calls import UserCall, UpdateCall
# from mongodb.endpoints.calls import *
from utils.text_rank import extract_terms
from transformers import AutoProcessor, SeamlessM4Tv2Model
# from seamless_communication.inference import Translator
from Client import Client
#----------------------------------
# base seamless imports
# ---------------------------------
import numpy as np
import torch
# ---------------------------------
import socketio
###############################################
# Configure logger
gunicorn_error_logger = logging.getLogger("gunicorn.error")
gunicorn_logger = logging.getLogger("gunicorn")
uvicorn_access_logger = logging.getLogger("uvicorn.access")
gunicorn_error_logger.propagate = True
gunicorn_logger.propagate = True
uvicorn_access_logger.propagate = True
uvicorn_access_logger.handlers = gunicorn_error_logger.handlers
fastapi_logger.handlers = gunicorn_error_logger.handlers
###############################################
# sio is the main socket.io entrypoint
sio = socketio.AsyncServer(
async_mode="asgi",
cors_allowed_origins="*",
logger=gunicorn_logger,
engineio_logger=gunicorn_logger,
)
# sio.logger.setLevel(logging.DEBUG)
socketio_app = socketio.ASGIApp(sio)
# app.mount("/", socketio_app)
config = dotenv_values(".env")
# Read connection string from environment vars
# uri = os.environ['MONGODB_URI']
# Read connection string from .env file
uri = config['MONGODB_URI']
# MongoDB Connection Lifespan Events
@asynccontextmanager
async def lifespan(app: FastAPI):
# startup logic
app.mongodb_client = MongoClient(uri)
app.database = app.mongodb_client['IT-Cluster1'] #connect to interpretalk primary db
try:
app.mongodb_client.admin.command('ping')
print("MongoDB Connection Established...")
except Exception as e:
print(e)
yield
# shutdown logic
print("Closing MongoDB Connection...")
app.mongodb_client.close()
app = FastAPI(lifespan=lifespan, logger=gunicorn_logger)
# New CORS funcitonality
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # configured node app port
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(api_router) # include routers for user, calls and transcripts operations
DEBUG = True
ESCAPE_HATCH_SERVER_LOCK_RELEASE_NAME = "remove_server_lock"
TARGET_SAMPLING_RATE = 16000
MAX_BYTES_BUFFER = 960_000
print("")
print("")
print("=" * 20 + " ⭐️ Starting Server... ⭐️ " + "=" * 20)
###############################################
# Configure socketio server
###############################################
# TODO PM - change this to the actual path
# seamless remnant code
CLIENT_BUILD_PATH = "../streaming-react-app/dist/"
static_files = {
"/": CLIENT_BUILD_PATH,
"/assets/seamless-db6a2555.svg": {
"filename": CLIENT_BUILD_PATH + "assets/seamless-db6a2555.svg",
"content_type": "image/svg+xml",
},
}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
# PM - hardcoding temporarily as my GPU doesnt have enough vram
model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
bytes_data = bytearray()
model_name = "seamlessM4T_v2_large"
vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs"
clients = {}
rooms = {}
def get_collection_users():
return app.database["user_records"]
def get_collection_calls():
# return app.database["call_records"]
return app.database["call_test"]
@app.get("/home/", response_description="Welcome User")
def test():
return {"message": "Welcome to InterpreTalk!"}
async def send_translated_text(client_id, original_text, translated_text, room_id):
print('SEND_TRANSLATED_TEXT IS WOKRING IN FASTAPI BACKEND...')
print(rooms) # Debugging
print(clients) # Debugging
data = {
"author": str(client_id),
"original_text": str(original_text),
"translated_text": str(translated_text),
"timestamp": str(datetime.now())
}
gunicorn_logger.info("SENDING TRANSLATED TEXT TO CLIENT")
await sio.emit("translated_text", data, room=room_id)
gunicorn_logger.info("SUCCESSFULLY SEND AUDIO TO FRONTEND")
@sio.on("connect")
async def connect(sid, environ):
print(f"📥 [event: connected] sid={sid}")
query_params = dict(parse.parse_qsl(environ["QUERY_STRING"]))
client_id = query_params.get("client_id")
gunicorn_logger.info(f"📥 [event: connected] sid={sid}, client_id={client_id}")
# sid = socketid, client_id = client specific ID ,always the same for same user
clients[sid] = Client(sid, client_id)
gunicorn_logger.warning(f"Client connected: {sid}")
gunicorn_logger.warning(clients)
@sio.on("disconnect")
async def disconnect(sid):
gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}")
clients.pop(sid, None)
@sio.on("term_extraction")
async def term_extraction(sid, call_id):
gunicorn_logger.debug(f"📤 [event: term_extraction] sid={sid}, call={call_id}")
# call_id = "0FIdAosKy9ysQDkp14T2"
# Get combined caption field for call record based on call_id
combined_text = get_caption_text(get_collection_calls(), call_id)
if combined_text: # > min_caption_length: -> poor term extraction on short
print("THE COMBINED TEXT IS:", combined_text)
# Extract Key Terms from Concatenated Caption Field
key_terms = extract_terms(combined_text, len(combined_text))
# BO -> Update Call record with call duration, key terms
print("THE KEY TERMS ARE:", key_terms)
request_data = {
"key_terms": key_terms
}
update_calls(get_collection_calls(), call_id, request_data)
@sio.on("target_language")
async def target_language(sid, target_lang):
gunicorn_logger.info(f"📥 [event: target_language] sid={sid}, target_lang={target_lang}")
clients[sid].target_language = target_lang
@sio.on("call_user")
async def call_user(sid, call_id):
clients[sid].call_id = call_id
gunicorn_logger.info(f"CALL {sid}: entering room {call_id}")
rooms[call_id] = rooms.get(call_id, [])
if sid not in rooms[call_id] and len(rooms[call_id]) < 2:
rooms[call_id].append(sid)
sio.enter_room(sid, call_id)
else:
gunicorn_logger.info(f"CALL {sid}: room {call_id} is full")
# await sio.emit("room_full", room=call_id, to=sid)
# BO - Get call id from dictionary created during socketio connection
client_id = clients[sid].client_id
gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
# BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..)
request_data = {
"call_id": str(call_id),
"caller_id": str(client_id),
"creation_date": str(datetime.now())
}
response = create_calls(get_collection_calls(), request_data)
print(response) # BO - print created db call record
@sio.on("audio_config")
async def audio_config(sid, sample_rate):
clients[sid].original_sr = sample_rate
@sio.on("answer_call")
async def answer_call(sid, call_id):
clients[sid].call_id = call_id
gunicorn_logger.info(f"ANSWER {sid}: entering room {call_id}")
rooms[call_id] = rooms.get(call_id, [])
if sid not in rooms[call_id] and len(rooms[call_id]) < 2:
rooms[call_id].append(sid)
sio.enter_room(sid, call_id)
else:
gunicorn_logger.info(f"ANSWER {sid}: room {call_id} is full")
# await sio.emit("room_full", room=call_id, to=sid)
# BO - Get call id from dictionary created during socketio connection
client_id = clients[sid].client_id
# BO -> Update Call Record with Callee field based on call_id
gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
# BO -> Create Call Record with callee_id field (None for callee, duration, terms..)
request_data = {
"callee_id": client_id
}
response = update_calls(get_collection_calls(), call_id, request_data)
print(response) # BO - print created db call record
@sio.on("incoming_audio")
async def incoming_audio(sid, data, call_id):
try:
clients[sid].add_bytes(data)
if clients[sid].get_length() >= MAX_BYTES_BUFFER:
gunicorn_logger.info('Buffer full, now outputting...')
output_path = clients[sid].output_path
resampled_audio = clients[sid].resample_and_clear()
vad_result = clients[sid].vad_analyse(resampled_audio)
# source lang is speakers tgt language 😃
src_lang = clients[sid].target_language
if vad_result:
gunicorn_logger.info('Speech detected, now processing audio.....')
tgt_sid = next(id for id in rooms[call_id] if id != sid)
tgt_lang = clients[tgt_sid].target_language
# following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage
output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt", sampling_rate=TARGET_SAMPLING_RATE).to(device)
model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0]
asr_text = processor.decode(model_output, skip_special_tokens=True)
print(f"ASR TEXT = {asr_text}")
# ASR TEXT => ORIGINAL TEXT
t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt").to(device)
print(f"FIRST TYPE = {type(output_tokens)}, SECOND TYPE = {type(t2t_tokens)}")
translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
translated_text = processor.decode(translated_data, skip_special_tokens=True)
print(f"TRANSLATED TEXT = {translated_text}")
# TRANSLATED TEXT
# PM - text_output is a list with 1 string
await send_translated_text(clients[sid].client_id, asr_text, translated_text, call_id)
# BO -> send translated_text to mongodb as caption record update based on call_id
await send_captions(clients[sid].client_id, asr_text, translated_text, call_id)
except Exception as e:
gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}")
async def send_captions(client_id, original_text, translated_text, call_id):
# BO -> Update Call Record with Callee field based on call_id
print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}")
data = {
"author": str(client_id),
"original_text": str(original_text),
"translated_text": str(translated_text),
"timestamp": str(datetime.now())
}
response = update_captions(get_collection_calls(), call_id, data)
return response
app.mount("/", socketio_app)
if __name__ == '__main__':
uvicorn.run("main:app", host='0.0.0.0', port=7860, log_level="info")
# Running in Docker Container
if __name__ != "__main__":
fastapi_logger.setLevel(gunicorn_logger.level)
else:
fastapi_logger.setLevel(logging.DEBUG)