File size: 4,474 Bytes
b6cd2a4
 
 
 
3b25c41
 
 
0b7c166
b6cd2a4
 
 
 
6f55a35
b6cd2a4
 
6f55a35
b6cd2a4
 
 
 
 
 
72ec471
0b7c166
b6cd2a4
6f55a35
b6cd2a4
 
6f55a35
b6cd2a4
0b7c166
 
 
45a86ac
6f55a35
a0c8166
3b25c41
 
6f55a35
b6cd2a4
 
 
 
 
6f55a35
b6cd2a4
6f55a35
3b25c41
 
b6cd2a4
3b25c41
 
 
 
6f55a35
3b25c41
6f55a35
3b25c41
 
 
 
6f55a35
3b25c41
6f55a35
a0c8166
3b25c41
 
 
6f55a35
3b25c41
 
a0c8166
3b25c41
6f55a35
b6cd2a4
6f55a35
b6cd2a4
 
 
6f55a35
 
b6cd2a4
 
 
 
 
 
 
6f55a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# app.py
import streamlit as st
from fastapi import FastAPI
from typing import List
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import json
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

# Initialize FastAPI
app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize models and processors
model = AutoModelForSeq2SeqLM.from_pretrained(
    "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
    "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)


def translate_text(sentences: List[str], target_lang: str):
    try:
        src_lang = "eng_Latn"
        batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
        inputs = tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)

        with torch.no_grad():
            generated_tokens = model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )

        with tokenizer.as_target_tokenizer():
            generated_tokens = tokenizer.batch_decode(
                generated_tokens.detach().cpu().tolist(),
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

        translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
        return {
            "translations": translations,
            "source_language": src_lang,
            "target_language": target_lang,
        }
    except Exception as e:
        raise Exception(f"Translation failed: {str(e)}")


# FastAPI routes
@app.get("/health")
async def health_check():
    return {"status": "healthy"}


@app.post("/translate")
async def translate_endpoint(sentences: List[str], target_lang: str):
    try:
        result = translate_text(sentences=sentences, target_lang=target_lang)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


# # Streamlit interface
# def main():
#     st.title("Indic Language Translator")

#     # Input text
#     text_input = st.text_area("Enter text to translate:", "Hello, how are you?")

#     # Language selection
#     target_languages = {
#         "Hindi": "hin_Deva",
#         "Bengali": "ben_Beng",
#         "Tamil": "tam_Taml",
#         "Telugu": "tel_Telu",
#         "Marathi": "mar_Deva",
#         "Gujarati": "guj_Gujr",
#         "Kannada": "kan_Knda",
#         "Malayalam": "mal_Mlym",
#         "Punjabi": "pan_Guru",
#         "Odia": "ori_Orya",
#     }

#     target_lang = st.selectbox(
#         "Select target language:", options=list(target_languages.keys())
#     )

#     if st.button("Translate"):
#         try:
#             result = translate_text(
#                 sentences=[text_input], target_lang=target_languages[target_lang]
#             )
#             st.success("Translation:")
#             st.write(result["translations"][0])
#         except Exception as e:
#             st.error(f"Translation failed: {str(e)}")

#     # Add API documentation
#     st.markdown("---")
#     st.header("API Documentation")
#     st.markdown(
#         """
#     To use the translation API, send POST requests to:
#     ```
#     https://darshankr-trans-en-indic.hf.space/translate
#     ```
#     Request body format:
#     ```json
#     {
#         "sentences": ["Your text here"],
#         "target_lang": "hin_Deva"
#     }
#     ```
#     """
#     )
#     st.markdown("Available target languages:")
#     for lang, code in target_languages.items():
#         st.markdown(f"- {lang}: `{code}`")


# if __name__ == "__main__":
#     # Run both Streamlit and FastAPI
#     import threading

#     def run_fastapi():
#         uvicorn.run(api, host="0.0.0.0", port=8000)

#     # Start FastAPI in a separate thread
#     api_thread = threading.Thread(target=run_fastapi)
#     api_thread.start()

#     # Run Streamlit
#     main()