realkun commited on
Commit
9d3f1f1
1 Parent(s): 211aa8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -23
app.py CHANGED
@@ -1,30 +1,108 @@
1
- from transformers import AutoProcessor, Pix2StructForConditionalGeneration, Pix2StructProcessor
2
- import requests
3
- import json
 
 
 
4
  from PIL import Image
5
 
6
- model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
7
- processor = AutoProcessor.from_pretrained("google/deplot")
8
- # processor = Pix2StructProcessor.from_pretrained('google/deplot')
9
- # url = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/5090.png"
10
- # image = Image.open(requests.get(url, stream=True).raw)
11
- image = Image.open('05e57f1c9acff69f1eb6fa72d4805d0.jpg')
 
 
 
 
12
 
13
- inputs = processor(images=image, text="Generate underlying data table of the figure below:", return_tensors="pt")
14
- predictions = model.generate(**inputs, max_new_tokens=512)
15
- print("prediction")
16
- print(processor.decode(predictions[0], skip_special_tokens=True))
 
 
 
 
 
 
17
 
18
- raw_output = processor.decode(predictions[0], skip_special_tokens=True)
19
- split_by_newline = raw_output.split("<0x0A>")
20
- result_array = []
 
 
 
21
 
22
- for item in split_by_newline:
23
-     result_array.append([x.strip() for x in item.split("|")])
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- print("result:")
26
- print(result_array)
 
 
 
 
 
 
 
27
 
28
- with open('test.log', mode='w') as file:
29
-     for row in result_array:
30
-         file.write(" | ".join(row) + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ from datetime import datetime
5
+ import torch
6
+ from transformers import AutoProcessor, Pix2StructForConditionalGeneration
7
  from PIL import Image
8
 
9
+ # 配置日志
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format='%(asctime)s - %(levelname)s - %(message)s',
13
+ handlers=[
14
+ logging.StreamHandler(sys.stdout),
15
+ logging.FileHandler('app.log')
16
+ ]
17
+ )
18
+ logger = logging.getLogger(__name__)
19
 
20
+ class ChartAnalyzer:
21
+ def __init__(self):
22
+ try:
23
+ logger.info("Initializing model and processor...")
24
+ self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
25
+ self.processor = AutoProcessor.from_pretrained("google/deplot")
26
+ logger.info("Model and processor initialized successfully")
27
+ except Exception as e:
28
+ logger.error(f"Error initializing model: {str(e)}")
29
+ raise
30
 
31
+ def process_image(self, image_path, prompt=None):
32
+ """处理图片并生成数据表格"""
33
+ try:
34
+ # 验证文件存在
35
+ if not os.path.exists(image_path):
36
+ raise FileNotFoundError(f"Image file not found: {image_path}")
37
 
38
+ # 打开并处理图片
39
+ logger.info(f"Processing image: {image_path}")
40
+ image = Image.open(image_path)
41
+
42
+ # 准备输入
43
+ if prompt is None:
44
+ prompt = "Generate underlying data table of the figure below:"
45
+
46
+ inputs = self.processor(
47
+ images=image,
48
+ text=prompt,
49
+ return_tensors="pt"
50
+ )
51
 
52
+ # 生成预测
53
+ logger.info("Generating predictions...")
54
+ with torch.no_grad(): # 提高性能并减少内存使用
55
+ predictions = self.model.generate(
56
+ **inputs,
57
+ max_new_tokens=512,
58
+ num_beams=4,
59
+ length_penalty=1.0
60
+ )
61
 
62
+ # 解码预测结果
63
+ raw_output = self.processor.decode(predictions[0], skip_special_tokens=True)
64
+
65
+ # 处理结果
66
+ split_by_newline = raw_output.split("<0x0A>")
67
+ result_array = []
68
+ for item in split_by_newline:
69
+ if item.strip(): # 跳过空行
70
+ result_array.append([x.strip() for x in item.split("|")])
71
+
72
+ # 保存结果
73
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
74
+ output_file = f'results_{timestamp}.log'
75
+
76
+ with open(output_file, mode='w', encoding='utf-8') as file:
77
+ for row in result_array:
78
+ file.write(" | ".join(row) + "\n")
79
+
80
+ logger.info(f"Results saved to {output_file}")
81
+ return result_array
82
+
83
+ except Exception as e:
84
+ logger.error(f"Error processing image: {str(e)}")
85
+ raise
86
+
87
+ def main():
88
+ try:
89
+ # 创建分析器实例
90
+ analyzer = ChartAnalyzer()
91
+
92
+ # 指定图片路径(在Space中使用上传的图片路径)
93
+ image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
94
+
95
+ # 处理图片
96
+ results = analyzer.process_image(image_path)
97
+
98
+ # 打印结果
99
+ print("\nAnalysis Results:")
100
+ for row in results:
101
+ print(" | ".join(row))
102
+
103
+ except Exception as e:
104
+ logger.error(f"Application error: {str(e)}")
105
+ raise
106
+
107
+ if __name__ == "__main__":
108
+ main()