suryadev1 commited on
Commit
5cb147d
·
2 Parent(s): c9640a0 0b133b0

internal changes

Browse files
Files changed (2) hide show
  1. app.py +503 -234
  2. result.txt +1 -1
app.py CHANGED
@@ -8,6 +8,7 @@ import shutil
8
  import matplotlib.pyplot as plt
9
  from sklearn.metrics import roc_curve, auc
10
  import pandas as pd
 
11
  from sklearn.metrics import roc_auc_score
12
  from matplotlib.figure import Figure
13
  # Define the function to process the input file and model selection
@@ -157,7 +158,6 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
157
 
158
  return opt1_done, opt2_done
159
 
160
- # Read data from test_info.txt
161
  # Read data from test_info.txt
162
  with open(test_info_location, "r") as file:
163
  data = file.readlines()
@@ -167,8 +167,8 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
167
 
168
  # Initialize counters
169
  task_counts = {
170
- 1: {"only_opt1": 0, "only_opt2": 0, "both": 0,"none":0},
171
- 2: {"only_opt1": 0, "only_opt2": 0, "both": 0,"none":0}
172
  }
173
 
174
  # Analyze rows
@@ -182,18 +182,18 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
182
 
183
  if ideal_task == 0:
184
  if opt1_done and not opt2_done:
185
- task_counts[1]["only_opt1"] += 1
186
  elif not opt1_done and opt2_done:
187
- task_counts[1]["only_opt2"] += 1
188
  elif opt1_done and opt2_done:
189
  task_counts[1]["both"] += 1
190
  else:
191
  task_counts[1]["none"] +=1
192
  elif ideal_task == 1:
193
  if opt1_done and not opt2_done:
194
- task_counts[2]["only_opt1"] += 1
195
  elif not opt1_done and opt2_done:
196
- task_counts[2]["only_opt2"] += 1
197
  elif opt1_done and opt2_done:
198
  task_counts[2]["both"] += 1
199
  else:
@@ -205,43 +205,112 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
205
 
206
  # for ideal_task, counts in task_counts.items():
207
  # output_summary += f"Ideal Task = OptionalTask_{ideal_task}:\n"
208
- # output_summary += f" Only OptionalTask_1 done: {counts['only_opt1']}\n"
209
- # output_summary += f" Only OptionalTask_2 done: {counts['only_opt2']}\n"
210
  # output_summary += f" Both done: {counts['both']}\n"
211
 
 
 
 
212
  # Generate pie chart for Task 1
213
  task1_labels = list(task_counts[1].keys())
214
  task1_values = list(task_counts[1].values())
215
 
216
- fig_task1 = Figure()
217
- ax1 = fig_task1.add_subplot(1, 1, 1)
218
- ax1.pie(task1_values, labels=task1_labels, autopct='%1.1f%%', startangle=90)
219
- ax1.set_title('Ideal Task 1 Distribution')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  # Generate pie chart for Task 2
222
  task2_labels = list(task_counts[2].keys())
223
  task2_values = list(task_counts[2].values())
224
 
225
- fig_task2 = Figure()
226
- ax2 = fig_task2.add_subplot(1, 1, 1)
227
- ax2.pie(task2_values, labels=task2_labels, autopct='%1.1f%%', startangle=90)
228
- ax2.set_title('Ideal Task 2 Distribution')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  # print(output_summary)
231
 
232
  progress(0.2, desc="analysis done!! Executing models")
233
  print("finetuned task: ",finetune_task)
234
- # subprocess.run([
235
- # "python", "new_test_saved_finetuned_model.py",
236
- # "-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
237
- # "-finetune_task", finetune_task,
238
- # "-test_dataset_path","../../../../selected_rows.txt",
239
- # # "-test_label_path","../../../../train_label.txt",
240
- # "-finetuned_bert_classifier_checkpoint",
241
- # "ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
242
- # "-e",str(1),
243
- # "-b",str(1000)
244
- # ])
245
  progress(0.6,desc="Model execution completed")
246
  result = {}
247
  with open("result.txt", 'r') as file:
@@ -262,18 +331,70 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
262
 
263
 
264
  # Create a matplotlib figure
265
- fig = Figure()
266
- ax = fig.add_subplot(1, 1, 1)
267
- ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
268
- ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
269
- ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'Receiver Operating Curve (ROC)')
270
- ax.legend(loc="lower right")
271
- ax.grid()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  # Save plot to a file
274
- plot_path = "plot.png"
275
- fig.savefig(plot_path)
276
- plt.close(fig)
277
 
278
 
279
 
@@ -283,19 +404,20 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
283
  text_output = f"Model: {model_name}\nResult:\n{result}"
284
  # Prepare text output with HTML formatting
285
  text_output = f"""
286
- Model: {model_name}\n
287
- -----------------\n
288
- Time Taken: {result['time_taken_from_start']:.2f} seconds\n
289
- Total Schools in test: {len(unique_schools):.4f}\n
290
- Total number of instances having Schools with HGR : {len(high_sample):.4f}\n
291
- Total number of instances having Schools with LGR: {len(low_sample):.4f}\n
292
-
293
- ROC score of HGR: {high_roc_auc}\n
294
- ROC score of LGR: {low_roc_auc}\n
295
-
296
- ROC score of opt1: {opt_task1_roc_auc}\n
297
- ROC score of opt2: {opt_task2_roc_auc}\n
298
- -----------------
 
299
  """
