Joash2024 commited on
Commit
1a468e6
1 Parent(s): 5162902

feat: add model comparison with base and fine-tuned

Browse files
Files changed (1) hide show
  1. app.py +66 -32
app.py CHANGED
@@ -17,7 +17,7 @@ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
17
  tokenizer.pad_token = tokenizer.eos_token
18
 
19
  print("Loading base model...")
20
- model = AutoModelForCausalLM.from_pretrained(
21
  BASE_MODEL,
22
  device_map="auto",
23
  torch_dtype=torch.float16,
@@ -25,14 +25,17 @@ model = AutoModelForCausalLM.from_pretrained(
25
  use_safetensors=True
26
  )
27
 
28
- print("Loading LoRA adapter...")
29
- model = PeftModel.from_pretrained(
30
- model,
31
  ADAPTER_MODEL,
32
  torch_dtype=torch.float16,
33
  device_map="auto"
34
  )
35
- model.eval()
 
 
 
36
 
37
  def format_prompt(problem: str, problem_type: str) -> str:
38
  """Format input prompt for the model"""
@@ -46,7 +49,7 @@ The derivative of this function is:"""
46
 
47
  Problem: {problem}
48
  The solution is:"""
49
- else: # Roots or Custom
50
  return f"""Find the roots of this equation.
51
 
52
  Equation: {problem}
@@ -54,7 +57,7 @@ The roots are:"""
54
 
55
  @spaces.GPU
56
  @measure_time
57
- def get_model_response(problem: str, problem_type: str) -> str:
58
  """Generate response from model"""
59
  # Format prompt
60
  prompt = format_prompt(problem, problem_type)
@@ -81,42 +84,68 @@ def get_model_response(problem: str, problem_type: str) -> str:
81
 
82
  @spaces.GPU
83
  def solve_problem(problem: str, problem_type: str) -> tuple:
84
- """Solve math problem and track performance"""
85
  if not problem:
86
- return "Please enter a problem", None
87
 
88
  # Record problem type
89
  monitor.record_problem_type(problem_type)
90
 
91
- # Get model response with timing
92
- response, time_taken = get_model_response(problem, problem_type)
 
93
 
94
- # Format output with steps
95
  if problem_type == "Derivative":
96
- output = f"""Generated derivative: {response}
97
 
98
  Let's verify this step by step:
99
  1. Starting with f(x) = {problem}
100
  2. Applying differentiation rules
101
- 3. We get f'(x) = {response}"""
 
 
 
 
 
 
 
 
102
  elif problem_type == "Addition":
103
- output = f"""Solution: {response}
 
 
 
 
 
 
 
104
 
105
  Let's verify this step by step:
106
  1. Starting with: {problem}
107
  2. Adding the numbers
108
- 3. We get: {response}"""
 
109
  else: # Roots
110
- output = f"""Found roots: {response}
111
 
112
  Let's verify this step by step:
113
  1. Starting with equation: {problem}
114
  2. Solving for x
115
- 3. Roots are: {response}"""
 
 
 
 
 
 
 
116
 
117
  # Record metrics
118
- monitor.record_response_time("model", time_taken)
119
- monitor.record_success("model", not response.startswith("Error"))
 
 
120
 
121
  # Get updated statistics
122
  stats = monitor.get_statistics()
@@ -125,23 +154,25 @@ Let's verify this step by step:
125
  stats_display = f"""
126
  ### Performance Metrics
127
 
128
- #### Response Times
129
- - Average: {stats.get('model_avg_response_time', 0):.2f} seconds
 
130
 
131
- #### Success Rate
132
- - {stats.get('model_success_rate', 0):.1f}%
 
133
 
134
  #### Problem Types Used
135
  """
136
  for ptype, percentage in stats.get('problem_type_distribution', {}).items():
137
  stats_display += f"- {ptype}: {percentage:.1f}%\n"
138
 
139
- return output, stats_display
140
 
141
  # Create Gradio interface
142
  with gr.Blocks(title="Mathematics Problem Solver") as demo:
143
  gr.Markdown("# Mathematics Problem Solver")
144
- gr.Markdown("Using our fine-tuned model to solve mathematical problems")
145
 
146
  with gr.Row():
147
  with gr.Column():
@@ -157,10 +188,13 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
157
  solve_btn = gr.Button("Solve", variant="primary")
158
 
159
  with gr.Row():
160
- solution_output = gr.Textbox(
161
- label="Solution with Steps",
162
- lines=6
163
- )
 
 
 
164
 
165
  # Performance metrics display
166
  with gr.Row():
@@ -177,7 +211,7 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
177
  ["\\frac{1}{x}", "Derivative"]
178
  ],
179
  inputs=[problem_input, problem_type],
180
- outputs=[solution_output, metrics_display],
181
  fn=solve_problem,
182
  cache_examples=False # Disable caching
183
  )
@@ -186,7 +220,7 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
186
  solve_btn.click(
187
  fn=solve_problem,
188
  inputs=[problem_input, problem_type],
189
- outputs=[solution_output, metrics_display]
190
  )
191
 
192
  if __name__ == "__main__":
 
17
  tokenizer.pad_token = tokenizer.eos_token
18
 
19
  print("Loading base model...")
20
+ base_model = AutoModelForCausalLM.from_pretrained(
21
  BASE_MODEL,
22
  device_map="auto",
23
  torch_dtype=torch.float16,
 
25
  use_safetensors=True
26
  )
27
 
28
+ print("Loading fine-tuned model...")
29
+ finetuned_model = PeftModel.from_pretrained(
30
+ base_model,
31
  ADAPTER_MODEL,
32
  torch_dtype=torch.float16,
33
  device_map="auto"
34
  )
