trans-en-indic / app.py
Darshan
Add apis
6f55a35
# app.py
import streamlit as st
from fastapi import FastAPI
from typing import List
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import json
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
# Initialize FastAPI
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize models and processors
model = AutoModelForSeq2SeqLM.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)
def translate_text(sentences: List[str], target_lang: str):
try:
src_lang = "eng_Latn"
batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
return {
"translations": translations,
"source_language": src_lang,
"target_language": target_lang,
}
except Exception as e:
raise Exception(f"Translation failed: {str(e)}")
# FastAPI routes
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.post("/translate")
async def translate_endpoint(sentences: List[str], target_lang: str):
try:
result = translate_text(sentences=sentences, target_lang=target_lang)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# # Streamlit interface
# def main():
# st.title("Indic Language Translator")
# # Input text
# text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
# # Language selection
# target_languages = {
# "Hindi": "hin_Deva",
# "Bengali": "ben_Beng",
# "Tamil": "tam_Taml",
# "Telugu": "tel_Telu",
# "Marathi": "mar_Deva",
# "Gujarati": "guj_Gujr",
# "Kannada": "kan_Knda",
# "Malayalam": "mal_Mlym",
# "Punjabi": "pan_Guru",
# "Odia": "ori_Orya",
# }
# target_lang = st.selectbox(
# "Select target language:", options=list(target_languages.keys())
# )
# if st.button("Translate"):
# try:
# result = translate_text(
# sentences=[text_input], target_lang=target_languages[target_lang]
# )
# st.success("Translation:")
# st.write(result["translations"][0])
# except Exception as e:
# st.error(f"Translation failed: {str(e)}")
# # Add API documentation
# st.markdown("---")
# st.header("API Documentation")
# st.markdown(
# """
# To use the translation API, send POST requests to:
# ```
# https://darshankr-trans-en-indic.hf.space/translate
# ```
# Request body format:
# ```json
# {
# "sentences": ["Your text here"],
# "target_lang": "hin_Deva"
# }
# ```
# """
# )
# st.markdown("Available target languages:")
# for lang, code in target_languages.items():
# st.markdown(f"- {lang}: `{code}`")
# if __name__ == "__main__":
# # Run both Streamlit and FastAPI
# import threading
# def run_fastapi():
# uvicorn.run(api, host="0.0.0.0", port=8000)
# # Start FastAPI in a separate thread
# api_thread = threading.Thread(target=run_fastapi)
# api_thread.start()
# # Run Streamlit
# main()