300
  return text_output,fig,fig_task1,fig_task2
301
 
@@ -304,27 +426,30 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
304
  # models = ["ASTRA-FT-HGR", "ASTRA-FT-LGR", "ASTRA-FT-FULL"]
305
  models = ["ASTRA-FT-HGR", "ASTRA-FT-FULL"]
306
  content = """
307
- <h1 style="color: white;">ASTRA: An AI Model for Analyzing Math Strategies</h1>
 
308
 
309
- <h3 style="color: white;">
310
- <a href="https://drive.google.com/file/d/1lbEpg8Se1ugTtkjreD8eXIg7qrplhWan/view" style="color: #1E90FF; text-decoration: none;">Link To Paper</a> |
311
  <a href="https://github.com/Syudu41/ASTRA---Gates-Project" style="color: #1E90FF; text-decoration: none;">GitHub</a> |
312
- <a href="#" style="color: #1E90FF; text-decoration: none;">Project Page</a>
313
  </h3>
314
 
315
  <p style="color: white;">Welcome to a demo of ASTRA. ASTRA is a collaborative research project between researchers at the
316
- <a href="https://www.memphis.edu" style="color: #1E90FF; text-decoration: none;">University of Memphis</a> and
317
  <a href="https://www.carnegielearning.com" style="color: #1E90FF; text-decoration: none;">Carnegie Learning</a>
318
  to utilize AI to improve our understanding of math learning strategies.</p>
319
 
320
- <p style="color: white;">This demo has been developed with a pre-trained model (based on an architecture similar to BERT)
321
- that learns math strategies using data collected from hundreds of schools in the U.S. who have used
322
- Carnegie Learning's MATHia (formerly known as Cognitive Tutor), the flagship Intelligent Tutor
323
- that is part of a core, blended math curriculum.</p>
324
 
325
- <p style="color: white;">For this demo, we have used data from a specific domain (teaching ratio and proportions) within
326
- 7th grade math. The fine-tuning based on the pre-trained models learns to predict which strategies
327
- lead to correct vs. incorrect solutions.</p>
 
 
328
 
329
  <p style="color: white;">To use the demo, please follow these steps:</p>
330
 
@@ -335,203 +460,327 @@ lead to correct vs. incorrect solutions.</p>
335
  <li style="color: white;">ASTRA-FT-Full: Fine-tuned with a small sample of data from a mix of schools that have high/low graduation rates.</li>
336
  </ul>
337
  </li>
338
- <li style="color: white;">Select a percentage of schools to analyze (selecting a large percentage may take a long time).</li>
339
- <li style="color: white;">View Results:
 
340
  <ul>
341
- <li style="color: white;">The results from the fine-tuned model are displayed on the dashboard.</li>
342
- <li style="color: white;">The results are shown separately for schools that have high and low graduation rates.</li>
 
 
 
 
343
  </ul>
344
  </li>
345
  </ol>
346
  """
347
  # CSS styling for white text
348
  # Create the Gradio interface
