truong-xuan-linh commited on
Commit
1d3d5c8
·
1 Parent(s): b2813ce

update visualize

Browse files
app.py CHANGED
@@ -2,11 +2,13 @@ import glob
2
  import streamlit as st
3
 
4
  from streamlit_image_select import image_select
 
5
 
6
- #Trick to not init function multitime
7
  if "model" not in st.session_state:
8
  print("INIT MODEL")
9
  from src.model import Model
 
10
  st.session_state.model = Model()
11
  print("DONE INIT MODEL")
12
 
@@ -16,17 +18,25 @@ hide_menu_style = """
16
  footer {visibility: hidden;}
17
  </style>
18
  """
19
- st.markdown(hide_menu_style, unsafe_allow_html= True)
20
 
21
  mapper = {
22
- "images/000000000645.jpg": "Đây là đâu",
23
- "images/000000000661.jpg": "Tốc độ tối đa trên đoạn đường này là bao nhiêu",
24
- "images/000000000674.jpg": "Còn bao xa nữa là tới Huế",
25
- "images/000000000706.jpg": "Cầu này dài bao nhiêu",
26
- "images/000000000777.jpg": "Chè khúc bạch giá bao nhiêu"
27
  }
28
 
29
- image = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png", "webp", ])
 
 
 
 
 
 
 
 
30
  example = image_select("Examples", glob.glob("images/*.jpg"))
31
 
32
  if image:
@@ -40,10 +50,27 @@ else:
40
  st.session_state.question = mapper[example]
41
  st.session_state.image = example
42
 
43
- if 'image' in st.session_state:
44
  st.image(st.session_state.image)
45
  question = st.text_input("**Question:** ", value=st.session_state.question)
 
46
  if question:
47
- answer = st.session_state.model.inference(st.session_state.image, question)
 
 
 
 
48
  st.write(f"**Answer:** {answer}")
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import streamlit as st
3
 
4
  from streamlit_image_select import image_select
5
+ import streamlit.components.v1 as components
6
 
7
+ # Trick to not init function multitime
8
  if "model" not in st.session_state:
9
  print("INIT MODEL")
10
  from src.model import Model
11
+
12
  st.session_state.model = Model()
13
  print("DONE INIT MODEL")
14
 
 
18
  footer {visibility: hidden;}
19
  </style>
