kokuma commited on
Commit
f44a566
·
verified ·
1 Parent(s): c1ba109

Single language

Browse files
Files changed (1) hide show
  1. app.py +18 -22
app.py CHANGED
@@ -1049,9 +1049,9 @@ if not precomputed_results:
1049
  model = model.to(device)
1050
 
1051
 
1052
- def change_language(lang, randomize_imgs, randomize_labels):
1053
  # compute text embeddings
1054
- labels = babel_imagenet[lang][1]
1055
  class_order = list(range(len(labels)))
1056
  if randomize_labels:
1057
  np.random.shuffle(class_order)
@@ -1065,7 +1065,7 @@ def change_language(lang, randomize_imgs, randomize_labels):
1065
  else:
1066
  text_features = None
1067
  correct_text = gr.Text(
1068
- f"Correct was: ''. Question 1/{len(babel_imagenet[lang][0])} ", label="Game"
1069
  )
1070
  player_score_text = gr.Text(f"Your choice: (Score: 0) ", label="Player")
1071
  clip_score_text = gr.Text(f"mSigLIP chose: '' (Score: 0)", label="Opponent")
@@ -1082,7 +1082,7 @@ def change_language(lang, randomize_imgs, randomize_labels):
1082
  )
1083
 
1084
 
1085
- def select(idx, lang, choice, correct, model_choice, player_score, clip_score, choices):
1086
  # checks if answer choice is correct and updated scores
1087
  correct_name, correct_value = correct
1088
  model_choice_name, model_choice_value = model_choice
