ake178178 commited on
Commit
07189f7
·
verified ·
1 Parent(s): fc6cd94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -31
app.py CHANGED
@@ -1,40 +1,30 @@
1
  import gradio as gr
2
- import os
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
  from PIL import Image
 
5
 
6
- # 加载模型和处理器
7
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
8
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
9
 
10
- # 定义图像OCR识别函数
11
- def ocr_images(images):
12
- results = {}
13
- for image in images:
14
- # 确保图片是RGB格式
15
- image = Image.open(image).convert("RGB")
16
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
17
- output_ids = model.generate(pixel_values)
18
- transcription = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
19
- results[image.filename] = transcription
20
- return results
21
 
22
- # 定义Gradio界面
23
- def ocr_interface(images):
24
- results = ocr_images(images)
25
- result_text = "\n\n".join([f"{filename}:\n{transcription}" for filename, transcription in results.items()])
26
- return result_text
 
 
 
27
 
28
- # 创建Gradio应用
29
- with gr.Blocks() as demo:
30
- gr.Markdown("## 多图片OCR识别")
31
- with gr.Row():
32
- image_input = gr.File(label="选择多张图片", file_count="multiple", type="file")
33
- output_text = gr.Textbox(label="OCR 识别结果")
34
-
35
- # 添加按钮和功能绑定
36
- submit_button = gr.Button("开始识别")
37
- submit_button.click(ocr_interface, inputs=image_input, outputs=output_text)
38
-
39
- # 启动应用
40
- demo.launch()
 
1
  import gradio as gr
 
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
+ import torch
5
 
6
+ # 加载模型
7
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
8
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
9
 
10
+ # 定义图像文字识别函数
11
+ def recognize_text_from_image(image):
12
+ # 图像预处理
13
+ image = Image.open(image).convert("RGB")
14
+ # 使用OCR模型预测
15
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
16
+ generated_ids = model.generate(pixel_values)
17
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
18
+ return text
 
 
19
 
20
+ # Gradio界面
21
+ iface = gr.Interface(
22
+ fn=recognize_text_from_image,
23
+ inputs=gr.Image(type="file", label="上传文件图片"),
24
+ outputs="text",
25
+ title="桌面文件扫描器",
26
+ description="上传文件图片,将自动识别其中的文字内容"
27
+ )
28
 
29
+ # 启动 Gradio
30
+ iface.launch()