omkar56 commited on
Commit
cafd34d
1 Parent(s): f486b06

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -270
main.py CHANGED
@@ -1,118 +1,15 @@
1
- # import os
2
- # os.system("sudo apt-get install xclip")
3
- # import nltk
4
- # from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException
5
- # from fastapi.security.api_key import APIKeyHeader
6
- # from typing import Optional, Annotated
7
- # from fastapi.encoders import jsonable_encoder
8
- # from PIL import Image
9
- # from io import BytesIO
10
- # import pytesseract
11
- # from nltk.tokenize import sent_tokenize
12
- # from transformers import MarianMTModel, MarianTokenizer
13
- # nltk.download('punkt')
14
-
15
- # API_KEY = os.environ.get("API_KEY")
16
-
17
- # app = FastAPI()
18
- # api_key_header = APIKeyHeader(name="api_key", auto_error=False)
19
-
20
- # def get_api_key(api_key: Optional[str] = Depends(api_key_header)):
21
- # if api_key is None or api_key != API_KEY:
22
- # raise HTTPException(status_code=401, detail="Unauthorized access")
23
- # return api_key
24
-
25
- # # Image path
26
- # img_dir = "./data"
27
- # # Get tesseract language list
28
- # choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
29
- # # Convert tesseract language list to pytesseract language
30
- # def ocr_lang(lang_list):
31
- # lang_str = ""
32
- # lang_len = len(lang_list)
33
- # if lang_len == 1:
34
- # return lang_list[0]
35
- # else:
36
- # for i in range(lang_len):
37
- # lang_list.insert(lang_len - i, "+")
38
-
39
- # lang_str = "".join(lang_list[:-1])
40
- # return lang_str
41
- # # ocr tesseract
42
- # def ocr_tesseract(img, languages):
43
- # print("[img]", img)
44
- # print("[languages]", languages)
45
- # ocr_str = pytesseract.image_to_string(img, lang=ocr_lang(languages))
46
- # return ocr_str
47
-
48
- # @app.post("/api/ocr", response_model=dict)
49
- # async def ocr(
50
- # api_key: str = Depends(get_api_key),
51
- # image: UploadFile = File(...),
52
- # # languages: list = Body(["eng"])
53
- # ):
54
-
55
- # try:
56
- # content = await image.read()
57
- # image = Image.open(BytesIO(content))
58
- # print("[image]",image)
59
- # if hasattr(pytesseract, "image_to_string"):
60
- # print("Image to string function is available")
61
- # # print(pytesseract.image_to_string(image, lang = 'eng'))
62
- # text = ocr_tesseract(image, ['eng'])
63
- # else:
64
- # print("Image to string function is not available")
65
- # # text = pytesseract.image_to_string(image, lang="+".join(languages))
66
- # except Exception as e:
67
- # return {"error": str(e)}, 500
68
-
69
- # return {"ImageText": "text"}
70
-
71
- # @app.post("/api/translate", response_model=dict)
72
- # async def translate(
73
- # api_key: str = Depends(get_api_key),
74
- # text: str = Body(...),
75
- # src: str = "en",
76
- # trg: str = "zh",
77
- # ):
78
- # if api_key != API_KEY:
79
- # return {"error": "Invalid API key"}, 401
80
-
81
- # tokenizer, model = get_model(src, trg)
82
-
83
- # translated_text = ""
84
- # for sentence in sent_tokenize(text):
85
- # translated_sub = model.generate(**tokenizer(sentence, return_tensors="pt"))[0]
86
- # translated_text += tokenizer.decode(translated_sub, skip_special_tokens=True) + "\n"
87
-
88
- # return jsonable_encoder({"translated_text": translated_text})
89
-
90
- # def get_model(src: str, trg: str):
91
- # model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
92
- # tokenizer = MarianTokenizer.from_pretrained(model_name)
93
- # model = MarianMTModel.from_pretrained(model_name)
94
- # return tokenizer, model
95
-
96
- # OCR Translate v0.2
97
-
98
-
99
  import os
100
-
101
  os.system("sudo apt-get install xclip")
102
-
103
- # import gradio as gr
104
  import nltk
105
- import pyclip
106
- import pytesseract
107
- from nltk.tokenize import sent_tokenize
108
- from transformers import MarianMTModel, MarianTokenizer
109
- # Added below code
110
  from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException
111
  from fastapi.security.api_key import APIKeyHeader
112
  from typing import Optional, Annotated
113
  from fastapi.encoders import jsonable_encoder
114
  from PIL import Image
115
  from io import BytesIO
 
 
 
116
 
117
  API_KEY = os.environ.get("API_KEY")
118
 
