Spaces:
Sleeping
Sleeping
feat: add model comparison with base and fine-tuned
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
|
17 |
tokenizer.pad_token = tokenizer.eos_token
|
18 |
|
19 |
print("Loading base model...")
|
20 |
-
|
21 |
BASE_MODEL,
|
22 |
device_map="auto",
|
23 |
torch_dtype=torch.float16,
|
@@ -25,14 +25,17 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
25 |
use_safetensors=True
|
26 |
)
|
27 |
|
28 |
-
print("Loading
|
29 |
-
|
30 |
-
|
31 |
ADAPTER_MODEL,
|
32 |
torch_dtype=torch.float16,
|
33 |
device_map="auto"
|
34 |
)
|
35 |
-
|
|
|
|
|
|
|
36 |
|
37 |
def format_prompt(problem: str, problem_type: str) -> str:
|
38 |
"""Format input prompt for the model"""
|
@@ -46,7 +49,7 @@ The derivative of this function is:"""
|
|
46 |
|
47 |
Problem: {problem}
|
48 |
The solution is:"""
|
49 |
-
else: # Roots
|
50 |
return f"""Find the roots of this equation.
|
51 |
|
52 |
Equation: {problem}
|
@@ -54,7 +57,7 @@ The roots are:"""
|
|
54 |
|
55 |
@spaces.GPU
|
56 |
@measure_time
|
57 |
-
def get_model_response(problem: str, problem_type: str) -> str:
|
58 |
"""Generate response from model"""
|
59 |
# Format prompt
|
60 |
prompt = format_prompt(problem, problem_type)
|
@@ -81,42 +84,68 @@ def get_model_response(problem: str, problem_type: str) -> str:
|
|
81 |
|
82 |
@spaces.GPU
|
83 |
def solve_problem(problem: str, problem_type: str) -> tuple:
|
84 |
-
"""Solve math problem
|
85 |
if not problem:
|
86 |
-
return "Please enter a problem", None
|
87 |
|
88 |
# Record problem type
|
89 |
monitor.record_problem_type(problem_type)
|
90 |
|
91 |
-
# Get
|
92 |
-
|
|
|
93 |
|
94 |
-
# Format
|
95 |
if problem_type == "Derivative":
|
96 |
-
|
97 |
|
98 |
Let's verify this step by step:
|
99 |
1. Starting with f(x) = {problem}
|
100 |
2. Applying differentiation rules
|
101 |
-
3. We get f'(x) = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
elif problem_type == "Addition":
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
Let's verify this step by step:
|
106 |
1. Starting with: {problem}
|
107 |
2. Adding the numbers
|
108 |
-
3. We get: {
|
|
|
109 |
else: # Roots
|
110 |
-
|
111 |
|
112 |
Let's verify this step by step:
|
113 |
1. Starting with equation: {problem}
|
114 |
2. Solving for x
|
115 |
-
3. Roots are: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
# Record metrics
|
118 |
-
monitor.record_response_time("
|
119 |
-
monitor.
|
|
|
|
|
120 |
|
121 |
# Get updated statistics
|
122 |
stats = monitor.get_statistics()
|
@@ -125,23 +154,25 @@ Let's verify this step by step:
|
|
125 |
stats_display = f"""
|
126 |
### Performance Metrics
|
127 |
|
128 |
-
#### Response Times
|
129 |
-
-
|
|
|
130 |
|
131 |
-
#### Success
|
132 |
-
- {stats.get('
|
|
|
133 |
|
134 |
#### Problem Types Used
|
135 |
"""
|
136 |
for ptype, percentage in stats.get('problem_type_distribution', {}).items():
|
137 |
stats_display += f"- {ptype}: {percentage:.1f}%\n"
|
138 |
|
139 |
-
return
|
140 |
|
141 |
# Create Gradio interface
|
142 |
with gr.Blocks(title="Mathematics Problem Solver") as demo:
|
143 |
gr.Markdown("# Mathematics Problem Solver")
|
144 |
-
gr.Markdown("
|
145 |
|
146 |
with gr.Row():
|
147 |
with gr.Column():
|
@@ -157,10 +188,13 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
|
|
157 |
solve_btn = gr.Button("Solve", variant="primary")
|
158 |
|
159 |
with gr.Row():
|
160 |
-
|
161 |
-
|
162 |
-
lines=6
|
163 |
-
|
|
|
|
|
|
|
164 |
|
165 |
# Performance metrics display
|
166 |
with gr.Row():
|
@@ -177,7 +211,7 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
|
|
177 |
["\\frac{1}{x}", "Derivative"]
|
178 |
],
|
179 |
inputs=[problem_input, problem_type],
|
180 |
-
outputs=[
|
181 |
fn=solve_problem,
|
182 |
cache_examples=False # Disable caching
|
183 |
)
|
@@ -186,7 +220,7 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
|
|
186 |
solve_btn.click(
|
187 |
fn=solve_problem,
|
188 |
inputs=[problem_input, problem_type],
|
189 |
-
outputs=[
|
190 |
)
|
191 |
|
192 |
if __name__ == "__main__":
|
|
|
17 |
tokenizer.pad_token = tokenizer.eos_token
|
18 |
|
19 |
print("Loading base model...")
|
20 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
21 |
BASE_MODEL,
|
22 |
device_map="auto",
|
23 |
torch_dtype=torch.float16,
|
|
|
25 |
use_safetensors=True
|
26 |
)
|
27 |
|
28 |
+
print("Loading fine-tuned model...")
|
29 |
+
finetuned_model = PeftModel.from_pretrained(
|
30 |
+
base_model,
|
31 |
ADAPTER_MODEL,
|
32 |
torch_dtype=torch.float16,
|
33 |
device_map="auto"
|
34 |
)
|
35 |
+
|
36 |
+
# Set models to eval mode
|
37 |
+
base_model.eval()
|
38 |
+
finetuned_model.eval()
|
39 |
|
40 |
def format_prompt(problem: str, problem_type: str) -> str:
|
41 |
"""Format input prompt for the model"""
|
|
|
49 |
|
50 |
Problem: {problem}
|
51 |
The solution is:"""
|
52 |
+
else: # Roots
|
53 |
return f"""Find the roots of this equation.
|
54 |
|
55 |
Equation: {problem}
|
|
|
57 |
|
58 |
@spaces.GPU
|
59 |
@measure_time
|
60 |
+
def get_model_response(problem: str, problem_type: str, model) -> str:
|
61 |
"""Generate response from model"""
|
62 |
# Format prompt
|
63 |
prompt = format_prompt(problem, problem_type)
|
|
|
84 |
|
85 |
@spaces.GPU
|
86 |
def solve_problem(problem: str, problem_type: str) -> tuple:
|
87 |
+
"""Solve math problem with both models"""
|
88 |
if not problem:
|
89 |
+
return "Please enter a problem", "Please enter a problem", None
|
90 |
|
91 |
# Record problem type
|
92 |
monitor.record_problem_type(problem_type)
|
93 |
|
94 |
+
# Get responses from both models with timing
|
95 |
+
base_response, base_time = get_model_response(problem, problem_type, base_model)
|
96 |
+
finetuned_response, finetuned_time = get_model_response(problem, problem_type, finetuned_model)
|
97 |
|
98 |
+
# Format outputs with steps
|
99 |
if problem_type == "Derivative":
|
100 |
+
base_output = f"""Generated derivative: {base_response}
|
101 |
|
102 |
Let's verify this step by step:
|
103 |
1. Starting with f(x) = {problem}
|
104 |
2. Applying differentiation rules
|
105 |
+
3. We get f'(x) = {base_response}"""
|
106 |
+
|
107 |
+
finetuned_output = f"""Generated derivative: {finetuned_response}
|
108 |
+
|
109 |
+
Let's verify this step by step:
|
110 |
+
1. Starting with f(x) = {problem}
|
111 |
+
2. Applying differentiation rules
|
112 |
+
3. We get f'(x) = {finetuned_response}"""
|
113 |
+
|
114 |
elif problem_type == "Addition":
|
115 |
+
base_output = f"""Solution: {base_response}
|
116 |
+
|
117 |
+
Let's verify this step by step:
|
118 |
+
1. Starting with: {problem}
|
119 |
+
2. Adding the numbers
|
120 |
+
3. We get: {base_response}"""
|
121 |
+
|
122 |
+
finetuned_output = f"""Solution: {finetuned_response}
|
123 |
|
124 |
Let's verify this step by step:
|
125 |
1. Starting with: {problem}
|
126 |
2. Adding the numbers
|
127 |
+
3. We get: {finetuned_response}"""
|
128 |
+
|
129 |
else: # Roots
|
130 |
+
base_output = f"""Found roots: {base_response}
|
131 |
|
132 |
Let's verify this step by step:
|
133 |
1. Starting with equation: {problem}
|
134 |
2. Solving for x
|
135 |
+
3. Roots are: {base_response}"""
|
136 |
+
|
137 |
+
finetuned_output = f"""Found roots: {finetuned_response}
|
138 |
+
|
139 |
+
Let's verify this step by step:
|
140 |
+
1. Starting with equation: {problem}
|
141 |
+
2. Solving for x
|
142 |
+
3. Roots are: {finetuned_response}"""
|
143 |
|
144 |
# Record metrics
|
145 |
+
monitor.record_response_time("base", base_time)
|
146 |
+
monitor.record_response_time("finetuned", finetuned_time)
|
147 |
+
monitor.record_success("base", not base_response.startswith("Error"))
|
148 |
+
monitor.record_success("finetuned", not finetuned_response.startswith("Error"))
|
149 |
|
150 |
# Get updated statistics
|
151 |
stats = monitor.get_statistics()
|
|
|
154 |
stats_display = f"""
|
155 |
### Performance Metrics
|
156 |
|
157 |
+
#### Response Times (seconds)
|
158 |
+
- Base Model: {stats.get('base_avg_response_time', 0):.2f} avg
|
159 |
+
- Fine-tuned Model: {stats.get('finetuned_avg_response_time', 0):.2f} avg
|
160 |
|
161 |
+
#### Success Rates
|
162 |
+
- Base Model: {stats.get('base_success_rate', 0):.1f}%
|
163 |
+
- Fine-tuned Model: {stats.get('finetuned_success_rate', 0):.1f}%
|
164 |
|
165 |
#### Problem Types Used
|
166 |
"""
|
167 |
for ptype, percentage in stats.get('problem_type_distribution', {}).items():
|
168 |
stats_display += f"- {ptype}: {percentage:.1f}%\n"
|
169 |
|
170 |
+
return base_output, finetuned_output, stats_display
|
171 |
|
172 |
# Create Gradio interface
|
173 |
with gr.Blocks(title="Mathematics Problem Solver") as demo:
|
174 |
gr.Markdown("# Mathematics Problem Solver")
|
175 |
+
gr.Markdown("Compare solutions between base and fine-tuned models")
|
176 |
|
177 |
with gr.Row():
|
178 |
with gr.Column():
|
|
|
188 |
solve_btn = gr.Button("Solve", variant="primary")
|
189 |
|
190 |
with gr.Row():
|
191 |
+
with gr.Column():
|
192 |
+
gr.Markdown("### Base Model")
|
193 |
+
base_output = gr.Textbox(label="Base Model Solution", lines=6)
|
194 |
+
|
195 |
+
with gr.Column():
|
196 |
+
gr.Markdown("### Fine-tuned Model")
|
197 |
+
finetuned_output = gr.Textbox(label="Fine-tuned Model Solution", lines=6)
|
198 |
|
199 |
# Performance metrics display
|
200 |
with gr.Row():
|
|
|
211 |
["\\frac{1}{x}", "Derivative"]
|
212 |
],
|
213 |
inputs=[problem_input, problem_type],
|
214 |
+
outputs=[base_output, finetuned_output, metrics_display],
|
215 |
fn=solve_problem,
|
216 |
cache_examples=False # Disable caching
|
217 |
)
|
|
|
220 |
solve_btn.click(
|
221 |
fn=solve_problem,
|
222 |
inputs=[problem_input, problem_type],
|
223 |
+
outputs=[base_output, finetuned_output, metrics_display]
|
224 |
)
|
225 |
|
226 |
if __name__ == "__main__":
|