Update app.py
Browse files
app.py
CHANGED
@@ -8,8 +8,9 @@ import matplotlib.pyplot as plt
|
|
8 |
import tempfile
|
9 |
from gradio_client import Client, handle_file
|
10 |
from dataclasses import dataclass
|
11 |
-
from typing import List, Optional
|
12 |
import logging
|
|
|
13 |
|
14 |
# Logging configuration
|
15 |
logging.basicConfig(level=logging.INFO)
|
@@ -55,18 +56,27 @@ class MathSolver:
|
|
55 |
raise
|
56 |
|
57 |
@staticmethod
|
58 |
-
def query_qwen2(image_path: str, question: str) -> str:
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
@staticmethod
|
72 |
def extract_and_execute_python_code(text: str) -> Optional[List[str]]:
|
@@ -98,7 +108,7 @@ app = Flask(__name__)
|
|
98 |
|
99 |
token = os.environ.get("TOKEN")
|
100 |
gemini_config = GeminiConfig(
|
101 |
-
token,
|
102 |
generation_config={
|
103 |
"temperature": 1,
|
104 |
"max_output_tokens": 8192,
|
@@ -135,11 +145,16 @@ def upload_image():
|
|
135 |
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
136 |
file.save(temp_file.name)
|
137 |
|
138 |
-
|
139 |
-
math_solver.query_gemini(temp_file.name, prompt)
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
# Extract and generate graphs
|
145 |
image_paths = math_solver.extract_and_execute_python_code(result)
|
|
|
8 |
import tempfile
|
9 |
from gradio_client import Client, handle_file
|
10 |
from dataclasses import dataclass
|
11 |
+
from typing import List, Optional, Tuple
|
12 |
import logging
|
13 |
+
import time
|
14 |
|
15 |
# Logging configuration
|
16 |
logging.basicConfig(level=logging.INFO)
|
|
|
56 |
raise
|
57 |
|
58 |
@staticmethod
|
59 |
+
def query_qwen2(image_path: str, question: str, max_retries: int = 3) -> Tuple[str, bool]:
|
60 |
+
"""
|
61 |
+
Query Qwen2 model with retry mechanism
|
62 |
+
Returns: (result, success)
|
63 |
+
"""
|
64 |
+
for attempt in range(max_retries):
|
65 |
+
try:
|
66 |
+
client = Client("Qwen/Qwen2.5-Math-Demo")
|
67 |
+
result = client.predict(
|
68 |
+
image=handle_file(image_path),
|
69 |
+
sketchpad=None,
|
70 |
+
question=question,
|
71 |
+
api_name="/math_chat_bot"
|
72 |
+
)
|
73 |
+
return result, True
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Qwen2 Error (attempt {attempt + 1}/{max_retries}): {str(e)}")
|
76 |
+
if attempt < max_retries - 1:
|
77 |
+
time.sleep(2 ** attempt) # Exponential backoff
|
78 |
+
continue
|
79 |
+
return "Error: Unable to process with Qwen2 model", False
|
80 |
|
81 |
@staticmethod
|
82 |
def extract_and_execute_python_code(text: str) -> Optional[List[str]]:
|
|
|
108 |
|
109 |
token = os.environ.get("TOKEN")
|
110 |
gemini_config = GeminiConfig(
|
111 |
+
token,
|
112 |
generation_config={
|
113 |
"temperature": 1,
|
114 |
"max_output_tokens": 8192,
|
|
|
145 |
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
146 |
file.save(temp_file.name)
|
147 |
|
148 |
+
if model_choice == "gemini":
|
149 |
+
result = math_solver.query_gemini(temp_file.name, prompt)
|
150 |
+
success = True
|
151 |
+
else:
|
152 |
+
result, success = math_solver.query_qwen2(temp_file.name, prompt)
|
153 |
+
if not success:
|
154 |
+
# Fallback to Gemini if Qwen2 fails
|
155 |
+
logger.info("Falling back to Gemini model")
|
156 |
+
result = math_solver.query_gemini(temp_file.name, prompt)
|
157 |
+
model_choice = "gemini (fallback)"
|
158 |
|
159 |
# Extract and generate graphs
|
160 |
image_paths = math_solver.extract_and_execute_python_code(result)
|