@@ -130,13 +27,14 @@ async def ocr(
130
  image: UploadFile = File(...),
131
  # languages: list = Body(["eng"])
132
  ):
 
133
  try:
134
  content = await image.read()
135
  image = Image.open(BytesIO(content))
136
  print("[image]",image)
137
  if hasattr(pytesseract, "image_to_string"):
138
  print("Image to string function is available")
139
- # print(pytesseract.image_to_string(image, lang = 'eng'))
140
  text = ocr_tesseract(image, ['eng'])
141
  else:
142
  print("Image to string function is not available")
@@ -146,171 +44,27 @@ async def ocr(
146
 
147
  return {"ImageText": "text"}
148
 
149
- nltk.download('punkt')
150
-
151
- OCR_TR_DESCRIPTION = '''# OCR Translate v0.2
152
- <div id="content_align">OCR translation system based on Tesseract</div>'''
153
-
154
- # Image path
155
- img_dir = "./data"
156
-
157
- # Get tesseract language list
158
- choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
159
 
 
160
 
161
- # Translation model selection
162
- def model_choice(src="en", trg="zh"):
163
- # https://huggingface.co/Helsinki-NLP/opus-mt-zh-en
164
- # https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
165
- model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}" # Model name
166
 
167
- tokenizer = MarianTokenizer.from_pretrained(model_name) # tokenizer
168
- model = MarianMTModel.from_pretrained(model_name) # Model
169
 
 
 
 
 
170
  return tokenizer, model
171
-
172
-
173
- # Convert tesseract language list to pytesseract language
174
- def ocr_lang(lang_list):
175
- lang_str = ""
176
- lang_len = len(lang_list)
177
- if lang_len == 1:
178
- return lang_list[0]
179
- else:
180
- for i in range(lang_len):
181
- lang_list.insert(lang_len - i, "+")
182
-
183
- lang_str = "".join(lang_list[:-1])
184
- return lang_str
185
-
186
-
187
- # ocr tesseract
188
- def ocr_tesseract(img, languages):
189
- print("[img]", img)
190
- print("[languages]", languages)
191
- ocr_str = pytesseract.image_to_string(img, lang=ocr_lang(languages))
192
- return ocr_str
193
-
194
-
195
- # Clear
196
- def clear_content():
197
- return None
198
-
199
-
200
- # copy to clipboard
201
- def cp_text(input_text):
202
- # sudo apt-get install xclip
203
- try:
204
- pyclip.copy(input_text)
205
- except Exception as e:
206
- print("sudo apt-get install xclip")
207
- print(e)
208
-
209
-
210
- # clear clipboard
211
- def cp_clear():
212
- pyclip.clear()
213
-
214
-
215
- # translate
216
- def translate(input_text, inputs_transStyle):
217
- # reference:https://huggingface.co/docs/transformers/model_doc/marian
218
- if input_text is None or input_text == "":
219
- return "System prompt: There is no content to translate!"
220
-
221
- # Select translation model
222
- trans_src, trans_trg = inputs_transStyle.split("-")[0], inputs_transStyle.split("-")[1]
223
- tokenizer, model = model_choice(trans_src, trans_trg)
224
-
225
- translate_text = ""
226
- input_text_list = input_text.split("\n\n")
227
-
228
- translate_text_list_tmp = []
229
- for i in range(len(input_text_list)):
230
- if input_text_list[i] != "":
231
- translate_text_list_tmp.append(input_text_list[i])
232
-
233
- for i in range(len(translate_text_list_tmp)):
234
- translated_sub = model.generate(
235
- **tokenizer(sent_tokenize(translate_text_list_tmp[i]), return_tensors="pt", truncation=True, padding=True))
236
- tgt_text_sub = [tokenizer.decode(t, skip_special_tokens=True) for t in translated_sub]
237
- translate_text_sub = "".join(tgt_text_sub)
238
- translate_text = translate_text + "\n\n" + translate_text_sub
239
-
240
- return translate_text[2:]
241
-
242
-
243
- # def main():
244
-
245
- # with gr.Blocks(css='style.css') as ocr_tr:
246
- # gr.Markdown(OCR_TR_DESCRIPTION)
247
-
248
- # # -------------- OCR text extraction --------------
249
- # with gr.Box():
250
-
251
- # with gr.Row():
252
- # gr.Markdown("### Step 01: Text Extraction")
253
-
254
- # with gr.Row():
255
- # with gr.Column():
256
- # with gr.Row():
257
- # inputs_img = gr.Image(image_mode="RGB", source="upload", type="pil", label="image")
258
- # with gr.Row():
259
- # inputs_lang = gr.CheckboxGroup(choices=["chi_sim", "eng"],
260
- # type="value",
261
- # value=['eng'],
262
- # label='language')
263
-
264
- # with gr.Row():
265
- # clear_img_btn = gr.Button('Clear')
266
- # ocr_btn = gr.Button(value='OCR Extraction', variant="primary")
267
-
268
- # with gr.Column():
269
- # with gr.Row():
270
- # outputs_text = gr.Textbox(label="Extract content", lines=20)
271
- # with gr.Row():
272
- # inputs_transStyle = gr.Radio(choices=["zh-en", "en-zh"],
273
- # type="value",
274
- # value="zh-en",
275
- # label='translation mode')
276
- # with gr.Row():
277
- # clear_text_btn = gr.Button('Clear')
278
- # translate_btn = gr.Button(value='Translate', variant="primary")
279
-
280
- # with gr.Row():
281
- # example_list = [["./data/test.png", ["eng"]], ["./data/test02.png", ["eng"]],
282
- # ["./data/test03.png", ["chi_sim"]]]
283
- # gr.Examples(example_list, [inputs_img, inputs_lang], outputs_text, ocr_tesseract, cache_examples=False)
284
-
285
- # # -------------- translate --------------
286
- # with gr.Box():
287
-
288
- # with gr.Row():
289
- # gr.Markdown("### Step 02: Translation")
290
-
291
- # with gr.Row():
292
- # outputs_tr_text = gr.Textbox(label="Translate Content", lines=20)
293
-
294
- # with gr.Row():
295
- # cp_clear_btn = gr.Button(value='Clear Clipboard')
296
- # cp_btn = gr.Button(value='Copy to clipboard', variant="primary")
297
-
298
- # # ---------------------- OCR Tesseract ----------------------
299
- # ocr_btn.click(fn=ocr_tesseract, inputs=[inputs_img, inputs_lang], outputs=[
300
- # outputs_text,])
301
- # clear_img_btn.click(fn=clear_content, inputs=[], outputs=[inputs_img])
302
-
303
- # # ---------------------- translate ----------------------
304
- # translate_btn.click(fn=translate, inputs=[outputs_text, inputs_transStyle], outputs=[outputs_tr_text])
305
- # clear_text_btn.click(fn=clear_content, inputs=[], outputs=[outputs_text])
306
-
307
- # # ---------------------- copy to clipboard ----------------------
308
- # cp_btn.click(fn=cp_text, inputs=[outputs_tr_text], outputs=[])
309
- # cp_clear_btn.click(fn=cp_clear, inputs=[], outputs=[])
310
-
311
- # ocr_tr.launch(inbrowser=True)
312
-
313
-
314
- # if __name__ == '__main__':
315
- # main()
316
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  os.system("sudo apt-get install xclip")
 
 
3
  import nltk
 
 
 
 
 
