NAME commited on
Commit
f4ed285
·
1 Parent(s): 613beb2

Add application file

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Load the OCR model and processor
6
+ ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
7
+ "Qwen/Qwen2-VL-7B-Instruct",
8
+ torch_dtype="auto",
9
+ device_map="auto",
10
+ )
11
+
12
+ ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
13
+
14
+ # Load the Math model and tokenizer
15
+ math_model = AutoModelForCausalLM.from_pretrained(
16
+ "Qwen/Qwen2.5-Math-72B-Instruct",
17
+ torch_dtype="auto",
18
+ device_map="auto"
19
+ )
20
+
21
+ math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-72B-Instruct")
22
+
23
+ # OCR extraction function
24
+ def ocr_and_query(image, question):
25
+ # Prepare image for the model
26
+ messages = [
27
+ {
28
+ "role": "user",
29
+ "content": [
30
+ {"type": "image"},
31
+ {
32
+ "type": "text",
33
+ "text": question
34
+ },
35
+ ],
36
+ }
37
+ ]
38
+
39
+ # Process image and text prompt
40
+ text_prompt = ocr_processor.apply_chat_template(messages, add_generation_prompt=True)
41
+ inputs = ocr_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
42
+
43
+ # Run the model to generate OCR results
44
+ inputs = inputs.to("cuda")
45
+ output_ids = ocr_model.generate(**inputs, max_new_tokens=1024)
46
+
47
+ # Decode the generated text
48
+ generated_ids = [
49
+ output_ids[len(input_ids):]
50
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids)
51
+ ]
52
+ output_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
53
+
54
+ return output_text
55
+
56
+ # Math problem solving function
57
+ def solve_math_problem(prompt):
58
+ # CoT (Chain of Thought)
59
+ messages = [
60
+ {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
61
+ {"role": "user", "content": prompt}
62
+ ]
63
+
64
+ text = math_tokenizer.apply_chat_template(
65
+ messages,
66
+ tokenize=False,
67
+ add_generation_prompt=True
68
+ )
69
+ model_inputs = math_tokenizer([text], return_tensors="pt").to("cuda")
70
+
71
+ generated_ids = math_model.generate(
72
+ **model_inputs,
73
+ max_new_tokens=512
74
+ )
75
+ generated_ids = [
76
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
77
+ ]
78
+
79
+ response = math_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
80
+
81
+ return response
82
+
83
+ # Function to clear inputs and output
84
+ def clear_inputs():
85
+ return None, "", ""
86
+
87
+ # Gradio interface setup
88
+ def gradio_app(image, question, task):
89
+ if task == "OCR and Query":
90
+ return image, question, ocr_and_query(image, question)
91
+ elif task == "Solve Math Problem from Image":
92
+ if image is None:
93
+ return image, question, "Please upload an image."
94
+ extracted_text = ocr_and_query(image, "")
95
+ math_solution = solve_math_problem(extracted_text)
96
+ return image, extracted_text, math_solution
97
+ elif task == "Solve Math Problem from Text":
98
+ if question.strip() == "":
99
+ return image, question, "Please enter a math problem."
100
+ math_solution = solve_math_problem(question)
101
+ return image, question, math_solution
102
+ else:
103
+ return image, question, "Please select a task."
104
+
105
+ # Gradio interface
106
+ with gr.Blocks() as app:
107
+ gr.Markdown("# Image OCR and Math Solver")
108
+ gr.Markdown("Upload an image, enter your question or math problem, and select the appropriate task.")
109
+
110
+ with gr.Row():
111
+ image_input = gr.Image(type="pil", label="Upload Image")
112
+ text_input = gr.Textbox(lines=2, placeholder="Enter your question or math problem here...", label="Input")
113
+
114
+ with gr.Row():
115
+ task_radio = gr.Radio(["OCR and Query", "Solve Math Problem from Image", "Solve Math Problem from Text"], label="Task")
116
+
117
+ with gr.Row():
118
+ complete_button = gr.Button("Complete")
119
+ clear_button = gr.Button("Clear")
120
+
121
+ output = gr.Markdown(label="Output")
122
+
123
+ # Event listeners
124
+ complete_button.click(fn=gradio_app, inputs=[image_input, text_input, task_radio], outputs=[image_input, text_input, output])
125
+ clear_button.click(fn=clear_inputs, outputs=[image_input, text_input, output])
126
+
127
+ # Launch the app
128
+ app.launch(share=True)