darshankr commited on
Commit
a0c8166
·
verified ·
1 Parent(s): 4bce69c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -40
app.py CHANGED
@@ -7,12 +7,8 @@ import torch
7
  import asyncio
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from IndicTransToolkit import IndicProcessor
10
- import uvicorn
11
- import nest_asyncio
12
- import threading
13
-
14
- # Initialize FastAPI
15
- app = FastAPI()
16
 
17
  # Initialize models and processors
18
  model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -27,25 +23,13 @@ ip = IndicProcessor(inference=True)
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
  model = model.to(DEVICE)
29
 
30
- class InputData(BaseModel):
31
- sentences: List[str]
32
- target_lang: str
33
-
34
- # FastAPI endpoints
35
- @app.get("/health")
36
- async def health_check():
37
- return {"status": "healthy"}
38
-
39
- @app.post("/translate/")
40
- async def translate(input_data: InputData):
41
  try:
42
  src_lang = "eng_Latn"
43
- tgt_lang = input_data.target_lang
44
-
45
  batch = ip.preprocess_batch(
46
- input_data.sentences,
47
  src_lang=src_lang,
48
- tgt_lang=tgt_lang
49
  )
50
 
51
  inputs = tokenizer(
@@ -73,16 +57,16 @@ async def translate(input_data: InputData):
73
  clean_up_tokenization_spaces=True
74
  )
75
 
76
- translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
77
 
78
  return {
79
  "translations": translations,
80
  "source_language": src_lang,
81
- "target_language": tgt_lang
82
  }
83
 
84
  except Exception as e:
85
- raise HTTPException(status_code=500, detail=str(e))
86
 
87
  # Streamlit interface
88
  def main():
@@ -112,18 +96,11 @@ def main():
112
 
113
  if st.button("Translate"):
114
  try:
115
- # Prepare input data
116
- input_data = InputData(
117
  sentences=[text_input],
118
  target_lang=target_languages[target_lang]
119
  )
120
 
121
- # Create event loop and run translation
122
- loop = asyncio.new_event_loop()
123
- asyncio.set_event_loop(loop)
124
- result = loop.run_until_complete(translate(input_data))
125
- loop.close()
126
-
127
  # Display result
128
  st.success("Translation:")
129
  st.write(result["translations"][0])
@@ -131,14 +108,35 @@ def main():
131
  except Exception as e:
132
  st.error(f"Translation failed: {str(e)}")
133
 
134
- def run_fastapi():
135
- nest_asyncio.apply()
136
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  if __name__ == "__main__":
139
- # Start FastAPI in a separate thread
140
- api_thread = threading.Thread(target=run_fastapi, daemon=True)
141
- api_thread.start()
142
-
143
- # Run Streamlit interface
144
  main()
 
7
  import asyncio
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from IndicTransToolkit import IndicProcessor
10
+ import requests
11
+ import json
 
 
 
 
12
 
13
  # Initialize models and processors
14
  model = AutoModelForSeq2SeqLM.from_pretrained(
 
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
  model = model.to(DEVICE)
25
 
26
+ def translate_text(sentences: List[str], target_lang: str):
 
 
 
 
 
 
 
 
 
 
27
  try:
28
  src_lang = "eng_Latn"
 
 
29
  batch = ip.preprocess_batch(
30
+ sentences,
31
  src_lang=src_lang,
32
+ tgt_lang=target_lang
33
  )
34
 
35
  inputs = tokenizer(
 
57
  clean_up_tokenization_spaces=True
58
  )
59
 
60
+ translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
61
 
62
  return {
63
  "translations": translations,
64
  "source_language": src_lang,
65
+ "target_language": target_lang
66
  }
67
 
68
  except Exception as e:
69
+ raise Exception(f"Translation failed: {str(e)}")
70
 
71
  # Streamlit interface
72
  def main():
 
96
 
97
  if st.button("Translate"):
98
  try:
99
+ result = translate_text(
 
100
  sentences=[text_input],
101
  target_lang=target_languages[target_lang]
102
  )
103
 
 
 
 
 
 
 
104
  # Display result
105
  st.success("Translation:")
106
  st.write(result["translations"][0])
 
108
  except Exception as e:
109
  st.error(f"Translation failed: {str(e)}")
110
 
111
+ # Add API documentation
112
+ st.markdown("---")
113
+ st.header("API Documentation")
114
+ st.markdown("""
115
+ To use the translation API, send POST requests to:
116
+ ```
117
+ https://USERNAME-SPACE_NAME.hf.space/translate
118
+ ```
119
+
120
+ Request body format:
121
+ ```json
122
+ {
123
+ "sentences": ["Your text here"],
124
+ "target_lang": "hin_Deva"
125
+ }
126
+ ```
127
+
128
+ Available target languages:
129
+ - Hindi: `hin_Deva`
130
+ - Bengali: `ben_Beng`
131
+ - Tamil: `tam_Taml`
132
+ - Telugu: `tel_Telu`
133
+ - Marathi: `mar_Deva`
134
+ - Gujarati: `guj_Gujr`
135
+ - Kannada: `kan_Knda`
136
+ - Malayalam: `mal_Mlym`
137
+ - Punjabi: `pan_Guru`
138
+ - Odia: `ori_Orya`
139
+ """)
140
 
141
  if __name__ == "__main__":
 
 
 
 
 
142
  main()