4
  from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException
5
  from fastapi.security.api_key import APIKeyHeader
6
  from typing import Optional, Annotated
7
  from fastapi.encoders import jsonable_encoder
8
  from PIL import Image
9
  from io import BytesIO
10
+ import pytesseract
11
+ from nltk.tokenize import sent_tokenize
12
+ from transformers import MarianMTModel, MarianTokenizer
13
 
14
  API_KEY = os.environ.get("API_KEY")
15
 
 
27
  image: UploadFile = File(...),
28
  # languages: list = Body(["eng"])
29
  ):
30
+
31
  try:
32
  content = await image.read()
33
  image = Image.open(BytesIO(content))
34
  print("[image]",image)
35
  if hasattr(pytesseract, "image_to_string"):
36
  print("Image to string function is available")
37
+ print(pytesseract.image_to_string(image, lang = 'eng'))
38
  text = ocr_tesseract(image, ['eng'])
39
  else:
40
  print("Image to string function is not available")
 
44
 
45
  return {"ImageText": "text"}
46
 
47
+ @app.post("/api/translate", response_model=dict)
48
+ async def translate(
49
+ api_key: str = Depends(get_api_key),
50
+ text: str = Body(...),
51
+ src: str = "en",
52
+ trg: str = "zh",
53
+ ):
54
+ if api_key != API_KEY:
55
+ return {"error": "Invalid API key"}, 401
 
56
 
57
+ tokenizer, model = get_model(src, trg)
58
 
59
+ translated_text = ""
60
+ for sentence in sent_tokenize(text):
61
+ translated_sub = model.generate(**tokenizer(sentence, return_tensors="pt"))[0]
62
+ translated_text += tokenizer.decode(translated_sub, skip_special_tokens=True) + "\n"
 
63
 
64
+ return jsonable_encoder({"translated_text": translated_text})
 
65
 
66
+ def get_model(src: str, trg: str):
67
+ model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
68
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
69
+ model = MarianMTModel.from_pretrained(model_name)
70
  return tokenizer, model