x-lai commited on
Commit
da654fa
·
1 Parent(s): bfbc812

Release training script

Browse files

Former-commit-id: 829fb3eb44fa87640c19293156b9256bf82158ab

Files changed (1) hide show
  1. utils/reason_seg_dataset.py +15 -19
utils/reason_seg_dataset.py CHANGED
@@ -55,25 +55,6 @@ class ReasonSegDataset(torch.utils.data.Dataset):
55
  self.long_question_list = LONG_QUESTION_LIST
56
  self.answer_list = ANSWER_LIST
57
 
58
- if explanatory != -1:
59
- self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
60
-
61
- if explanatory != -1:
62
- self.img_to_explanation = {}
63
- for sub_data in [
64
- "train.json",
65
- ]:
66
- with open(
67
- os.path.join(base_image_dir, "reason_seg", "explanatory", sub_data)
68
- ) as f:
69
- items = json.load(f)
70
- for item in items:
71
- img_name = item["image_path"].split("/")[-1]
72
- self.img_to_explanation[img_name] = {
73
- "query": item["query"],
74
- "outputs": item["outputs"],
75
- }
76
-
77
  reason_seg_data, splits = reason_seg_data.split("|")
78
  splits = splits.split("_")
79
  images = []
@@ -87,6 +68,21 @@ class ReasonSegDataset(torch.utils.data.Dataset):
87
  jsons = [path.replace(".jpg", ".json") for path in images]
88
  self.reason_seg_data = (images, jsons)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def __len__(self):
91
  return self.samples_per_epoch
92
 
 
55
  self.long_question_list = LONG_QUESTION_LIST
56
  self.answer_list = ANSWER_LIST
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  reason_seg_data, splits = reason_seg_data.split("|")
59
  splits = splits.split("_")
60
  images = []
 
68
  jsons = [path.replace(".jpg", ".json") for path in images]
69
  self.reason_seg_data = (images, jsons)
70
 
71
+ if explanatory != -1:
72
+ self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
73
+ self.img_to_explanation = {}
74
+ with open(
75
+ os.path.join(base_image_dir, "reason_seg", reason_seg_data, "explanatory", "train.json")
76
+ ) as f:
77
+ items = json.load(f)
78
+ for item in items:
79
+ img_name = item["image_path"].split("/")[-1]
80
+ self.img_to_explanation[img_name] = {
81
+ "query": item["query"],
82
+ "outputs": item["outputs"],
83
+ }
84
+
85
+
86
  def __len__(self):
87
  return self.samples_per_epoch
88