trans-en-indic / app.py
darshankr's picture
Update app.py
45a86ac verified
raw
history blame
5 kB
# app.py
import streamlit as st
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import json
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from starlette.applications import Starlette
from starlette.routing import Mount, Route
from starlette.staticfiles import StaticFiles
import asyncio
import nest_asyncio
# Enable nested event loops
nest_asyncio.apply()
# Initialize FastAPI
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize models and processors
model = AutoModelForSeq2SeqLM.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B",
trust_remote_code=True
)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)
class TranslationRequest(BaseModel):
sentences: List[str]
target_lang: str
def translate_text(sentences: List[str], target_lang: str):
try:
src_lang = "eng_Latn"
batch = ip.preprocess_batch(
sentences,
src_lang=src_lang,
tgt_lang=target_lang
)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1
)
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
return {
"translations": translations,
"source_language": src_lang,
"target_language": target_lang
}
except Exception as e:
raise Exception(f"Translation failed: {str(e)}")
# FastAPI routes
@app.get("/api/health")
async def health_check():
return {"status": "healthy"}
@app.post("/api/translate")
async def translate_endpoint(request: TranslationRequest):
try:
result = translate_text(
sentences=request.sentences,
target_lang=request.target_lang
)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Streamlit interface
def streamlit_app():
st.title("Indic Language Translator")
# Input text
text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
# Language selection
target_languages = {
"Hindi": "hin_Deva",
"Bengali": "ben_Beng",
"Tamil": "tam_Taml",
"Telugu": "tel_Telu",
"Marathi": "mar_Deva",
"Gujarati": "guj_Gujr",
"Kannada": "kan_Knda",
"Malayalam": "mal_Mlym",
"Punjabi": "pan_Guru",
"Odia": "ori_Orya"
}
target_lang = st.selectbox(
"Select target language:",
options=list(target_languages.keys())
)
if st.button("Translate"):
try:
result = translate_text(
sentences=[text_input],
target_lang=target_languages[target_lang]
)
st.success("Translation:")
st.write(result["translations"][0])
except Exception as e:
st.error(f"Translation failed: {str(e)}")
# Add API documentation
st.markdown("---")
st.header("API Documentation")
st.markdown("""
To use the translation API, send POST requests to:
```
https://darshankr-trans-en-indic.hf.space/api/translate
```
Request body format:
```json
{
"sentences": ["Your text here"],
"target_lang": "hin_Deva"
}
```
""")
st.markdown("Available target languages:")
for lang, code in target_languages.items():
st.markdown(f"- {lang}: `{code}`")
# Create a unified application
def create_app():
routes = [
Mount("/api", app),
Mount("/", StaticFiles(directory="static", html=True), name="static"),
]
return Starlette(routes=routes)
if __name__ == "__main__":
if "streamlit" in sys.argv[0]:
streamlit_app()
else:
uvicorn.run(create_app(), host="0.0.0.0", port=7860)