gchhablani commited on
Commit
f384719
1 Parent(s): 571a3f6

Add auto scaling image

Browse files
Files changed (2) hide show
  1. apps/mlm.py +18 -6
  2. apps/vqa.py +1 -1
apps/mlm.py CHANGED
@@ -50,8 +50,11 @@ def app(state):
50
  if mlm_state.mlm_image_file is None:
51
  mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
52
  caption = dummy_data.loc[first_index, "caption"].strip("- ")
 
53
  ids = bert_tokenizer.encode(caption)
54
- ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
 
 
55
  mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
56
  mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
57
 
@@ -72,8 +75,11 @@ def app(state):
72
  sample = dummy_data.sample(1).reset_index()
73
  mlm_state.mlm_image_file = sample.loc[0, "image_file"]
74
  caption = sample.loc[0, "caption"].strip("- ")
 
75
  ids = bert_tokenizer.encode(caption)
76
- ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
 
 
77
  mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
78
  mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
79
 
@@ -99,7 +105,7 @@ def app(state):
99
  new_col1, new_col2 = st.beta_columns([5, 5])
100
 
101
  # Display Image
102
- new_col1.image(mlm_state.mlm_image, use_column_width="always")
103
 
104
  # Display caption
105
  new_col2.write("Write your text with exactly one [MASK] token.")
@@ -109,9 +115,14 @@ def app(state):
109
  help="Type your masked caption regarding the image above in one of the four languages.",
110
  )
111
 
112
- new_col2.markdown(
113
- f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}"""
114
- )
 
 
 
 
 
115
  caption_inputs = get_text_attributes(caption)
116
 
117
  # Display Top-5 Predictions
@@ -119,6 +130,7 @@ def app(state):
119
  scores = predict(transformed_image, dict(caption_inputs))
120
  scores = softmax(scores)
121
  labels, values = get_top_5_predictions(scores)
 
122
  # newer_col1, newer_col2 = st.beta_columns([6,4])
123
  fig = plotly_express_horizontal_bar_plot(values, labels)
124
  st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
 
50
  if mlm_state.mlm_image_file is None:
51
  mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
52
  caption = dummy_data.loc[first_index, "caption"].strip("- ")
53
+ mlm_state.unmasked_caption = caption
54
  ids = bert_tokenizer.encode(caption)
55
+ mask_index = np.random.randint(1, len(ids) - 1)
56
+ mlm_state.currently_masked_token = ids[mask_index]
57
+ ids[mask_index] = bert_tokenizer.mask_token_id
58
  mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
59
  mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
60
 
 
75
  sample = dummy_data.sample(1).reset_index()
76
  mlm_state.mlm_image_file = sample.loc[0, "image_file"]
77
  caption = sample.loc[0, "caption"].strip("- ")
78
+ mlm_state.unmasked_caption = caption
79
  ids = bert_tokenizer.encode(caption)
80
+ mask_index = np.random.randint(1, len(ids) - 1)
81
+ mlm_state.currently_masked_token = ids[mask_index]
82
+ ids[mask_index] = bert_tokenizer.mask_token_id
83
  mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
84
  mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
85
 
 
105
  new_col1, new_col2 = st.beta_columns([5, 5])
106
 
107
  # Display Image
108
+ new_col1.image(mlm_state.mlm_image, use_column_width="auto")
109
 
110
  # Display caption
111
  new_col2.write("Write your text with exactly one [MASK] token.")
 
115
  help="Type your masked caption regarding the image above in one of the four languages.",
116
  )
117
 
118
+ if caption == mlm_state.caption:
119
+ new_col2.markdown("**Masked Token**: "+mlm_state.currently_masked_token)
120
+ new_col2.markdown("**English Translation: " + mlm_state.unmasked_caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.unmasked_caption, 'en'))
121
+
122
+ else:
123
+ new_col2.markdown(
124
+ f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}"""
125
+ )
126
  caption_inputs = get_text_attributes(caption)
127
 
128
  # Display Top-5 Predictions
 
130
  scores = predict(transformed_image, dict(caption_inputs))
131
  scores = softmax(scores)
132
  labels, values = get_top_5_predictions(scores)
133
+ print(labels)
134
  # newer_col1, newer_col2 = st.beta_columns([6,4])
135
  fig = plotly_express_horizontal_bar_plot(values, labels)
136
  st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
apps/vqa.py CHANGED
@@ -109,7 +109,7 @@ def app(state):
109
  new_col1, new_col2 = st.beta_columns([5, 5])
110
 
111
  # Display Image
112
- new_col1.image(vqa_state.vqa_image, use_column_width="always")
113
 
114
  # Display Question
115
  question = new_col2.text_input(
 
109
  new_col1, new_col2 = st.beta_columns([5, 5])
110
 
111
  # Display Image
112
+ new_col1.image(vqa_state.vqa_image, use_column_width="auto")
113
 
114
  # Display Question
115
  question = new_col2.text_input(