Spaces:
Running
Running
File size: 11,705 Bytes
1c08fbd 5aa78f4 1c08fbd bcfd8ed 5aa78f4 1c08fbd 5aa78f4 bb15705 5aa78f4 bb15705 1c08fbd bb15705 5aa78f4 1c08fbd bb15705 1c08fbd 5aa78f4 1c08fbd d455eb8 5aa78f4 bb15705 5aa78f4 bb15705 5aa78f4 bb15705 5aa78f4 bb15705 5aa78f4 bb15705 5aa78f4 bb15705 5aa78f4 1c08fbd bb15705 1c08fbd bb15705 1c08fbd bb15705 1c08fbd bb15705 1c08fbd bb15705 5aa78f4 1c08fbd 5aa78f4 1c08fbd bb15705 1c08fbd bb15705 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 |
'''
Created By Lewis Kamau Kimaru
Sema translator fastapi implementation
January 2024
Docker deployment
'''
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
import uvicorn
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from pymongo import MongoClient
import jwt
from jwt import encode as jwt_encode
from bson import ObjectId
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import ctranslate2
import sentencepiece as spm
import fasttext
import torch
from datetime import datetime
import gradio as gr
import pytz
import time
import os
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
fasttext.FastText.eprint = lambda x: None
# set this key as an environment variable
os.environ["HUGGINGFACEHUB_API_TOKEN"] = st.secrets['huggingface_token']
# User interface
templates_folder = os.path.join(os.path.dirname(__file__), "templates")
# Authentication
class User(BaseModel):
username: str = None # Make the username field optional
email: str
password: str
# Connect to the MongoDB database
client = MongoClient("mongodb://localhost:27017")
db = client["mydatabase"]
users_collection = db["users"]
# Secret key for signing the token
SECRET_KEY = "helloworld"
security = HTTPBearer()
#Implement the login route:
@app.post("/login")
def login(user: User):
# Check if user exists in the database
user_data = users_collection.find_one(
{"email": user.email, "password": user.password}
)
if user_data:
# Generate a token
token = generate_token(user.email)
# Convert ObjectId to string
user_data["_id"] = str(user_data["_id"])
# Store user details and token in local storage
user_data["token"] = token
return user_data
return {"message": "Invalid email or password"}
#Implement the registration route:
@app.post("/register")
def register(user: User):
# Check if user already exists in the database
existing_user = users_collection.find_one({"email": user.email})
if existing_user:
return {"message": "User already exists"}
#Insert the new user into the database
user_dict = user.dict()
users_collection.insert_one(user_dict)
# Generate a token
token = generate_token(user.email)
# Convert ObjectId to string
user_dict["_id"] = str(user_dict["_id"])
# Store user details and token in local storage
user_dict["token"] = token
return user_dict
#Implement the `/api/user` route to fetch user data based on the JWT token
@app.get("/api/user")
def get_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
# Extract the token from the Authorization header
token = credentials.credentials
# Authenticate and retrieve the user data from the database based on the token
# Here, you would implement the authentication logic and fetch user details
# based on the token from the database or any other authentication mechanism
# For demonstration purposes, assuming the user data is stored in local storage
# Note: Local storage is not accessible from server-side code
# This is just a placeholder to demonstrate the concept
user_data = {
"username": "John Doe",
"email": "johndoe@example.com"
}
if user_data["username"] and user_data["email"]:
return user_data
raise HTTPException(status_code=401, detail="Invalid token")
#Define a helper function to generate a JWT token
def generate_token(email: str) -> str:
payload = {"email": email}
token = jwt_encode(payload, SECRET_KEY, algorithm="HS256")
return token
# Get time of request
def get_time():
nairobi_timezone = pytz.timezone('Africa/Nairobi')
current_time_nairobi = datetime.now(nairobi_timezone)
curr_day = current_time_nairobi.strftime('%A')
curr_date = current_time_nairobi.strftime('%Y-%m-%d')
curr_time = current_time_nairobi.strftime('%H:%M:%S')
full_date = f"{curr_day} | {curr_date} | {curr_time}"
return full_date, curr_time
def load_models():
# build model and tokenizer
model_name_dict = {
#'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
#'nllb-1.3B': 'facebook/nllb-200-1.3B',
#'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
#'nllb-3.3B': 'facebook/nllb-200-3.3B',
'nllb-moe-54b': 'facebook/nllb-moe-54b',
}
model_dict = {}
for call_name, real_name in model_name_dict.items():
print('\tLoading model: %s' % call_name)
model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
tokenizer = AutoTokenizer.from_pretrained(real_name)
model_dict[call_name+'_model'] = model
model_dict[call_name+'_tokenizer'] = tokenizer
return model_dict
# Load the model and tokenizer ..... only once!
beam_size = 1 # change to a smaller value for faster inference
device = "cpu" # or "cuda"
# Language Prediction model
print("\nimporting Language Prediction model")
lang_model_file = "lid218e.bin"
lang_model_full_path = os.path.join(os.path.dirname(__file__), lang_model_file)
lang_model = fasttext.load_model(lang_model_full_path)
# Load the source SentencePiece model
print("\nimporting SentencePiece model")
sp_model_file = "spm.model"
sp_model_full_path = os.path.join(os.path.dirname(__file__), sp_model_file)
sp = spm.SentencePieceProcessor()
sp.load(sp_model_full_path)
'''
# Import The Translator model
print("\nimporting Translator model")
ct_model_file = "sematrans-3.3B"
ct_model_full_path = os.path.join(os.path.dirname(__file__), ct_model_file)
translator = ctranslate2.Translator(ct_model_full_path, device)
'''
print("\nimporting Translator model")
model_dict = load_models()
print('\nDone importing models\n')
def translate_detect(userinput: str, target_lang: str):
source_sents = [userinput]
source_sents = [sent.strip() for sent in source_sents]
target_prefix = [[target_lang]] * len(source_sents)
# Predict the source language
predictions = lang_model.predict(source_sents[0], k=1)
source_lang = predictions[0][0].replace('__label__', '')
# Subword the source sentences
source_sents_subworded = sp.encode(source_sents, out_type=str)
source_sents_subworded = [[source_lang] + sent + ["</s>"] for sent in source_sents_subworded]
# Translate the source sentences
translations = translator.translate_batch(
source_sents_subworded,
batch_type="tokens",
max_batch_size=2024,
beam_size=beam_size,
target_prefix=target_prefix,
)
translations = [translation[0]['tokens'] for translation in translations]
# Desubword the target sentences
translations_desubword = sp.decode(translations)
translations_desubword = [sent[len(target_lang):] for sent in translations_desubword]
# Return the source language and the translated text
return source_lang, translations_desubword
def translate_enter(userinput: str, source_lang: str, target_lang: str):
source_sents = [userinput]
source_sents = [sent.strip() for sent in source_sents]
target_prefix = [[target_lang]] * len(source_sents)
# Subword the source sentences
source_sents_subworded = sp.encode(source_sents, out_type=str)
source_sents_subworded = [[source_lang] + sent + ["</s>"] for sent in source_sents_subworded]
# Translate the source sentences
translations = translator.translate_batch(source_sents_subworded, batch_type="tokens", max_batch_size=2024, beam_size=beam_size, target_prefix=target_prefix)
translations = [translation[0]['tokens'] for translation in translations]
# Desubword the target sentences
translations_desubword = sp.decode(translations)
translations_desubword = [sent[len(target_lang):] for sent in translations_desubword]
# Return the source language and the translated text
return translations_desubword[0]
def translate_faster(userinput3: str, source_lang3: str, target_lang3: str):
if len(model_dict) == 2:
model_name = 'nllb-moe-54b'
start_time = time.time()
model = model_dict[model_name + '_model']
tokenizer = model_dict[model_name + '_tokenizer']
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source_lang3, tgt_lang=target_lang3)
output = translator(userinput3, max_length=400)
end_time = time.time()
output = output[0]['translation_text']
result = {'inference_time': end_time - start_time,
'source': source,
'target': target,
'result': output}
return result
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return HTMLResponse(content=open(os.path.join(templates_folder, "translator.html"), "r").read(), status_code=200)
@app.post("/translate_detect/")
async def translate_detect_endpoint(request: Request):
datad = await request.json()
userinputd = datad.get("userinput")
target_langd = datad.get("target_lang")
dfull_date = get_time()[0]
print(f"\nrequest: {dfull_date}\nTarget Language; {target_langd}, User Input: {userinputd}\n")
if not userinputd or not target_langd:
raise HTTPException(status_code=422, detail="Both 'userinput' and 'target_lang' are required.")
source_langd, translated_text_d = translate_detect(userinputd, target_langd)
dcurrent_time = get_time()[1]
print(f"\nresponse: {dcurrent_time}; ... Source_language: {source_langd}, Translated Text: {translated_text_d}\n\n")
return {
"source_language": source_langd,
"translated_text": translated_text_d[0],
}
@app.post("/translate_enter/")
async def translate_enter_endpoint(request: Request):
datae = await request.json()
userinpute = datae.get("userinput")
source_lange = datae.get("source_lang")
target_lange = datae.get("target_lang")
efull_date = get_time()[0]
print(f"\nrequest: {efull_date}\nSource_language; {source_lange}, Target Language; {target_lange}, User Input: {userinpute}\n")
if not userinpute or not target_lange:
raise HTTPException(status_code=422, detail="'userinput' 'sourc_lang'and 'target_lang' are required.")
translated_text_e = translate_enter(userinpute, source_lange, target_lange)
ecurrent_time = get_time()[1]
print(f"\nresponse: {ecurrent_time}; ... Translated Text: {translated_text_e}\n\n")
return {
"translated_text": translated_text_e,
}
@app.post("/translate_faster/")
async def translate_faster_endpoint(request: Request):
dataf = await request.json()
userinputf = datae.get("userinput")
source_langf = datae.get("source_lang")
target_langf = datae.get("target_lang")
ffull_date = get_time()[0]
print(f"\nrequest: {ffull_date}\nSource_language; {source_langf}, Target Language; {target_langf}, User Input: {userinputf}\n")
if not userinputf or not target_langf:
raise HTTPException(status_code=422, detail="'userinput' 'sourc_lang'and 'target_lang' are required.")
translated_text_f = translate_faster(userinputf, source_langf, target_langf)
fcurrent_time = get_time()[1]
print(f"\nresponse: {fcurrent_time}; ... Translated Text: {translated_text_f}\n\n")
return {
"translated_text": translated_text_f,
}
print("\nAPI started successfully .......\n")
|