349
- with gr.Blocks(css="""
350
- body {
351
- background-color: #1e1e1e!important;
352
- font-family: 'Arial', sans-serif;
353
- color: #f5f5f5!important;;
354
- }
 
355
 
356
- .gradio-container {
357
- max-width: 850px!important;
358
- margin: 0 auto!important;;
359
- padding: 20px!important;;
360
- background-color: #292929!important;
361
- border-radius: 10px;
362
- box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2);
363
- }
364
- .gradio-container-4-44-0 .prose h1 {
365
- font-size: var(--text-xxl);
366
- color: #ffffff!important;
367
- }
368
- #title {
369
- color: white!important;
370
- font-size: 2.3em;
371
- font-weight: bold;
372
- text-align: center!important;
373
- margin-bottom: 20px;
374
- }
375
- .description {
376
- text-align: center;
377
- font-size: 1.1em;
378
- color: #bfbfbf;
379
- margin-bottom: 30px;
380
- }
381
- .file-box {
382
- max-width: 180px;
383
- padding: 5px;
384
- background-color: #444!important;
385
- border: 1px solid #666!important;
386
- border-radius: 6px;
387
- height: 80px!important;;
388
- margin: 0 auto!important;;
389
- text-align: center;
390
- color: transparent;
391
- }
392
- .file-box span {
393
- color: #f5f5f5!important;
394
- font-size: 1em;
395
- line-height: 45px; /* Vertically center text */
396
- }
397
- .dropdown-menu {
398
- max-width: 220px;
399
- margin: 0 auto!important;
400
- background-color: #444!important;
401
- color:#444!important;
402
- border-radius: 6px;
403
- padding: 8px;
404
- font-size: 1.1em;
405
- border: 1px solid #666;
406
- }
407
- .button {
408
- background-color: #4CAF50!important;
409
- color: white!important;
410
- font-size: 1.1em;
411
- padding: 10px 25px;
412
- border-radius: 6px;
413
- cursor: pointer;
414
- transition: background-color 0.2s ease-in-out;
415
- }
416
- .button:hover {
417
- background-color: #45a049!important;
418
- }
419
- .output-text {
420
- background-color: #333!important;
421
- padding: 12px;
422
- border-radius: 8px;
423
- border: 1px solid #666;
424
- font-size: 1.1em;
425
- }
426
- .footer {
427
- text-align: center;
428
- margin-top: 50px;
429
- font-size: 0.9em;
430
- color: #b0b0b0;
431
- }
432
- .svelte-12ioyct .wrap {
433
- display: none !important;
434
  }
435
- .file-label-text {
436
- display: none !important;
 
 
 
 
 
 
 
437
  }
438
 
439
- div.svelte-sfqy0y {
440
- display: flex;
441
- flex-direction: inherit;
442
- flex-wrap: wrap;
443
- gap: var(--form-gap-width);
444
- box-shadow: var(--block-shadow);
445
- border: var(--block-border-width) solid var(--border-color-primary);
446
- border-radius: var(--block-radius);
447
- background: #1f2937!important;
448
- overflow-y: hidden;
449
  }
450
 
451
- .block.svelte-12cmxck {
452
- position: relative;
453
- margin: 0;
454
- box-shadow: var(--block-shadow);
455
- border-width: var(--block-border-width);
456
- border-color: var(--block-border-color);
457
- border-radius: var(--block-radius);
458
- background: #1f2937!important;
459
- width: 100%;
460
- line-height: var(--line-sm);
461
  }
462
 
463
- .svelte-12ioyct .wrap {
464
- display: none !important;
 
 
 
465
  }
466
- .file-label-text {
467
- display: none !important;
 
 
 
 
468
  }
469
- input[aria-label="file upload"] {
470
- display: none !important;
 
 
471
  }
472
 
473
- gradio-app .gradio-container.gradio-container-4-44-0 .contain .file-box span {
474
- font-size: 1em;
475
- line-height: 45px;
476
- color: #1f2937 !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  }
478
- .wrap.svelte-12ioyct {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  display: flex;
480
- flex-direction: column;
481
- justify-content: center;
482
  align-items: center;
483
- min-height: var(--size-60);
484
- color: #1f2937 !important;
485
- line-height: var(--line-md);
486
- height: 100%;
487
- padding-top: var(--size-3);
 
 
 
 
 
 
 
 
 
 
 
488
  text-align: center;
489
- margin: auto var(--spacing-lg);
 
 
 
490
  }
491
- span.svelte-1gfkn6j:not(.has-info) {
492
- margin-bottom: var(--spacing-lg);
493
- color: white!important;
 
 
 
 
 
 
494
  }
495
- label.float.svelte-1b6s6s {
496
- position: relative!important;
497
- top: var(--block-label-margin);
498
- left: var(--block-label-margin);
 
 
 
 
 
 
499
  }
500
- label.svelte-1b6s6s {
501
- display: inline-flex;
502
- align-items: center;
503
- z-index: var(--layer-2);
504
- box-shadow: var(--block-label-shadow);
505
- border: var(--block-label-border-width) solid var(--border-color-primary);
506
- border-top: none;
507
- border-left: none;
508
- border-radius: var(--block-label-radius);
509
- background: rgb(120 151 180)!important;
510
- padding: var(--block-label-padding);
511
- pointer-events: none;
512
- color: #1f2937!important;
513
- font-weight: var(--block-label-text-weight);
514
- font-size: var(--block-label-text-size);
515
- line-height: var(--line-sm);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  }
517
- .file.svelte-18wv37q.svelte-18wv37q {
518
- display: block!important;
519
- width: var(--size-full);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  }
521
 
522
- tbody.svelte-18wv37q>tr.svelte-18wv37q:nth-child(odd) {
523
- background: ##7897b4!important;
524
- color: white;
525
- background: #aca7b2;
526
  }
527
 
528
- .gradio-container-4-31-4 .prose h1, .gradio-container-4-31-4 .prose h2, .gradio-container-4-31-4 .prose h3, .gradio-container-4-31-4 .prose h4, .gradio-container-4-31-4 .prose h5 {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
- color: white;
 
 
 
 
 
 
531
  }
532
- """) as demo:
 
 
 
533
 
534
- gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
535
  gr.Markdown(content)
536
 
537
  with gr.Row():
