realkun commited on
Commit
eee8d68
1 Parent(s): 9d3f1f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -12
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from transformers import AutoProcessor, Pix2StructForConditionalGeneration
7
  from PIL import Image
8
 
9
- # 配置日志
10
  logging.basicConfig(
11
  level=logging.INFO,
12
  format='%(asctime)s - %(levelname)s - %(message)s',
@@ -17,26 +17,59 @@ logging.basicConfig(
17
  )
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class ChartAnalyzer:
21
  def __init__(self):
22
  try:
23
- logger.info("Initializing model and processor...")
 
24
  self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
25
  self.processor = AutoProcessor.from_pretrained("google/deplot")
26
- logger.info("Model and processor initialized successfully")
27
  except Exception as e:
 
28
  logger.error(f"Error initializing model: {str(e)}")
29
  raise
30
 
31
  def process_image(self, image_path, prompt=None):
32
  """处理图片并生成数据表格"""
33
  try:
 
 
34
  # 验证文件存在
35
  if not os.path.exists(image_path):
36
- raise FileNotFoundError(f"Image file not found: {image_path}")
37
 
38
  # 打开并处理图片
39
- logger.info(f"Processing image: {image_path}")
40
  image = Image.open(image_path)
41
 
42
  # 准备输入
@@ -50,8 +83,8 @@ class ChartAnalyzer:
50
  )
51
 
52
  # 生成预测
53
- logger.info("Generating predictions...")
54
- with torch.no_grad(): # 提高性能并减少内存使用
55
  predictions = self.model.generate(
56
  **inputs,
57
  max_new_tokens=512,
@@ -77,31 +110,36 @@ class ChartAnalyzer:
77
  for row in result_array:
78
  file.write(" | ".join(row) + "\n")
79
 
80
- logger.info(f"Results saved to {output_file}")
81
  return result_array
82
 
83
  except Exception as e:
 
84
  logger.error(f"Error processing image: {str(e)}")
85
  raise
86
 
87
  def main():
88
  try:
 
 
89
  # 创建分析器实例
90
  analyzer = ChartAnalyzer()
91
 
92
- # 指定图片路径(在Space中使用上传的图片路径)
93
  image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
94
 
95
  # 处理图片
96
  results = analyzer.process_image(image_path)
97
 
98
  # 打印结果
99
- print("\nAnalysis Results:")
100
- for row in results:
101
- print(" | ".join(row))
 
102
 
103
  except Exception as e:
104
  logger.error(f"Application error: {str(e)}")
 
105
  raise
106
 
107
  if __name__ == "__main__":
 
6
  from transformers import AutoProcessor, Pix2StructForConditionalGeneration
7
  from PIL import Image
8
 
9
+ # 配置日志格式
10
  logging.basicConfig(
11
  level=logging.INFO,
12
  format='%(asctime)s - %(levelname)s - %(message)s',
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+ def print_section(title, char='='):
21
+ """打印格式化的章节标题"""
22
+ print(f"\n{char * 50}")
23
+ print(f"{title.center(50)}")
24
+ print(f"{char * 50}\n")
25
+
26
+ def print_table(data):
27
+ """格式化打印表格数据"""
28
+ if not data:
29
+ print("No data available")
30
+ return
31
+
32
+ # 计算每列的最大宽度
33
+ col_widths = []
34
+ for i in range(len(data[0])):
35
+ col_width = max(len(str(row[i])) for row in data)
36
+ col_widths.append(col_width)
37
+
38
+ # 打印表头
39
+ header = data[0]
40
+ header_str = " | ".join(str(header[i]).ljust(col_widths[i]) for i in range(len(header)))
41
+ print(header_str)
42
+ print("-" * len(header_str))
43
+
44
+ # 打印数据行
45
+ for row in data[1:]:
46
+ row_str = " | ".join(str(row[i]).ljust(col_widths[i]) for i in range(len(row)))
47
+ print(row_str)
48
+
49
  class ChartAnalyzer:
50
  def __init__(self):
51
  try:
52
+ print_section("初始化模型")
53
+ print("正在加载模型和处理器...")
54
  self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
55
  self.processor = AutoProcessor.from_pretrained("google/deplot")
56
+ print(" 模型加载完成")
57
  except Exception as e:
58
+ print("✗ 模型加载失败")
59
  logger.error(f"Error initializing model: {str(e)}")
60
  raise
61
 
62
  def process_image(self, image_path, prompt=None):
63
  """处理图片并生成数据表格"""
64
  try:
65
+ print_section("图片处理", char='-')
66
+
67
  # 验证文件存在
68
  if not os.path.exists(image_path):
69
+ raise FileNotFoundError(f"找不到图片文件: {image_path}")
70
 
71
  # 打开并处理图片
72
+ print(f"正在处理图片: {image_path}")
73
  image = Image.open(image_path)
74
 
75
  # 准备输入
 
83
  )
84
 
85
  # 生成预测
86
+ print("\n正在生成数据分析...")
87
+ with torch.no_grad():
88
  predictions = self.model.generate(
89
  **inputs,
90
  max_new_tokens=512,
 
110
  for row in result_array:
111
  file.write(" | ".join(row) + "\n")
112
 
113
+ print(f"\n✓ 结果已保存至: {output_file}")
114
  return result_array
115
 
116
  except Exception as e:
117
+ print("\n✗ 处理失败")
118
  logger.error(f"Error processing image: {str(e)}")
119
  raise
120
 
121
  def main():
122
  try:
123
+ print_section("图表数据提取系统", char='*')
124
+
125
  # 创建分析器实例
126
  analyzer = ChartAnalyzer()
127
 
128
+ # 指定图片路径
129
  image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
130
 
131
  # 处理图片
132
  results = analyzer.process_image(image_path)
133
 
134
  # 打印结果
135
+ print_section("分析结果")
136
+ print_table(results)
137
+
138
+ print_section("处理完成", char='*')
139
 
140
  except Exception as e:
141
  logger.error(f"Application error: {str(e)}")
142
+ print("\n✗ 程序执行出错,请查看日志获取详细信息")
143
  raise
144
 
145
  if __name__ == "__main__":