@@ -1095,7 +1095,7 @@ def select(idx, lang, choice, correct, model_choice, player_score, clip_score, c
1095
  clip_score = clip_score + int(model_correct)
1096
 
1097
  correct_text = gr.Text(
1098
- f"Correct was: '{correct_name}'. Question {idx+1}/{len(babel_imagenet[lang][0])} ",
1099
  label="Game",
1100
  )
1101
  player_score_text = gr.Text(
@@ -1110,19 +1110,19 @@ def select(idx, lang, choice, correct, model_choice, player_score, clip_score, c
1110
  return correct_text, player_score_text, clip_score_text, player_score, clip_score
1111
 
1112
 
1113
- def prepare(raw_idx, lang, text_embeddings, class_order, randomize_images):
1114
  # prepared next question, loads image, and computes choices
1115
 
1116
- raw_idx = (raw_idx + 1) % len(babel_imagenet[lang][0])
1117
  idx = class_order[raw_idx]
1118
- lang_class_idxs = babel_imagenet[lang][0]
1119
  class_idx = lang_class_idxs[idx]
1120
 
1121
  # skip classes with no images
1122
  while class_idx in no_image_idxs:
1123
- raw_idx = (raw_idx + 1) % len(babel_imagenet[lang][0])
1124
  idx = class_order[raw_idx]
1125
- lang_class_idxs = babel_imagenet[lang][0] if lang != "EN" else list(range(1000))
1126
  class_idx = lang_class_idxs[idx]
1127
 
1128
  img_idx = 0
@@ -1131,7 +1131,7 @@ def prepare(raw_idx, lang, text_embeddings, class_order, randomize_images):
1131
  min(len(babelnet_images[class_idx]), max_image_choices)
1132
  )
1133
  img_url = babelnet_images[class_idx][img_idx]["url"]
1134
- class_labels = babel_imagenet[lang][1] if lang != "EN" else openai_en_classes
1135
 
1136
  if not precomputed_results:
1137
  try:
@@ -1150,14 +1150,14 @@ def prepare(raw_idx, lang, text_embeddings, class_order, randomize_images):
1150
  except:
1151
  gr.Warning("There is a problem with the next class. Skipping it.")
1152
  return prepare(
1153
- raw_idx, lang, text_embeddings, class_order, randomize_images
1154
  )
1155
 
1156
  similarity = (text_embeddings @ image_features.cpu().numpy().T).squeeze()
1157
  choices = np.argsort(similarity)[-4:].tolist()
1158
  else:
1159
  choices = list(
1160
- reversed(precomputed_results[lang][idx][img_idx])
1161
  ) # precomputing script uses torch.topk which sorts in reverse here
1162
  if idx not in choices:
1163
  choices = [idx] + choices[1:]
@@ -1194,11 +1194,11 @@ def prepare(raw_idx, lang, text_embeddings, class_order, randomize_images):
1194
  return next_radio, next_image, raw_idx, correct_choice, model_choice, choice_values
1195
 
1196
 
1197
- def reroll(raw_idx, lang, text_embeddings, class_order, randomize_images):
1198
  # prepared next question, loads image, and computes choices
1199
 
1200
  idx = class_order[raw_idx]
1201
- lang_class_idxs = babel_imagenet[lang][0]
1202
  class_idx = lang_class_idxs[idx]
1203
 
1204
  img_idx = 0
@@ -1207,7 +1207,7 @@ def reroll(raw_idx, lang, text_embeddings, class_order, randomize_images):
1207
  min(len(babelnet_images[class_idx]), max_image_choices)
1208
  )
1209
  img_url = babelnet_images[class_idx][img_idx]["url"]
1210
- class_labels = babel_imagenet[lang][1] if lang != "EN" else openai_en_classes
1211
 
1212
  if not precomputed_results:
1213
  try:
@@ -1226,14 +1226,14 @@ def reroll(raw_idx, lang, text_embeddings, class_order, randomize_images):
1226
  except:
1227
  gr.Warning("There is a problem with the next class. Skipping it.")
1228
  return prepare(
1229
- raw_idx, lang, text_embeddings, class_order, randomize_images
1230
  )
1231
 
1232
  similarity = (text_embeddings @ image_features.cpu().numpy().T).squeeze()
1233
  choices = np.argsort(similarity)[-4:].tolist()
1234
  else:
1235
  choices = list(
1236
- reversed(precomputed_results[lang][idx][img_idx])
1237
  ) # precomputing script uses torch.topk which sorts in reverse here
1238
  if idx not in choices:
1239
  choices = [idx] + choices[1:]
@@ -1390,7 +1390,6 @@ Select your language, click 'Start' and start guessing! We'll keep track of your
1390
  fn=select,
1391
  inputs=[
1392
  class_idx,
1393
- "EN",
1394
  options,
1395
  correct_choice,
1396
  model_choice,
@@ -1409,7 +1408,6 @@ Select your language, click 'Start' and start guessing! We'll keep track of your
1409
  fn=prepare,
1410
  inputs=[
1411
  class_idx,
1412
- "EN",
1413
  text_embeddings,
1414
  class_order,
1415
  randomize_images,
@@ -1434,7 +1432,6 @@ Select your language, click 'Start' and start guessing! We'll keep track of your
1434
  fn=prepare,
1435
  inputs=[
1436
  class_idx,
1437
- "EN",
1438
  text_embeddings,
1439
  class_order,
1440
  randomize_images,
@@ -1446,7 +1443,6 @@ Select your language, click 'Start' and start guessing! We'll keep track of your
1446
  fn=reroll,
1447
  inputs=[
1448
  class_idx,
1449
- "EN",
1450
  text_embeddings,
1451
  class_order,
1452
  randomize_images,
 
1049
  model = model.to(device)
1050
 
1051
 
1052
+ def change_language(randomize_imgs, randomize_labels):
1053
  # compute text embeddings
1054
+ labels = babel_imagenet["EN"][1]
1055
  class_order = list(range(len(labels)))
1056
  if randomize_labels:
1057
  np.random.shuffle(class_order)
 
1065
  else:
1066
  text_features = None
1067
  correct_text = gr.Text(
1068
+ f"Correct was: ''. Question 1/{len(babel_imagenet["EN"][0])} ", label="Game"
1069
  )
1070
  player_score_text = gr.Text(f"Your choice: (Score: 0) ", label="Player")
1071
  clip_score_text = gr.Text(f"mSigLIP chose: '' (Score: 0)", label="Opponent")
 
1082
  )
1083
 
1084
 
1085
+ def select(idx, choice, correct, model_choice, player_score, clip_score, choices):
1086
  # checks if answer choice is correct and updated scores
1087
  correct_name, correct_value = correct
1088
  model_choice_name, model_choice_value = model_choice
 
1095
  clip_score = clip_score + int(model_correct)
1096
 
1097
  correct_text = gr.Text(
1098
+ f"Correct was: '{correct_name}'. Question {idx+1}/{len(babel_imagenet["EN"][0])} ",
1099
  label="Game",
1100
  )
1101
  player_score_text = gr.Text(
 
1110
  return correct_text, player_score_text, clip_score_text, player_score, clip_score
1111
 
1112
 
1113
+ def prepare(raw_idx, text_embeddings, class_order, randomize_images):
1114
  # prepared next question, loads image, and computes choices
1115
 
1116
+ raw_idx = (raw_idx + 1) % len(babel_imagenet["EN"][0])
1117
  idx = class_order[raw_idx]
1118
+ lang_class_idxs = babel_imagenet["EN"][0]
1119
  class_idx = lang_class_idxs[idx]
1120
 
1121
  # skip classes with no images
1122
  while class_idx in no_image_idxs:
1123
+ raw_idx = (raw_idx + 1) % len(babel_imagenet["EN"][0])
1124
  idx = class_order[raw_idx]
1125
+ lang_class_idxs = babel_imagenet["EN"][0] if "EN" != "EN" else list(range(1000))
1126
  class_idx = lang_class_idxs[idx]
1127
 
1128
  img_idx = 0
 
1131
  min(len(babelnet_images[class_idx]), max_image_choices)
1132
  )
1133
  img_url = babelnet_images[class_idx][img_idx]["url"]
1134
+ class_labels = babel_imagenet["EN"][1] if "EN" != "EN" else openai_en_classes
1135
 
1136
  if not precomputed_results:
1137
  try:
 
1150
  except:
1151
  gr.Warning("There is a problem with the next class. Skipping it.")
1152
  return prepare(
1153
+ raw_idx, text_embeddings, class_order, randomize_images
1154
  )
1155
 
1156
  similarity = (text_embeddings @ image_features.cpu().numpy().T).squeeze()
1157
  choices = np.argsort(similarity)[-4:].tolist()
1158
  else:
1159
  choices = list(
1160
+ reversed(precomputed_results["EN"][idx][img_idx])
1161
  ) # precomputing script uses torch.topk which sorts in reverse here
1162
  if idx not in choices:
1163
  choices = [idx] + choices[1:]
 
1194
  return next_radio, next_image, raw_idx, correct_choice, model_choice, choice_values
1195
 
1196
 
1197
+ def reroll(raw_idx, text_embeddings, class_order, randomize_images):
1198
  # prepared next question, loads image, and computes choices
1199
 
1200
  idx = class_order[raw_idx]
1201
+ lang_class_idxs = babel_imagenet["EN"][0]
1202
  class_idx = lang_class_idxs[idx]
1203
 
1204
  img_idx = 0
 
1207
  min(len(babelnet_images[class_idx]), max_image_choices)
1208
  )
1209
  img_url = babelnet_images[class_idx][img_idx]["url"]
1210
+ class_labels = babel_imagenet["EN"][1] if "EN" != "EN" else openai_en_classes
1211
 
1212
  if not precomputed_results:
1213
  try:
 
1226
  except:
1227
  gr.Warning("There is a problem with the next class. Skipping it.")
1228
  return prepare(
1229
+ raw_idx, text_embeddings, class_order, randomize_images
1230
  )
1231
 
1232
  similarity = (text_embeddings @ image_features.cpu().numpy().T).squeeze()
1233
  choices = np.argsort(similarity)[-4:].tolist()
1234
  else:
1235
  choices = list(
1236
+ reversed(precomputed_results["EN"][idx][img_idx])
1237
  ) # precomputing script uses torch.topk which sorts in reverse here
1238
  if idx not in choices:
1239
  choices = [idx] + choices[1:]
 
1390
  fn=select,
1391
  inputs=[
1392
  class_idx,
 
1393
  options,
1394
  correct_choice,
1395
  model_choice,
 
1408
  fn=prepare,
1409
  inputs=[
1410
  class_idx,
 
1411
  text_embeddings,
1412
  class_order,
1413
  randomize_images,
 
1432
  fn=prepare,
1433
  inputs=[
1434
  class_idx,
 
1435
  text_embeddings,
1436
  class_order,
1437
  randomize_images,
 
1443
  fn=reroll,
1444
  inputs=[
1445
  class_idx,
 
1446
  text_embeddings,
1447
  class_order,
1448
  randomize_images,