20
  """
21
+ st.markdown(hide_menu_style, unsafe_allow_html=True)
22
 
23
  mapper = {
24
+ "images/000000000645.jpg": "Đây là đâu",
25
+ "images/000000000661.jpg": "Tốc độ tối đa trên đoạn đường này là bao nhiêu",
26
+ "images/000000000674.jpg": "Còn bao xa nữa là tới Huế",
27
+ "images/000000000706.jpg": "Cầu này dài bao nhiêu",
28
+ "images/000000000777.jpg": "Chè khúc bạch giá bao nhiêu",
29
  }
30
 
31
+ image = st.file_uploader(
32
+ "Choose an image file",
33
+ type=[
34
+ "jpg",
35
+ "jpeg",
36
+ "png",
37
+ "webp",
38
+ ],
39
+ )
40
  example = image_select("Examples", glob.glob("images/*.jpg"))
41
 
42
  if image:
 
50
  st.session_state.question = mapper[example]
51
  st.session_state.image = example
52
 
53
+ if "image" in st.session_state:
54
  st.image(st.session_state.image)
55
  question = st.text_input("**Question:** ", value=st.session_state.question)
56
+ visualize = True
57
  if question:
58
+ answer, text_attention_html, images_visualize = (
59
+ st.session_state.model.inference(
60
+ st.session_state.image, question, visualize
61
+ )
62
+ )
63
  st.write(f"**Answer:** {answer}")
64
+
65
+ if visualize:
66
+ st.write("**Explanation**")
67
+ col1, col2 = st.columns([1, 2])
68
+ # st.markdown(text_attention_html, unsafe_allow_html=True)
69
+ with col1:
70
+ st.write("*Text Attention*")
71
+ components.html(text_attention_html, height=960, scrolling=True)
72
+
73
+ with col2:
74
+ st.write("*Image Attention*")
75
+ for image_visualize in images_visualize:
76
+ st.image(image_visualize)
pre-requirements.txt CHANGED
@@ -6,3 +6,5 @@ torchvision==0.18.0
6
  streamlit==1.35.0
7
  transformers==4.41.2
8
  streamlit-image-select==0.6.0
 
 
 
6
  streamlit==1.35.0
7
  transformers==4.41.2
8
  streamlit-image-select==0.6.0
9
+ bertviz==1.4.0
10
+ ipython==8.18.1
src/feature_extraction.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import torch
3
  import requests
4
  from PIL import Image, ImageFont, ImageDraw, ImageTransform
@@ -9,7 +8,9 @@ from src.ocr import OCRDetector
9
 
10
  class ViT:
11
  def __init__(self) -> None:
12
- self.processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
 
 
13
  self.model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
14
  self.model.to(Config.device)
15
 
@@ -23,7 +24,9 @@ class ViT:
23
  with torch.no_grad():
24
  outputs = self.model(**inputs)
25
  last_hidden_states = outputs.last_hidden_state
26
- attention_mask = torch.ones((last_hidden_states.shape[0], last_hidden_states.shape[1]))
 
 
27
 
28
  return last_hidden_states.to(Config.device), attention_mask.to(Config.device)
29
 
@@ -34,16 +37,20 @@ class ViT:
34
  image_outputs = self.model(**image_inputs)
35
  image_pooler_output = image_outputs.pooler_output
36
  image_pooler_output = torch.unsqueeze(image_pooler_output, 0)
37
- image_attention_mask = torch.ones((image_pooler_output.shape[0], image_pooler_output.shape[1]))
 
 
 
 
 
 
38
 
39
- return image_pooler_output.to(Config.device), image_attention_mask.to(Config.device)
40
 
41
  class OCR:
42
  def __init__(self) -> None:
43
  self.ocr_detector = OCRDetector()
44
 
45
  def extraction(self, image_dir):
46
-
47
  ocr_results = self.ocr_detector.text_detector(image_dir)
48
  if not ocr_results:
49
  print("NOT OCR1")
@@ -53,7 +60,6 @@ class OCR:
53
  ocrs = self.post_process(ocr_results)
54
 
55
  if not ocrs:
56
-
57
  return "", [], []
58
 
59
  ocrs.reverse()
@@ -74,10 +80,9 @@ class OCR:
74
  ocr_content = " ".join(ocr_content.split())
75
  ocr_content = "<extra_id_0>" + ocr_content
76
 
77
-
78
  return ocr_content, groups_box, paragraph_boxes
79
 
80
- def post_process(self,ocr_results):
81
  ocrs = []
82
  for result in ocr_results:
83
  text = result["text"]
@@ -96,10 +101,7 @@ class OCR:
96
  # if w*h < 300:
97
  # continue
98
 
99
- ocrs.append(
100
- {"text": text.lower(),
101
- "box": box}
102
- )
103
  return ocrs
104
 
105
  @staticmethod
@@ -107,87 +109,96 @@ class OCR:
107
  (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
108
  w = x2 - x1
109
  h = y4 - y1
110
- scl = h//7
111
- new_box = [max(x1-scl,0), max(y1 - scl, 0)], [x2+scl, y2-scl], [x3+scl, y3+scl], [x4-scl, y4+scl]
 
 
 
 
 
112
  (x1, y1), (x2, y2), (x3, y3), (x4, y4) = new_box
113
  # Define 8-tuple with x,y coordinates of top-left, bottom-left, bottom-right and top-right corners and apply
114
  transform = [x1, y1, x4, y4, x3, y3, x2, y2]
115
- result = image.transform((w,h), ImageTransform.QuadTransform(transform))
116
  return result
117
 
118
-
119
  @staticmethod
120
  def check_point_in_rectangle(box, point, padding_devide):
121
- (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
122
- x_min = min(x1, x4)
123
- x_max = max(x2, x3)
124
 
125
- padding = (x_max-x_min)//padding_devide
126
- x_min = x_min - padding
127
- x_max = x_max + padding
128
 
129
- y_min = min(y1, y2)
130
- y_max = max(y3, y4)
131
 
132
- y_min = y_min - padding
133
- y_max = y_max + padding
134
 
135
- x, y = point
136
 
137
- if x >= x_min and x <= x_max and y >= y_min and y <= y_max:
138
- return True
139
 
140
- return False
141
 
142
  @staticmethod
143
  def check_rectangle_overlap(rec1, rec2, padding_devide):
144
- for point in rec1:
145
- if OCR.check_point_in_rectangle(rec2, point, padding_devide):
146
- return True
147
 
148
- for point in rec2:
149
- if OCR.check_point_in_rectangle(rec1, point, padding_devide):
150
- return True
151
 
152
- return False
153
 
154
  @staticmethod
155
  def group_boxes(boxes, texts):
156
- groups = []
157
- groups_text = []
158
- paragraph_boxes = []
159
- processed = []
160
- boxes_cp = boxes.copy()
161
- for i, (box, text) in enumerate(zip(boxes_cp, texts)):
162
- (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
163
-
164
- if i not in processed:
165
- processed.append(i)
166
- else:
167
- continue
168
-
169
- groups.append([box])
170
- groups_text.append([text])
171
- for j, (box2, text2) in enumerate(zip(boxes_cp[i+1:], texts[i+1:])):
172
- if j+i+1 in processed:
173
- continue
174
- padding_devide = len(groups[-1])*4
175
- is_overlap = OCR.check_rectangle_overlap(box, box2, padding_devide)
176
- if is_overlap:
177
- (xx1, yy1), (xx2, yy2), (xx3, yy3), (xx4, yy4) = box2
178
- processed.append(j+i+1)
179
- groups[-1].append(box2)
180
- groups_text[-1].append(text2)
181
- new_x1 = min(x1, xx1)
182
- new_y1 = min(y1, yy1)
183
- new_x2 = max(x2, xx2)
184
- new_y2 = min(y2, yy2)
185
- new_x3 = max(x3, xx3)
186
- new_y3 = max(y3, yy3)
187
- new_x4 = min(x4, xx4)
188
- new_y4 = max(y4, yy4)
189
-
190
- box = [(new_x1, new_y1), (new_x2, new_y2), (new_x3, new_y3), (new_x4, new_y4)]
191
-
192
- paragraph_boxes.append(box)
193
- return groups, groups_text, paragraph_boxes
 
 
 
 
 
 
 
1
  import torch
2
  import requests
3
  from PIL import Image, ImageFont, ImageDraw, ImageTransform
 
8
 
9
  class ViT:
10
  def __init__(self) -> None:
11
+ self.processor = AutoImageProcessor.from_pretrained(
12
+ "google/vit-base-patch16-224-in21k"
13
+ )
14
  self.model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
15
  self.model.to(Config.device)
16
 
 
24
  with torch.no_grad():
25
  outputs = self.model(**inputs)
26
  last_hidden_states = outputs.last_hidden_state
27
+ attention_mask = torch.ones(
28
+ (last_hidden_states.shape[0], last_hidden_states.shape[1])
29
+ )
30
 
31
  return last_hidden_states.to(Config.device), attention_mask.to(Config.device)
32
 
 
37
  image_outputs = self.model(**image_inputs)
38
  image_pooler_output = image_outputs.pooler_output
39
  image_pooler_output = torch.unsqueeze(image_pooler_output, 0)
40
+ image_attention_mask = torch.ones(
41
+ (image_pooler_output.shape[0], image_pooler_output.shape[1])
42
+ )
43
+
44
+ return image_pooler_output.to(Config.device), image_attention_mask.to(
45
+ Config.device
46
+ )
47
 
 
48
 
49
  class OCR:
50
  def __init__(self) -> None:
51
  self.ocr_detector = OCRDetector()
52
 
53
  def extraction(self, image_dir):
 
54
  ocr_results = self.ocr_detector.text_detector(image_dir)
55
  if not ocr_results:
56
  print("NOT OCR1")
 
60
  ocrs = self.post_process(ocr_results)
61
 
62
  if not ocrs:
 
63
  return "", [], []
64
 
65
  ocrs.reverse()
 
80
  ocr_content = " ".join(ocr_content.split())
81
  ocr_content = "<extra_id_0>" + ocr_content
82
 
 
83
  return ocr_content, groups_box, paragraph_boxes
84
 
85
+ def post_process(self, ocr_results):
86
  ocrs = []
87
  for result in ocr_results:
88
  text = result["text"]
 
101
  # if w*h < 300:
102
  # continue
103
 
104
+ ocrs.append({"text": text.lower(), "box": box})
 
 
 
105
  return ocrs
106
 
107
  @staticmethod
 
109
  (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
110
  w = x2 - x1
111
  h = y4 - y1
112
+ scl = h // 7
113
+ new_box = (
114
+ [max(x1 - scl, 0), max(y1 - scl, 0)],
115
+ [x2 + scl, y2 - scl],
116
+ [x3 + scl, y3 + scl],
117
+ [x4 - scl, y4 + scl],
118
+ )
119
  (x1, y1), (x2, y2), (x3, y3), (x4, y4) = new_box
120
  # Define 8-tuple with x,y coordinates of top-left, bottom-left, bottom-right and top-right corners and apply
121
  transform = [x1, y1, x4, y4, x3, y3, x2, y2]
122
+ result = image.transform((w, h), ImageTransform.QuadTransform(transform))
123
  return result
124
 
 
125
  @staticmethod
126
  def check_point_in_rectangle(box, point, padding_devide):
127
+ (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
128
+ x_min = min(x1, x4)
129
+ x_max = max(x2, x3)
130
 
131
+ padding = (x_max - x_min) // padding_devide
132
+ x_min = x_min - padding
133
+ x_max = x_max + padding
134
 
135
+ y_min = min(y1, y2)
136
+ y_max = max(y3, y4)
137
 
138
+ y_min = y_min - padding
139
+ y_max = y_max + padding
140
 
141
+ x, y = point
142
 
143
+ if x >= x_min and x <= x_max and y >= y_min and y <= y_max:
144
+ return True
145
 
146
+ return False
147
 
148
  @staticmethod
149
  def check_rectangle_overlap(rec1, rec2, padding_devide):
150
+ for point in rec1:
151
+ if OCR.check_point_in_rectangle(rec2, point, padding_devide):
152
+ return True
153
 
154
+ for point in rec2:
155
+ if OCR.check_point_in_rectangle(rec1, point, padding_devide):
156
+ return True
157
 
158
+ return False
159
 
160
  @staticmethod
161
  def group_boxes(boxes, texts):
162
+ groups = []
163
+ groups_text = []
164
+ paragraph_boxes = []
165
+ processed = []
166
+ boxes_cp = boxes.copy()
167
+ for i, (box, text) in enumerate(zip(boxes_cp, texts)):
168
+ (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
169
+
170
+ if i not in processed:
171
+ processed.append(i)
172
+ else:
173
+ continue
174
+
175
+ groups.append([box])
176
+ groups_text.append([text])
177
+ for j, (box2, text2) in enumerate(zip(boxes_cp[i + 1 :], texts[i + 1 :])):
178
+ if j + i + 1 in processed:
179
+ continue
180
+ padding_devide = len(groups[-1]) * 4
181
+ is_overlap = OCR.check_rectangle_overlap(box, box2, padding_devide)
182
+ if is_overlap:
183
+ (xx1, yy1), (xx2, yy2), (xx3, yy3), (xx4, yy4) = box2
184
+ processed.append(j + i + 1)
185
+ groups[-1].append(box2)
186
+ groups_text[-1].append(text2)
187
+ new_x1 = min(x1, xx1)
188
+ new_y1 = min(y1, yy1)
189
+ new_x2 = max(x2, xx2)
190
+ new_y2 = min(y2, yy2)
191
+ new_x3 = max(x3, xx3)
192
+ new_y3 = max(y3, yy3)
193
+ new_x4 = min(x4, xx4)
194
+ new_y4 = max(y4, yy4)
195
+
196
+ box = [
197
+ (new_x1, new_y1),
198
+ (new_x2, new_y2),
199
+ (new_x3, new_y3),
200
+ (new_x4, new_y4),
201
+ ]
202
+
203
+ paragraph_boxes.append(box)
204
+ return groups, groups_text, paragraph_boxes
src/image_visualization.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+
4
+ # Show attention
5
+ def plot_attention(img, result, attention_plot, image_dir):
6
+ # img = img.numpy().transpose((1, 2, 0))
7
+ temp_image = img
8
+
9
+ fig = plt.figure(figsize=(15, 15))
10
+
11
+ len_result = len(result)
12
+ for l in range(len_result):
13
+ temp_att = attention_plot[l][1:].reshape(14, 14)
14
+ # temp_att = np.resize(attention_plot[l].detach().numpy(),(98,98))
15
+ ax = fig.add_subplot(len_result // 2, len_result // 2, l + 1)
16
+ ax.set_title(result[l], fontsize=18)
17
+ img = ax.imshow(temp_image)
18
+ ax.imshow(temp_att, alpha=0.6, cmap="jet", extent=img.get_extent())
19
+
20
+ plt.tight_layout()
21
+ plt.savefig(image_dir)
src/model.py CHANGED
@@ -8,12 +8,16 @@ from typing import *
8
  from transformers import T5ForConditionalGeneration, AutoTokenizer
9
  from utils.config import Config
10
  from src.feature_extraction import ViT, OCR
 
 
 
 
11
 
12
  _CONFIG_FOR_DOC = "T5Config"
13
  _CHECKPOINT_FOR_DOC = "google-t5/t5-small"
14
 
15
- class CustomT5Stack(T5Stack):
16
 
 
17
  def forward(
18
  self,
19
  input_ids=None,
@@ -35,11 +39,19 @@ class CustomT5Stack(T5Stack):
35
  torch.cuda.set_device(self.first_device)
36
  self.embed_tokens = self.embed_tokens.to(self.first_device)
37
  use_cache = use_cache if use_cache is not None else self.config.use_cache
38
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
39
  output_hidden_states = (
40
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
41
  )
42
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
43
 
44
  if input_ids is not None and inputs_embeds is not None:
45
  err_msg_prefix = "decoder_" if self.is_decoder else ""
@@ -53,11 +65,15 @@ class CustomT5Stack(T5Stack):
53
  input_shape = inputs_embeds.size()[:-1]
54
  else:
55
  err_msg_prefix = "decoder_" if self.is_decoder else ""
56
- raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
 
 
57
 
58
  if inputs_embeds is None:
59
  if self.embed_tokens is None:
60
- raise ValueError("You have to initialize the model with valid token embeddings")
 
 
61
  inputs_embeds = self.embed_tokens(input_ids)
62
  if not self.is_decoder and images_embeds is not None:
63
  inputs_embeds = torch.concat([inputs_embeds, images_embeds], dim=1)
@@ -66,33 +82,47 @@ class CustomT5Stack(T5Stack):
66
  batch_size, seq_length = input_shape
67
 
68
  # required mask seq length can be calculated via length of past
69
- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
 
 
 
 
70
 
71
  if use_cache is True:
72
  if not self.is_decoder:
73
- raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
 
 
74
 
75
  # initialize past_key_values with `None` if past does not exist
76
  if past_key_values is None:
77
  past_key_values = [None] * len(self.block)
78
 
79
  if attention_mask is None:
80
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
 
 
81
 
82
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
83
  # ourselves in which case we just need to make it broadcastable to all heads.
84
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
 
 
85
 
86
  # If a 2D or 3D attention mask is provided for the cross-attention
87
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
88
  if self.is_decoder and encoder_hidden_states is not None:
89
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
 
 
90
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
91
  if encoder_attention_mask is None:
92
  encoder_attention_mask = torch.ones(
93
  encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
94
  )
95
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
 
 
96
  else:
97
  encoder_extended_attention_mask = None
98
 
@@ -105,7 +135,9 @@ class CustomT5Stack(T5Stack):
105
 
106
  # Prepare head mask if needed
107
  head_mask = self.get_head_mask(head_mask, self.config.num_layers)
108
- cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
 
 
109
  present_key_value_states = () if use_cache else None
110
  all_hidden_states = () if output_hidden_states else None
111
  all_attentions = () if output_attentions else None
@@ -115,7 +147,9 @@ class CustomT5Stack(T5Stack):
115
 
116
  hidden_states = self.dropout(inputs_embeds)
117
 
118
- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
 
 
119
  layer_head_mask = head_mask[i]
120
  cross_attn_layer_head_mask = cross_attn_head_mask[i]
121
  # Model parallel
@@ -127,15 +161,23 @@ class CustomT5Stack(T5Stack):
127
  if position_bias is not None:
128
  position_bias = position_bias.to(hidden_states.device)
129
  if encoder_hidden_states is not None:
130
- encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
 
 
131
  if encoder_extended_attention_mask is not None:
132
- encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
 
 
133
  if encoder_decoder_position_bias is not None:
134
- encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
 
 
135
  if layer_head_mask is not None:
136
  layer_head_mask = layer_head_mask.to(hidden_states.device)
137
  if cross_attn_layer_head_mask is not None:
138
- cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
 
 
139
  if output_hidden_states:
140
  all_hidden_states = all_hidden_states + (hidden_states,)
141
 
@@ -181,10 +223,14 @@ class CustomT5Stack(T5Stack):
181
  # (cross-attention position bias), (cross-attention weights)
182
  position_bias = layer_outputs[2]
183
  if self.is_decoder and encoder_hidden_states is not None:
184
- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
 
 
185
  # append next layer key value states
186
  if use_cache:
187
- present_key_value_states = present_key_value_states + (present_key_value_state,)
 
 
188
 
189
  if output_attentions:
190
  all_attentions = all_attentions + (layer_outputs[3],)
@@ -227,7 +273,9 @@ class CustomT5Stack(T5Stack):
227
 
228
  class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
229
  @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
230
- @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
 
 
231
  def forward(
232
  self,
233
  input_ids: Optional[torch.LongTensor] = None,
@@ -280,7 +328,9 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
280
  >>> # studies have shown that owning a dog is good for you.
281
  ```"""