@@ -539,24 +788,44 @@ tbody.svelte-18wv37q>tr.svelte-18wv37q:nth-child(odd) {
539
  # label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
540
 
541
  # info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
- model_dropdown = gr.Dropdown(choices=models, label="Select Fine-tuned Model", elem_classes="dropdown-menu")
 
544
 
545
-
546
- increment_slider = gr.Slider(minimum=1, maximum=100, step=1, label="Schools Percentage", value=1)
547
  gr.Markdown("<p class='description'>Dashboard</p>")
 
548
  with gr.Row():
549
  output_text = gr.Textbox(label="")
550
  # output_image = gr.Image(label="ROC")
551
- plot_output = gr.Plot(label="roc")
552
  with gr.Row():
553
- opt1_pie = gr.Plot(label="opt1")
554
- opt2_pie = gr.Plot(label="opt2")
 
 
 
555
  # output_summary = gr.Textbox(label="Summary")
556
 
557
- btn = gr.Button("Submit")
558
 
559
- btn.click(fn=process_file, inputs=[model_dropdown,increment_slider], outputs=[output_text,plot_output,opt1_pie,opt2_pie])
 
 
 
 
560
 
561
 
562
  # Launch the app
 
8
  import matplotlib.pyplot as plt
9
  from sklearn.metrics import roc_curve, auc
10
  import pandas as pd
11
+ import plotly.graph_objects as go
12
  from sklearn.metrics import roc_auc_score
13
  from matplotlib.figure import Figure
14
  # Define the function to process the input file and model selection
 
158
 
159
  return opt1_done, opt2_done
160
 
 
161
  # Read data from test_info.txt
162
  with open(test_info_location, "r") as file:
163
  data = file.readlines()
 
167
 
168
  # Initialize counters
169
  task_counts = {
170
+ 1: {"ER": 0, "ME": 0, "both": 0,"none":0},
171
+ 2: {"ER": 0, "ME": 0, "both": 0,"none":0}
172
  }
173
 
174
  # Analyze rows
 
182
 
183
  if ideal_task == 0:
184
  if opt1_done and not opt2_done:
185
+ task_counts[1]["ER"] += 1
186
  elif not opt1_done and opt2_done:
187
+ task_counts[1]["ME"] += 1
188
  elif opt1_done and opt2_done:
189
  task_counts[1]["both"] += 1
190
  else:
191
  task_counts[1]["none"] +=1
192
  elif ideal_task == 1:
193
  if opt1_done and not opt2_done:
194
+ task_counts[2]["ER"] += 1
195
  elif not opt1_done and opt2_done:
196
+ task_counts[2]["ME"] += 1
197
  elif opt1_done and opt2_done:
198
  task_counts[2]["both"] += 1
199
  else:
 
205
 
206
  # for ideal_task, counts in task_counts.items():
207
  # output_summary += f"Ideal Task = OptionalTask_{ideal_task}:\n"
208
+ # output_summary += f" Only OptionalTask_1 done: {counts['ER']}\n"
209
+ # output_summary += f" Only OptionalTask_2 done: {counts['ME']}\n"
210
  # output_summary += f" Both done: {counts['both']}\n"
211
 
212
+ # colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
213
+ colors = ["#FF6F61", "#6B5B95", "#88B04B", "#F7CAC9"]
214
+
215
  # Generate pie chart for Task 1
216
  task1_labels = list(task_counts[1].keys())
217
  task1_values = list(task_counts[1].values())
218
 
219
+ # fig_task1 = Figure()
220
+ # ax1 = fig_task1.add_subplot(1, 1, 1)
221
+ # ax1.pie(task1_values, labels=task1_labels, autopct='%1.1f%%', startangle=90)
222
+ # ax1.set_title('Ideal Task 1 Distribution')
223
+
224
+ fig_task1 = go.Figure(data=[go.Pie(
225
+ labels=task1_labels,
226
+ values=task1_values,
227
+ textinfo='percent+label',
228
+ textposition='auto',
229
+ marker=dict(colors=colors),
230
+ sort=False
231
+
232
+ )])
233
+
234
+ fig_task1.update_layout(
235
+ title='Problem Type: ER',
236
+ title_x=0.5,
237
+ font=dict(
238
+ family="sans-serif",
239
+ size=12,
240
+ color="black"
241
+ ),
242
+ )
243
+
244
+ fig_task1.update_layout(
245
+ legend=dict(
246
+ font=dict(
247
+ family="sans-serif",
248
+ size=12,
249
+ color="black"
250
+ ),
251
+ )
252
+ )
253
+
254
+
255
+
256
+ # fig.show()
257
 
258
  # Generate pie chart for Task 2
259
  task2_labels = list(task_counts[2].keys())
260
  task2_values = list(task_counts[2].values())
261
 
262
+ fig_task2 = go.Figure(data=[go.Pie(
263
+ labels=task2_labels,
264
+ values=task2_values,
265
+ textinfo='percent+label',
266
+ textposition='auto',
267
+ marker=dict(colors=colors),
268
+ sort=False
269
+ # pull=[0, 0.2, 0, 0] # for pulling part of pie chart out (depends on position)
270
+
271
+ )])
272
+
273
+ fig_task2.update_layout(
274
+ title='Problem Type: ME',
275
+ title_x=0.5,
276
+ font=dict(
277
+ family="sans-serif",
278
+ size=12,
279
+ color="black"
280
+ ),
281
+ )
282
+
283
+ fig_task2.update_layout(
284
+ legend=dict(
285
+ font=dict(
286
+ family="sans-serif",
287
+ size=12,
288
+ color="black"
289
+ ),
290
+ )
291
+ )
292
+
293
+
294
+ # fig_task2 = Figure()
295
+ # ax2 = fig_task2.add_subplot(1, 1, 1)
296
+ # ax2.pie(task2_values, labels=task2_labels, autopct='%1.1f%%', startangle=90)
297
+ # ax2.set_title('Ideal Task 2 Distribution')
298
 
299
  # print(output_summary)
300
 
301
  progress(0.2, desc="analysis done!! Executing models")
302
  print("finetuned task: ",finetune_task)
303
+ subprocess.run([
304
+ "python", "new_test_saved_finetuned_model.py",
305
+ "-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
306
+ "-finetune_task", finetune_task,
307
+ "-test_dataset_path","../../../../selected_rows.txt",
308
+ # "-test_label_path","../../../../train_label.txt",
309
+ "-finetuned_bert_classifier_checkpoint",
310
+ "ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
311
+ "-e",str(1),
312
+ "-b",str(1000)
313
+ ])
314
  progress(0.6,desc="Model execution completed")
315
  result = {}
316
  with open("result.txt", 'r') as file:
 
331
 
332
 
333
  # Create a matplotlib figure
334
+ # fig = Figure()
335
+ # ax = fig.add_subplot(1, 1, 1)
336
+ # ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
337
+ # ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
338
+ # ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'Receiver Operating Curve (ROC)')
339
+ # ax.legend(loc="lower right")
340
+ # ax.grid()
341
+
342
+ fig = go.Figure()
343
+ # Create and style traces
344
+ fig.add_trace(go.Line(x = list(fpr), y = list(tpr), name=f'ROC curve (area = {roc_auc:.2f})',
345
+ line=dict(color='royalblue', width=3,
346
+ ) # dash options include 'dash', 'dot', and 'dashdot'
347
+ ))
348
+ fig.add_trace(go.Line(x = [0,1], y = [0,1], showlegend = False,
349
+ line=dict(color='firebrick', width=2,
350
+ dash='dash',) # dash options include 'dash', 'dot', and 'dashdot'
351
+ ))
352
+
353
+ # Edit the layout
354
+ fig.update_layout(
355
+ showlegend = True,
356
+ title_x=0.5,
357
+ title=dict(
358
+ text='Receiver Operating Curve (ROC)'
359
+ ),
360
+ xaxis=dict(
361
+ title=dict(
362
+ text='False Positive Rate'
363
+ )
364
+ ),
365
+ yaxis=dict(
366
+ title=dict(
367
+ text='False Negative Rate'
368
+ )
369
+ ),
370
+ font=dict(
371
+ family="sans-serif",
372
+ color="black"
373
+ ),
374
+
375
+ )
376
+ fig.update_layout(
377
+ legend=dict(
378
+ x=0.75,
379
+ y=0,
380
+ traceorder="normal",
381
+ font=dict(
382
+ family="sans-serif",
383
+ size=12,
384
+ color="black"
385
+ ),
386
+ )
387
+ )
388
+
389
+
390
+
391
+
392
+
393
 
394
  # Save plot to a file
395
+ # plot_path = "plot.png"
396
+ # fig.savefig(plot_path)
397
+ # plt.close(fig)
398
 
399
 
400
 
 
404
  text_output = f"Model: {model_name}\nResult:\n{result}"
405
  # Prepare text output with HTML formatting
406
  text_output = f"""
407
+ ---------------------------
408
+ Model: {model_name}
409
+ ---------------------------\n
410
+ Time Taken: {result['time_taken_from_start']:.2f} seconds
411
+ Total Schools in test: {len(unique_schools):.4f}
412
+ Total number of instances having Schools with HGR : {len(high_sample):.4f}
413
+ Total number of instances having Schools with LGR: {len(low_sample):.4f}
414
+
415
+ ROC score of HGR: {high_roc_auc:.4f}
416
+ ROC score of LGR: {low_roc_auc:.4f}
417
+
418
+
419
+ ROC-AUC for problems of type ER: {opt_task1_roc_auc:.4f}
420
+ ROC-AUC for problems of type ME: {opt_task2_roc_auc:.4f}
421
  """
422
  return text_output,fig,fig_task1,fig_task2
423
 
 
426
  # models = ["ASTRA-FT-HGR", "ASTRA-FT-LGR", "ASTRA-FT-FULL"]
427
  models = ["ASTRA-FT-HGR", "ASTRA-FT-FULL"]
428
  content = """
429
+ <h1 style="color: black;">A S T R A</h1>
430
+ <h2 style="color: black;">An AI Model for Analyzing Math Strategies</h2>
431
 
432
+ <h3 style="color: white; text-align: center">
433
+ <a href="https://drive.google.com/file/d/1lbEpg8Se1ugTtkjreD8eXIg7qrplhWan/view" style="color: gr.themes.colors.red; text-decoration: none;">Link To Paper</a> |
434
  <a href="https://github.com/Syudu41/ASTRA---Gates-Project" style="color: #1E90FF; text-decoration: none;">GitHub</a> |
435
+ <a href="https://sites.google.com/view/astra-research/home" style="color: #1E90FF; text-decoration: none;">Project Page</a>
436
  </h3>
437
 
438
  <p style="color: white;">Welcome to a demo of ASTRA. ASTRA is a collaborative research project between researchers at the
439
+ <a href="https://sites.google.com/site/dvngopal/" style="color: #1E90FF; text-decoration: none;">University of Memphis</a> and
440
  <a href="https://www.carnegielearning.com" style="color: #1E90FF; text-decoration: none;">Carnegie Learning</a>
441
  to utilize AI to improve our understanding of math learning strategies.</p>
442
 
443
+ <p style="color: white;">This demo has been developed with a pre-trained model (based on an architecture similar to BERT ) that learns math strategies using data
444
+ collected from hundreds of schools in the U.S. who have used Carnegie Learning’s MATHia (formerly known as Cognitive Tutor), the flagship Intelligent Tutor that is part of a core, blended math curriculum.
445
+ For this demo, we have used data from a specific domain (teaching ratio and proportions) within 7th grade math. The fine-tuning based on the pre-trained model learns to predict which strategies lead to correct vs incorrect solutions.
446
+ </p>
447
 
448
+ <p style="color: white;">In this math domain, students were given word problems related to ratio and proportions. Further, the students
449
+ were given a choice of optional tasks to work on in parallel to the main problem to demonstrate their thinking (metacognition).
450
+ The optional tasks are designed based on solving problems using Equivalent Ratios (ER) and solving using Means and Extremes/cross-multiplication (ME).
451
+ When the equivalent ratios are easy to compute (integral values), ER is much more efficient compared to ME and switching between the tasks appropriately demonstrates cognitive flexibility.
452
+ </p>
453
 
454
  <p style="color: white;">To use the demo, please follow these steps:</p>
455
 
 
460
  <li style="color: white;">ASTRA-FT-Full: Fine-tuned with a small sample of data from a mix of schools that have high/low graduation rates.</li>
461
  </ul>
462
  </li>
463
+ <li style="color: white;">Select a percentage of schools to analyze (selecting a large percentage may take a long time). Note that the selected percentage is applied to both High Graduation Rate (HGR) schools and Low Graduation Rate (LGR schools).
464
+ </li>
465
+ <li style="color: white;">The results from the fine-tuned model are displayed in the dashboard:
466
  <ul>
467
+ <li style="color: white;">The model accuracy is computed using the ROC-AUC metric.
468
+ </li>
469
+ <li style="color: white;">The results are shown for HGR, LGR schools and for different problem types (ER/ME).
470
+ </li>
471
+ <li style="color: white;">The distribution over how students utilized the optional tasks (whether they utilized ER/ME, used both of them or none of them) is shown for each problem type.
472
+ </li>
473
  </ul>
474
  </li>
475
  </ol>
476
  """
477
  # CSS styling for white text
478
  # Create the Gradio interface
479
+ available_themes = {
480
+ "default": gr.themes.Default(),
481
+ "soft": gr.themes.Soft(),
482
+ "monochrome": gr.themes.Monochrome(),
483
+ "glass": gr.themes.Glass(),
484
+ "base": gr.themes.Base(),
485
+ }
486
 
487
+ # Comprehensive CSS for all HTML elements
488
+ custom_css = '''
489
+ /* Import Fira Sans font */
490
+ @import url('https://fonts.googleapis.com/css2?family=Fira+Sans:wght@400;500;600;700&family=Inter:wght@400;500;600;700&display=swap');
491
+ @import url('https://fonts.googleapis.com/css2?family=Libre+Caslon+Text:ital,wght@0,400;0,700;1,400&family=Spectral+SC:wght@600&display=swap');
492
+ /* Container modifications for centering */
493
+ .gradio-container {
494
+ color: var(--block-label-text-color) !important;
495
+ max-width: 1000px !important;
496
+ margin: 0 auto !important;
497
+ padding: 2rem !important;
498
+ font-family: Arial, sans-serif !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  }
500
+
501
+ /* Main title (ASTRA) */
502
+ #title {
503
+ text-align: center !important;
504
+ margin: 1rem auto !important; /* Reduced margin */
505
+ font-size: 2.5em !important;
506
+ font-weight: 600 !important;
507
+ font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
508
+ padding-bottom: 0 !important; /* Remove bottom padding */
509
  }
510
 
511
+ /* Subtitle (An AI Model...) */
512
+ h1 {
513
+ text-align: center !important;
514
+ font-size: 30pt !important;
515
+ font-weight: 600 !important;
516
+ font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
517
+ margin-top: 0.5em !important; /* Reduced top margin */
518
+ margin-bottom: 0.3em !important;
 
 
519
  }
520
 
521
+ h2 {
522
+ text-align: center !important;
523
+ font-size: 22pt !important;
524
+ font-weight: 600 !important;
525
+ font-family: "Spectral SC",'Fira Sans', sans-serif !important;
526
+ margin-top: 0.2em !important; /* Reduced top margin */
527
+ margin-bottom: 0.3em !important;
 
 
 
528
  }
529
 
530
+ /* Links container styling */
531
+ .links-container {
532
+ text-align: center !important;
533
+ margin: 1em auto !important;
534
+ font-family: 'Inter' ,'Fira Sans', sans-serif !important;
535
  }
536
+
537
+ /* Links */
538
+ a {
539
+ color: #2563eb !important;
540
+ text-decoration: none !important;
541
+ font-family:'Inter' , 'Fira Sans', sans-serif !important;
542
  }
543
+
544
+ a:hover {
545
+ text-decoration: underline !important;
546
+ opacity: 0.8;
547
  }
548
 
549
+ /* Regular text */
550
+ p, li, .description, .markdown-text {
551
+ font-family: 'Inter', Arial, sans-serif !important;
552
+ color: black !important;
553
+ font-size: 11pt;
554
+ line-height: 1.6;
555
+ font-weight: 500 !important;
556
+ color: var(--block-label-text-color) !important;
557
+ }
558
+
559
+ /* Other headings */
560
+ h3, h4, h5 {
561
+ font-family: 'Fira Sans', sans-serif !important;
562
+ color: var(--block-label-text-color) !important;
563
+ margin-top: 1.5em;
564
+ margin-bottom: 0.75em;
565
+ }
566
+
567
+
568
+ h3 { font-size: 1.5em; font-weight: 600; }
569
+ h4 { font-size: 1.25em; font-weight: 500; }
570
+ h5 { font-size: 1.1em; font-weight: 500; }
571
+
572
+ /* Form elements */
573
+ .select-wrap select, .wrap select,
574
+ input, textarea {
575
+ font-family: 'Inter' ,Arial, sans-serif !important;
576
+ color: var(--block-label-text-color) !important;
577
+ }
578
+
579
+ /* Lists */
580
+ ul, ol {
581
+ margin-left: 0 !important;
582
+ margin-bottom: 1.25em;
583
+ padding-left: 2em;
584
+ }
585
+
586
+ li {
587
+ margin-bottom: 0.75em;
588
  }
589
+
590
+ /* Form container */
591
+ .form-container {
592
+ max-width: 1000px !important;
593
+ margin: 0 auto !important;
594
+ padding: 1rem !important;
595
+ }
596
+
597
+ /* Dashboard */
598
+ .dashboard {
599
+ margin-top: 2rem !important;
600
+ padding: 1rem !important;
601
+ border-radius: 8px !important;
602
+ }
603
+
604
+ /* Slider styling */
605
+ .gradio-slider-row {
606
  display: flex;
 
 
607
  align-items: center;
608
+ justify-content: space-between;
609
+ margin: 1.5em 0;
610
+ max-width: 100% !important;
611
+ }
612
+
613
+ .gradio-slider {
614
+ flex-grow: 1;
615
+ margin-right: 15px;
616
+ }
617
+
618
+ .slider-percentage {
619
+ font-family: 'Inter', Arial, sans-serif !important;
620
+ flex-shrink: 0;
621
+ min-width: 60px;
622
+ font-size: 1em;
623
+ font-weight: bold;
624
  text-align: center;
625
+ background-color: #f0f8ff;
626
+ border: 1px solid #004080;
627
+ border-radius: 5px;
628
+ padding: 5px 10px;
629
  }
630
+
631
+ .progress-bar-wrap.progress-bar-wrap.progress-bar-wrap
632
+ {
633
+ border-radius: var(--input-radius);
634
+ height: 1.25rem;
635
+ margin-top: 1rem;
636
+ overflow: hidden;
637
+ width: 70%;
638
+ font-family: 'Inter', Arial, sans-serif !important;
639
  }
640
+
641
+ /* Add these new styles after your existing CSS */
642
+
643
+ /* Card-like appearance for the dashboard */
644
+ .dashboard {
645
+ background: #ffffff !important;
646
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
647
+ border-radius: 12px !important;
648
+ padding: 2rem !important;
649
+ margin-top: 2.5rem !important;
650
  }
651
+
652
+ /* Enhance ROC graph container */
653
+ #roc {
654
+ background: #ffffff !important;
655
+ padding: 1.5rem !important;
656
+ border-radius: 8px !important;
657
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
658
+ margin: 1.5rem 0 !important;
659
+ }
660
+
661
+ /* Style the dropdown select */
662
+ select {
663
+ background-color: #ffffff !important;
664
+ border: 1px solid #e2e8f0 !important;
665
+ border-radius: 8px !important;
666
+ padding: 0.5rem 1rem !important;
667
+ transition: all 0.2s ease-in-out !important;
668
+ box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05) !important;
669
+ }
670
+
671
+ select:hover {
672
+ border-color: #cbd5e1 !important;
673
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
674
+ }
675
+
676
+ /* Enhance slider appearance */
677
+ .progress-bar-wrap {
678
+ background: #f8fafc !important;
679
+ border: 1px solid #e2e8f0 !important;
680
+ box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.05) !important;
681
+ }
682
+
683
+ /* Style metrics in dashboard */
684
+ .dashboard p {
685
+ padding: 0.5rem 0 !important;
686
+ border-bottom: 1px solid #f1f5f9 !important;
687
+ }
688
+
689
+ /* Add spacing between sections */
690
+ .dashboard > div {
691
+ margin-bottom: 1.5rem !important;
692
+ }
693
+
694
+ /* Style the ROC curve title */
695
+ .dashboard h4 {
696
+ color: #1e293b !important;
697
+ font-weight: 600 !important;
698
+ margin-bottom: 1rem !important;
699
+ padding-bottom: 0.5rem !important;
700
+ border-bottom: 2px solid #e2e8f0 !important;
701
  }
702
+
703
+ /* Enhance link appearances */
704
+ a {
705
+ position: relative !important;
706
+ padding-bottom: 2px !important;
707
+ transition: all 0.2s ease-in-out !important;
708
+ }
709
+
710
+ a:after {
711
+ content: '' !important;
712
+ position: absolute !important;
713
+ width: 0 !important;
714
+ height: 1px !important;
715
+ bottom: 0 !important;
716
+ left: 0 !important;
717
+ background-color: #2563eb !important;
718
+ transition: width 0.3s ease-in-out !important;
719
+ }
720
+
721
+ a:hover:after {
722
+ width: 100% !important;
723
+ }
724
+
725
+ /* Add subtle dividers between sections */
726
+ .form-container > div {
727
+ padding-bottom: 1.5rem !important;
728
+ margin-bottom: 1.5rem !important;
729
+ border-bottom: 1px solid #f1f5f9 !important;
730
+ }
731
+
732
+ /* Style model selection section */
733
+ .select-wrap {
734
+ background: #ffffff !important;
735
+ padding: 1.5rem !important;
736
+ border-radius: 8px !important;
737
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
738
+ margin-bottom: 2rem !important;
739
+ }
740
+
741
+ /* Style the metrics display */
742
+ .dashboard span {
743
+ font-family: 'Inter', sans-serif !important;
744
+ font-weight: 500 !important;
745
+ color: #334155 !important;
746
  }
747
 
748
+ /* Add subtle animation to interactive elements */
749
+ button, select, .slider-percentage {
750
+ transition: all 0.2s ease-in-out !important;
 
751
  }
752
 
753
+ /* Style the ROC curve container */
754
+ .plot-container {
755
+ background: #ffffff !important;
756
+ border-radius: 8px !important;
757
+ padding: 1rem !important;
758
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
759
+ }
760
+
761
+ /* Add container styles for opt1 and opt2 sections */
762
+ #opt1, #opt2 {
763
+ background: #ffffff !important;
764
+ border-radius: 8px !important;
765
+ padding: 1.5rem !important;
766
+ margin-top: 1.5rem !important;
767
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
768
+ }
769
 
770
+ /* Style the distribution titles */
771
+ .distribution-title {
772
+ font-family: 'Inter', sans-serif !important;
773
+ font-weight: 600 !important;
774
+ color: #1e293b !important;
775
+ margin-bottom: 1rem !important;
776
+ text-align: center !important;
777
  }
778
+
779
+ '''
780
+
781
+ with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
782
 
783
+ # gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
784
  gr.Markdown(content)
785
 
786
  with gr.Row():
 
788
  # label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
789
 
790
  # info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
791
+ model_dropdown = gr.Dropdown(
792
+ choices=models,
793
+ label="Select Fine-tuned Model",
794
+ elem_classes="dropdown-menu"
795
+ )
796
+ increment_slider = gr.Slider(
797
+ minimum=1,
798
+ maximum=100,
799
+ step=1,
800
+ label="Schools Percentage",
801
+ value=1,
802
+ elem_id="increment-slider",
803
+ elem_classes="gradio-slider"
804
+ )
805
 
806
+ with gr.Row():
807
+ btn = gr.Button("Submit")
808
 
 
 
809
  gr.Markdown("<p class='description'>Dashboard</p>")
810
+
811
  with gr.Row():
812
  output_text = gr.Textbox(label="")
813
  # output_image = gr.Image(label="ROC")
 
814
  with gr.Row():
815
+ plot_output = gr.Plot(label="ROC")
816
+
817
+ with gr.Row():
818
+ opt1_pie = gr.Plot(label="ER")
819
+ opt2_pie = gr.Plot(label="ME")
820
  # output_summary = gr.Textbox(label="Summary")
821
 
822
+
823
 
824
+ btn.click(
825
+ fn=process_file,
826
+ inputs=[model_dropdown,increment_slider],
827
+ outputs=[output_text,plot_output,opt1_pie,opt2_pie]
828
+ )
829
 
830
 
831
  # Launch the app
result.txt CHANGED
@@ -3,5 +3,5 @@ total_acc: 69.00702106318957
3
  precisions: 0.7236623191454734
4
  recalls: 0.6900702106318957
5
  f1_scores: 0.6802420656474512
6
- time_taken_from_start: 25.420082330703735
7
  auc_score: 0.7457100293916334
 
3
  precisions: 0.7236623191454734
4
  recalls: 0.6900702106318957
5
  f1_scores: 0.6802420656474512
6
+ time_taken_from_start: 53.13972353935242
7
  auc_score: 0.7457100293916334