Spaces:
Runtime error
Runtime error
# app.py | |
import streamlit as st | |
from fastapi import FastAPI | |
from typing import List | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from IndicTransToolkit import IndicProcessor | |
import json | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
# Initialize FastAPI | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize models and processors | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True | |
) | |
ip = IndicProcessor(inference=True) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(DEVICE) | |
def translate_text(sentences: List[str], target_lang: str): | |
try: | |
src_lang = "eng_Latn" | |
batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang) | |
inputs = tokenizer( | |
batch, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(DEVICE) | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=256, | |
num_beams=5, | |
num_return_sequences=1, | |
) | |
with tokenizer.as_target_tokenizer(): | |
generated_tokens = tokenizer.batch_decode( | |
generated_tokens.detach().cpu().tolist(), | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
) | |
translations = ip.postprocess_batch(generated_tokens, lang=target_lang) | |
return { | |
"translations": translations, | |
"source_language": src_lang, | |
"target_language": target_lang, | |
} | |
except Exception as e: | |
raise Exception(f"Translation failed: {str(e)}") | |
# FastAPI routes | |
async def health_check(): | |
return {"status": "healthy"} | |
async def translate_endpoint(sentences: List[str], target_lang: str): | |
try: | |
result = translate_text(sentences=sentences, target_lang=target_lang) | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# # Streamlit interface | |
# def main(): | |
# st.title("Indic Language Translator") | |
# # Input text | |
# text_input = st.text_area("Enter text to translate:", "Hello, how are you?") | |
# # Language selection | |
# target_languages = { | |
# "Hindi": "hin_Deva", | |
# "Bengali": "ben_Beng", | |
# "Tamil": "tam_Taml", | |
# "Telugu": "tel_Telu", | |
# "Marathi": "mar_Deva", | |
# "Gujarati": "guj_Gujr", | |
# "Kannada": "kan_Knda", | |
# "Malayalam": "mal_Mlym", | |
# "Punjabi": "pan_Guru", | |
# "Odia": "ori_Orya", | |
# } | |
# target_lang = st.selectbox( | |
# "Select target language:", options=list(target_languages.keys()) | |
# ) | |
# if st.button("Translate"): | |
# try: | |
# result = translate_text( | |
# sentences=[text_input], target_lang=target_languages[target_lang] | |
# ) | |
# st.success("Translation:") | |
# st.write(result["translations"][0]) | |
# except Exception as e: | |
# st.error(f"Translation failed: {str(e)}") | |
# # Add API documentation | |
# st.markdown("---") | |
# st.header("API Documentation") | |
# st.markdown( | |
# """ | |
# To use the translation API, send POST requests to: | |
# ``` | |
# https://darshankr-trans-en-indic.hf.space/translate | |
# ``` | |
# Request body format: | |
# ```json | |
# { | |
# "sentences": ["Your text here"], | |
# "target_lang": "hin_Deva" | |
# } | |
# ``` | |
# """ | |
# ) | |
# st.markdown("Available target languages:") | |
# for lang, code in target_languages.items(): | |
# st.markdown(f"- {lang}: `{code}`") | |
# if __name__ == "__main__": | |
# # Run both Streamlit and FastAPI | |
# import threading | |
# def run_fastapi(): | |
# uvicorn.run(api, host="0.0.0.0", port=8000) | |
# # Start FastAPI in a separate thread | |
# api_thread = threading.Thread(target=run_fastapi) | |
# api_thread.start() | |
# # Run Streamlit | |
# main() | |