35
+
36
+ # Set models to eval mode
37
+ base_model.eval()
38
+ finetuned_model.eval()
39
 
40
  def format_prompt(problem: str, problem_type: str) -> str:
41
  """Format input prompt for the model"""
 
49
 
50
  Problem: {problem}
51
  The solution is:"""
52
+ else: # Roots
53
  return f"""Find the roots of this equation.
54
 
55
  Equation: {problem}
 
57
 
58
  @spaces.GPU
59
  @measure_time
60
+ def get_model_response(problem: str, problem_type: str, model) -> str:
61
  """Generate response from model"""
62
  # Format prompt
63
  prompt = format_prompt(problem, problem_type)
 
84
 
85
  @spaces.GPU
86
  def solve_problem(problem: str, problem_type: str) -> tuple:
87
+ """Solve math problem with both models"""
88
  if not problem:
89
+ return "Please enter a problem", "Please enter a problem", None
90
 
91
  # Record problem type
92
  monitor.record_problem_type(problem_type)
93
 
94
+ # Get responses from both models with timing
95
+ base_response, base_time = get_model_response(problem, problem_type, base_model)
96
+ finetuned_response, finetuned_time = get_model_response(problem, problem_type, finetuned_model)
97
 
98
+ # Format outputs with steps
99
  if problem_type == "Derivative":
100
+ base_output = f"""Generated derivative: {base_response}
101
 
102
  Let's verify this step by step:
103
  1. Starting with f(x) = {problem}
104
  2. Applying differentiation rules
105
+ 3. We get f'(x) = {base_response}"""
106
+
107
+ finetuned_output = f"""Generated derivative: {finetuned_response}
108
+
109
+ Let's verify this step by step:
110
+ 1. Starting with f(x) = {problem}
111
+ 2. Applying differentiation rules
112
+ 3. We get f'(x) = {finetuned_response}"""
113
+
114
  elif problem_type == "Addition":
115
+ base_output = f"""Solution: {base_response}
116
+
117
+ Let's verify this step by step:
118
+ 1. Starting with: {problem}
119
+ 2. Adding the numbers
120
+ 3. We get: {base_response}"""
121
+
122
+ finetuned_output = f"""Solution: {finetuned_response}
123
 
124
  Let's verify this step by step:
125
  1. Starting with: {problem}
126
  2. Adding the numbers
127
+ 3. We get: {finetuned_response}"""
128
+
129
  else: # Roots
130
+ base_output = f"""Found roots: {base_response}
131
 
132
  Let's verify this step by step:
133
  1. Starting with equation: {problem}
134
  2. Solving for x
135
+ 3. Roots are: {base_response}"""
136
+
137
+ finetuned_output = f"""Found roots: {finetuned_response}
138
+
139
+ Let's verify this step by step:
140
+ 1. Starting with equation: {problem}
141
+ 2. Solving for x
142
+ 3. Roots are: {finetuned_response}"""
143
 
144
  # Record metrics
145
+ monitor.record_response_time("base", base_time)
146
+ monitor.record_response_time("finetuned", finetuned_time)
147
+ monitor.record_success("base", not base_response.startswith("Error"))
148
+ monitor.record_success("finetuned", not finetuned_response.startswith("Error"))
149
 
150
  # Get updated statistics
151
  stats = monitor.get_statistics()
 
154
  stats_display = f"""
155
  ### Performance Metrics
156
 
157
+ #### Response Times (seconds)
158
+ - Base Model: {stats.get('base_avg_response_time', 0):.2f} avg
159
+ - Fine-tuned Model: {stats.get('finetuned_avg_response_time', 0):.2f} avg
160
 
161
+ #### Success Rates
162
+ - Base Model: {stats.get('base_success_rate', 0):.1f}%
163
+ - Fine-tuned Model: {stats.get('finetuned_success_rate', 0):.1f}%
164
 
165
  #### Problem Types Used
166
  """
167
  for ptype, percentage in stats.get('problem_type_distribution', {}).items():
168
  stats_display += f"- {ptype}: {percentage:.1f}%\n"
169
 
170
+ return base_output, finetuned_output, stats_display
171
 
172
  # Create Gradio interface
173
  with gr.Blocks(title="Mathematics Problem Solver") as demo:
174
  gr.Markdown("# Mathematics Problem Solver")
175
+ gr.Markdown("Compare solutions between base and fine-tuned models")
176
 
177
  with gr.Row():
178
  with gr.Column():
 
188
  solve_btn = gr.Button("Solve", variant="primary")
189
 
190
  with gr.Row():
191
+ with gr.Column():
192
+ gr.Markdown("### Base Model")
193
+ base_output = gr.Textbox(label="Base Model Solution", lines=6)
194
+
195
+ with gr.Column():
196
+ gr.Markdown("### Fine-tuned Model")
197
+ finetuned_output = gr.Textbox(label="Fine-tuned Model Solution", lines=6)
198
 
199
  # Performance metrics display
200
  with gr.Row():
 
211
  ["\\frac{1}{x}", "Derivative"]
212
  ],
213
  inputs=[problem_input, problem_type],
214
+ outputs=[base_output, finetuned_output, metrics_display],
215
  fn=solve_problem,
216
  cache_examples=False # Disable caching
217
  )
 
220
  solve_btn.click(
221
  fn=solve_problem,
222
  inputs=[problem_input, problem_type],
223
+ outputs=[base_output, finetuned_output, metrics_display]
224
  )
225
 
226
  if __name__ == "__main__":