282
  use_cache = use_cache if use_cache is not None else self.config.use_cache
283
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
284
 
285
  # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
286
  if head_mask is not None and decoder_head_mask is None:
@@ -299,7 +349,7 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
299
  output_attentions=output_attentions,
300
  output_hidden_states=output_hidden_states,
301
  return_dict=return_dict,
302
- images_embeds=images_embeds
303
  )
304
  elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
305
  encoder_outputs = BaseModelOutput(
@@ -313,7 +363,11 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
313
  if self.model_parallel:
314
  torch.cuda.set_device(self.decoder.first_device)
315
 
316
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
 
 
 
 
317
  # get decoder inputs from shifting lm labels to the right
318
  decoder_input_ids = self._shift_right(labels)
319
 
@@ -326,7 +380,9 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
326
  if attention_mask is not None:
327
  attention_mask = attention_mask.to(self.decoder.first_device)
328
  if decoder_attention_mask is not None:
329
- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
 
 
330
 
331
  # Decode
332
  decoder_outputs = self.decoder(
@@ -382,64 +438,124 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
382
  encoder_hidden_states=encoder_outputs.hidden_states,
383
  encoder_attentions=encoder_outputs.attentions,
384
  )
385
-
 
386
  transformers.models.t5.modeling_t5.T5Stack = CustomT5Stack
387
- transformers.models.t5.modeling_t5.T5ForConditionalGeneration = CustomT5ForConditionalGeneration
 
 
388
  transformers.T5ForConditionalGeneration = CustomT5ForConditionalGeneration
 
389
 
390
 
391
  class Model:
392
  def __init__(self) -> None:
393
  os.makedirs("storage", exist_ok=True)
394
-
395
  if not os.path.exists("storage/vlsp_transfomer_vietocr.pth"):
396
  print("DOWNLOADING model")
397
- gdown.download(Config.model_url, output="storage/vlsp_transfomer_vietocr.pth")
 
 
398
  self.vit5_tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
399
- self.model = T5ForConditionalGeneration.from_pretrained("truong-xuan-linh/VQA-vit5",
400
- revision=Config.revision,
401
- output_attentions=True)
 
 
402
  self.model.to(Config.device)
403
 
404
  self.vit = ViT()
405
  self.ocr = OCR()
406
 
407
  def get_inputs(self, image_dir: str, question: str):
408
- #VIT
409
  image_feature, image_mask = self.vit.extraction(image_dir)
410
 
411
  ocr_content, groups_box, paragraph_boxes = self.ocr.extraction(image_dir)
412
  print("Input: ", question + " " + ocr_content)
413
- #VIT5
414
- input_ = self.vit5_tokenizer(question + " " + ocr_content,
415
- padding="max_length",
416
- truncation=True,
417
- max_length=Config.question_maxlen + Config.ocr_maxlen,
418
- return_tensors="pt")
 
 
419
 
420
  input_ids = input_.input_ids
421
  attention_mask = input_.attention_mask
422
  mask = torch.cat((attention_mask, image_mask), 1)
423
  return {
424
- "input_ids": input_ids,
425
- "attention_mask": mask,
426
- "images_embeds": image_feature,
427
- }
428
 
429
- def inference(self, image_dir: str, question: str):
430
  inputs = self.get_inputs(image_dir, question)
431
  with torch.no_grad():
432
  input_ids = inputs["input_ids"]
433
  attention_mask = inputs["attention_mask"]
434
  images_embeds = inputs["images_embeds"]
435
  generated_ids = self.model.generate(
436
- input_ids=input_ids, \
437
- attention_mask=attention_mask, \
438
- images_embeds=images_embeds, \
439
- num_beams=2,
440
- max_length=Config.answer_maxlen
441
- )
 
 
 
 
 
 
442
 
443
- pred_answer = self.vit5_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
 
 
444
 
445
- return pred_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from transformers import T5ForConditionalGeneration, AutoTokenizer
9
  from utils.config import Config
10
  from src.feature_extraction import ViT, OCR
11
+ from bertviz import model_view, head_view
12
+ from src.image_visualization import plot_attention
13
+ import numpy as np
14
+ from PIL import Image
15
 
16
  _CONFIG_FOR_DOC = "T5Config"
17
  _CHECKPOINT_FOR_DOC = "google-t5/t5-small"
18
 
 
19
 
20
+ class CustomT5Stack(T5Stack):
21
  def forward(
22
  self,
23
  input_ids=None,
 
39
  torch.cuda.set_device(self.first_device)
40
  self.embed_tokens = self.embed_tokens.to(self.first_device)
41
  use_cache = use_cache if use_cache is not None else self.config.use_cache
42
+ output_attentions = (
43
+ output_attentions
44
+ if output_attentions is not None
45
+ else self.config.output_attentions
46
+ )
47
  output_hidden_states = (
48
+ output_hidden_states
49
+ if output_hidden_states is not None
50
+ else self.config.output_hidden_states
51
+ )
52
+ return_dict = (
53
+ return_dict if return_dict is not None else self.config.use_return_dict
54
  )
 
55
 
56
  if input_ids is not None and inputs_embeds is not None:
57
  err_msg_prefix = "decoder_" if self.is_decoder else ""
 
65
  input_shape = inputs_embeds.size()[:-1]
66
  else:
67
  err_msg_prefix = "decoder_" if self.is_decoder else ""
68
+ raise ValueError(
69
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
70
+ )
71
 
72
  if inputs_embeds is None:
73
  if self.embed_tokens is None:
74
+ raise ValueError(
75
+ "You have to initialize the model with valid token embeddings"
76
+ )
77
  inputs_embeds = self.embed_tokens(input_ids)
78
  if not self.is_decoder and images_embeds is not None:
79
  inputs_embeds = torch.concat([inputs_embeds, images_embeds], dim=1)
 
82
  batch_size, seq_length = input_shape
83
 
84
  # required mask seq length can be calculated via length of past
85
+ mask_seq_length = (
86
+ past_key_values[0][0].shape[2] + seq_length
87
+ if past_key_values is not None
88
+ else seq_length
89
+ )
90
 
91
  if use_cache is True:
92
  if not self.is_decoder:
93
+ raise ValueError(
94
+ f"`use_cache` can only be set to `True` if {self} is used as a decoder"
95
+ )
96
 
97
  # initialize past_key_values with `None` if past does not exist
98
  if past_key_values is None:
99
  past_key_values = [None] * len(self.block)
100
 
101
  if attention_mask is None:
102
+ attention_mask = torch.ones(
103
+ batch_size, mask_seq_length, device=inputs_embeds.device
104
+ )
105
 
106
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
107
  # ourselves in which case we just need to make it broadcastable to all heads.
108
+ extended_attention_mask = self.get_extended_attention_mask(
109
+ attention_mask, input_shape
110
+ )
111
 
112
  # If a 2D or 3D attention mask is provided for the cross-attention
113
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
114
  if self.is_decoder and encoder_hidden_states is not None:
115
+ encoder_batch_size, encoder_sequence_length, _ = (
116
+ encoder_hidden_states.size()
117
+ )
118
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
119
  if encoder_attention_mask is None:
120
  encoder_attention_mask = torch.ones(
121
  encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
122
  )
123
+ encoder_extended_attention_mask = self.invert_attention_mask(
124
+ encoder_attention_mask
125
+ )
126
  else:
127
  encoder_extended_attention_mask = None
128
 
 
135
 
136
  # Prepare head mask if needed
137
  head_mask = self.get_head_mask(head_mask, self.config.num_layers)
138
+ cross_attn_head_mask = self.get_head_mask(
139
+ cross_attn_head_mask, self.config.num_layers
140
+ )
141
  present_key_value_states = () if use_cache else None
142
  all_hidden_states = () if output_hidden_states else None
143
  all_attentions = () if output_attentions else None
 
147
 
148
  hidden_states = self.dropout(inputs_embeds)
149
 
150
+ for i, (layer_module, past_key_value) in enumerate(
151
+ zip(self.block, past_key_values)
152
+ ):
153
  layer_head_mask = head_mask[i]
154
  cross_attn_layer_head_mask = cross_attn_head_mask[i]
155
  # Model parallel
 
161
  if position_bias is not None:
162
  position_bias = position_bias.to(hidden_states.device)
163
  if encoder_hidden_states is not None:
164
+ encoder_hidden_states = encoder_hidden_states.to(
165
+ hidden_states.device
166
+ )
167
  if encoder_extended_attention_mask is not None:
168
+ encoder_extended_attention_mask = (
169
+ encoder_extended_attention_mask.to(hidden_states.device)
170
+ )
171
  if encoder_decoder_position_bias is not None:
172
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(
173
+ hidden_states.device
174
+ )
175
  if layer_head_mask is not None:
176
  layer_head_mask = layer_head_mask.to(hidden_states.device)
177
  if cross_attn_layer_head_mask is not None:
178
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
179
+ hidden_states.device
180
+ )
181
  if output_hidden_states:
182
  all_hidden_states = all_hidden_states + (hidden_states,)
183
 
 
223
  # (cross-attention position bias), (cross-attention weights)
224
  position_bias = layer_outputs[2]
225
  if self.is_decoder and encoder_hidden_states is not None:
226
+ encoder_decoder_position_bias = layer_outputs[
227
+ 4 if output_attentions else 3
228
+ ]
229
  # append next layer key value states
230
  if use_cache:
231
+ present_key_value_states = present_key_value_states + (
232
+ present_key_value_state,
233
+ )
234
 
235
  if output_attentions:
236
  all_attentions = all_attentions + (layer_outputs[3],)
 
273
 
274
  class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
275
  @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
276
+ @replace_return_docstrings(
277
+ output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
278
+ )
279
  def forward(
280
  self,
281
  input_ids: Optional[torch.LongTensor] = None,
 
328
  >>> # studies have shown that owning a dog is good for you.
329
  ```"""
330
  use_cache = use_cache if use_cache is not None else self.config.use_cache
331
+ return_dict = (
332
+ return_dict if return_dict is not None else self.config.use_return_dict
333
+ )
334
 
335
  # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
336
  if head_mask is not None and decoder_head_mask is None:
 
349
  output_attentions=output_attentions,
350
  output_hidden_states=output_hidden_states,
351
  return_dict=return_dict,
352
+ images_embeds=images_embeds,
353
  )
354
  elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
355
  encoder_outputs = BaseModelOutput(
 
363
  if self.model_parallel:
364
  torch.cuda.set_device(self.decoder.first_device)
365
 
366
+ if (
367
+ labels is not None
368
+ and decoder_input_ids is None
369
+ and decoder_inputs_embeds is None
370
+ ):
371
  # get decoder inputs from shifting lm labels to the right
372
  decoder_input_ids = self._shift_right(labels)
373
 
 
380
  if attention_mask is not None:
381
  attention_mask = attention_mask.to(self.decoder.first_device)
382
  if decoder_attention_mask is not None:
383
+ decoder_attention_mask = decoder_attention_mask.to(
384
+ self.decoder.first_device
385
+ )
386
 
387
  # Decode
388
  decoder_outputs = self.decoder(
 
438
  encoder_hidden_states=encoder_outputs.hidden_states,
439
  encoder_attentions=encoder_outputs.attentions,
440
  )
441
+
442
+
443
  transformers.models.t5.modeling_t5.T5Stack = CustomT5Stack
444
+ transformers.models.t5.modeling_t5.T5ForConditionalGeneration = (
445
+ CustomT5ForConditionalGeneration
446
+ )
447
  transformers.T5ForConditionalGeneration = CustomT5ForConditionalGeneration
448
+ from transformers import T5ForConditionalGeneration
449
 
450
 
451
  class Model:
452
  def __init__(self) -> None:
453
  os.makedirs("storage", exist_ok=True)
454
+
455
  if not os.path.exists("storage/vlsp_transfomer_vietocr.pth"):
456
  print("DOWNLOADING model")
457
+ gdown.download(
458
+ Config.model_url, output="storage/vlsp_transfomer_vietocr.pth"
459
+ )
460
  self.vit5_tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
461
+ self.model = T5ForConditionalGeneration.from_pretrained(
462
+ "truong-xuan-linh/VQA-vit5",
463
+ revision=Config.revision,
464
+ output_attentions=True,
465
+ )
466
  self.model.to(Config.device)
467
 
468
  self.vit = ViT()
469
  self.ocr = OCR()
470
 
471
  def get_inputs(self, image_dir: str, question: str):
472
+ # VIT
473
  image_feature, image_mask = self.vit.extraction(image_dir)
474
 
475
  ocr_content, groups_box, paragraph_boxes = self.ocr.extraction(image_dir)
476
  print("Input: ", question + " " + ocr_content)
477
+ # VIT5
478
+ input_ = self.vit5_tokenizer(
479
+ question + " " + ocr_content,
480
+ padding="max_length",
481
+ truncation=True,
482
+ max_length=Config.question_maxlen + Config.ocr_maxlen,
483
+ return_tensors="pt",
484
+ )
485
 
486
  input_ids = input_.input_ids
487
  attention_mask = input_.attention_mask
488
  mask = torch.cat((attention_mask, image_mask), 1)
489
  return {
490
+ "input_ids": input_ids,
491
+ "attention_mask": mask,
492
+ "images_embeds": image_feature,
493
+ }
494
 
495
+ def inference(self, image_dir: str, question: str, explain: bool = False):
496
  inputs = self.get_inputs(image_dir, question)
497
  with torch.no_grad():
498
  input_ids = inputs["input_ids"]
499
  attention_mask = inputs["attention_mask"]
500
  images_embeds = inputs["images_embeds"]
501
  generated_ids = self.model.generate(
502
+ input_ids=input_ids,
503
+ attention_mask=attention_mask,
504
+ images_embeds=images_embeds,
505
+ num_beams=2,
506
+ max_length=Config.answer_maxlen,
507
+ )
508
+
509
+ pred_answer = self.vit5_tokenizer.decode(
510
+ generated_ids[0], skip_special_tokens=True
511
+ )
512
+ if not explain:
513
+ return pred_answer, None, None
514
 
515
+ with self.vit5_tokenizer.as_target_tokenizer():
516
+ decoder_input_ids = self.vit5_tokenizer(
517
+ pred_answer, return_tensors="pt", add_special_tokens=True
518
+ ).input_ids
519
 
520
+ with torch.no_grad():
521
+ outputs = self.model(
522
+ input_ids=input_ids,
523
+ attention_mask=attention_mask,
524
+ images_embeds=images_embeds,
525
+ decoder_input_ids=decoder_input_ids,
526
+ )
527
+
528
+ encoder_text = self.vit5_tokenizer.convert_ids_to_tokens(input_ids[0])
529
+ decoder_text = self.vit5_tokenizer.convert_ids_to_tokens(decoder_input_ids[0])
530
+ while "<pad>" in encoder_text:
531
+ encoder_text.remove("<pad>")
532
+
533
+ text_encoder_attentions = [
534
+ att[:, :, : len(encoder_text), : len(encoder_text)]
535
+ for att in outputs.encoder_attentions
536
+ ]
537
+ text_cross_attentions = [
538
+ att[:, :, :, : len(encoder_text)] for att in outputs.cross_attentions
539
+ ]
540
+
541
+ html_output = head_view(
542
+ encoder_attention=text_encoder_attentions,
543
+ decoder_attention=outputs.decoder_attentions,
544
+ cross_attention=text_cross_attentions,
545
+ encoder_tokens=encoder_text[: len(encoder_text)],
546
+ decoder_tokens=decoder_text,
547
+ # display_mode="light",
548
+ html_action="return",
549
+ )
550
+
551
+ img = Image.open(image_dir).convert("RGB")
552
+ image_dirs = []
553
+
554
+ for i in range(len(outputs.cross_attentions[:1])):
555
+ image_dir = f"visualization/test_image_visualize_{i}.jpg"
556
+ image_dirs.append(image_dir)
557
+ attention_plot = np.mean(
558
+ outputs.cross_attentions[i][0, :, :, -197:].detach().numpy(), axis=0
559
+ )
560
+ plot_attention(img, decoder_text, attention_plot, image_dir)
561
+ return pred_answer, html_output.data, image_dirs
src/ocr.py CHANGED
@@ -6,74 +6,80 @@ import requests
6
  import numpy as np
7
  from PIL import Image, ImageTransform
8
 
 
9
  class OCRDetector:
10
- def __init__(self) -> None:
11
- self.paddle_ocr = PaddleOCR(lang='en',
12
- use_angle_cls=False,
13
- use_gpu=True if Config.device == "cpu" else False,
14
- show_log=False )
15
- # config['weights'] = './weights/transformerocr.pth'
 
 
16
 
17
- vietocr_config = Cfg.load_config_from_name('vgg_transformer')
18
- vietocr_config['weights'] = Config.ocr_path
19
- vietocr_config['cnn']['pretrained']=False
20
- vietocr_config['device'] = Config.device
21
- vietocr_config['predictor']['beamsearch']=False
22
- self.viet_ocr = Predictor(vietocr_config)
23
 
24
- def find_box(self, image):
25
- '''Xác định box dựa vào mô hình paddle_ocr'''
26
- result = self.paddle_ocr.ocr(image, cls = False, rec=False)
27
- result = result[0]
28
- # Extracting detected components
29
- boxes = result #[res[0] for res in result]
30
- boxes = np.array(boxes).astype(int)
31
 
32
- # scores = [res[1][1] for res in result]
33
- return boxes
34
 
35
- def cut_image_polygon(self, image, box):
36
- (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
37
- w = x2 - x1
38
- h = y4 - y1
39
- scl = h//7
40
- new_box = [max(x1-scl,0), max(y1 - scl, 0)], [x2+scl, y2-scl], [x3+scl, y3+scl], [x4-scl, y4+scl]
41
- (x1, y1), (x2, y2), (x3, y3), (x4, y4) = new_box
42
- # Define 8-tuple with x,y coordinates of top-left, bottom-left, bottom-right and top-right corners and apply
43
- transform = [x1, y1, x4, y4, x3, y3, x2, y2]
44
- result = image.transform((w,h), ImageTransform.QuadTransform(transform))
45
- return result
 
 
 
 
 
46
 
47
- def vietnamese_text(self, boxes, image):
48
- '''Xác định text dựa vào mô hình viet_ocr'''
49
- results = []
50
- for box in boxes:
51
- try:
52
- cut_image = self.cut_image_polygon(image, box)
53
- # cut_image = Image.fromarray(np.uint8(cut_image))
54
- text, score = self.viet_ocr.predict(cut_image, return_prob=True)
55
- if score > Config.vietocr_threshold:
56
- results.append({"text": text,
57
- "score": score,
58
- "box": box})
59
- except:
60
- continue
61
- return results
62
 
63
- #Merge
64
- def text_detector(self, image_path):
65
- if image_path.startswith("https://"):
66
- image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
67
- else:
68
- image = Image.open(image_path).convert("RGB")
69
- # np_image = np.array(image)
70
 
71
- boxes = self.find_box(image_path)
72
- if not boxes.any():
73
- return None
74
 
75
- results = self.vietnamese_text(boxes, image)
76
- if results != []:
77
- return results
78
- else:
79
- return None
 
6
  import numpy as np
7
  from PIL import Image, ImageTransform
8
 
9
+
10
  class OCRDetector:
11
+ def __init__(self) -> None:
12
+ self.paddle_ocr = PaddleOCR(
13
+ lang="en",
14
+ use_angle_cls=False,
15
+ use_gpu=True if Config.device == "cpu" else False,
16
+ show_log=False,
17
+ )
18
+ # config['weights'] = './weights/transformerocr.pth'
19
 
20
+ vietocr_config = Cfg.load_config_from_name("vgg_transformer")
21
+ vietocr_config["weights"] = Config.ocr_path
22
+ vietocr_config["cnn"]["pretrained"] = False
23
+ vietocr_config["device"] = Config.device
24
+ vietocr_config["predictor"]["beamsearch"] = False
25
+ self.viet_ocr = Predictor(vietocr_config)
26
 
27
+ def find_box(self, image):
28
+ """Xác định box dựa vào mô hình paddle_ocr"""
29
+ result = self.paddle_ocr.ocr(image, cls=False, rec=False)
30
+ result = result[0]
31
+ # Extracting detected components
32
+ boxes = result # [res[0] for res in result]
33
+ boxes = np.array(boxes).astype(int)
34
 
35
+ # scores = [res[1][1] for res in result]
36
+ return boxes
37
 
38
+ def cut_image_polygon(self, image, box):
39
+ (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
40
+ w = x2 - x1
41
+ h = y4 - y1
42
+ scl = h // 7
43
+ new_box = (
44
+ [max(x1 - scl, 0), max(y1 - scl, 0)],
45
+ [x2 + scl, y2 - scl],
46
+ [x3 + scl, y3 + scl],
47
+ [x4 - scl, y4 + scl],
48
+ )
49
+ (x1, y1), (x2, y2), (x3, y3), (x4, y4) = new_box
50
+ # Define 8-tuple with x,y coordinates of top-left, bottom-left, bottom-right and top-right corners and apply
51
+ transform = [x1, y1, x4, y4, x3, y3, x2, y2]
52
+ result = image.transform((w, h), ImageTransform.QuadTransform(transform))
53
+ return result
54
 
55
+ def vietnamese_text(self, boxes, image):
56
+ """Xác định text dựa vào mô hình viet_ocr"""
57
+ results = []
58
+ for box in boxes:
59
+ try:
60
+ cut_image = self.cut_image_polygon(image, box)
61
+ # cut_image = Image.fromarray(np.uint8(cut_image))
62
+ text, score = self.viet_ocr.predict(cut_image, return_prob=True)
63
+ if score > Config.vietocr_threshold:
64
+ results.append({"text": text, "score": score, "box": box})
65
+ except:
66
+ continue
67
+ return results
 
 
68
 
69
+ # Merge
70
+ def text_detector(self, image_path):
71
+ if image_path.startswith("https://"):
72
+ image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
73
+ else:
74
+ image = Image.open(image_path).convert("RGB")
75
+ # np_image = np.array(image)
76
 
77
+ boxes = self.find_box(image_path)
78
+ if not boxes.any():
79
+ return None
80
 
81
+ results = self.vietnamese_text(boxes, image)
82
+ if results != []:
83
+ return results
84
+ else:
85
+ return None
utils/config.py CHANGED
@@ -10,4 +10,4 @@ class Config:
10
  ocr_maxobj = 10000
11
  num_ocr = 32
12
  num_beams = 3
13
- revision = "version_2_with_extra_id_0"
 
10
  ocr_maxobj = 10000
11
  num_ocr = 32
12
  num_beams = 3
13
+ revision = "version_2_with_extra_id_0"
visualization/.gitkeep ADDED
File without changes