momergul commited on
Commit
5f8e458
1 Parent(s): 18e7d92

Tweaked inference

Browse files
Files changed (1) hide show
  1. app.py +40 -21
app.py CHANGED
@@ -23,7 +23,7 @@ css="""
23
  def initialize_game() -> List[List[str]]:
24
  context_dicts = [generate_complete_game() for _ in range(2)]
25
 
26
- roles = ["speaker"] * 3 + ["listener"] * 3
27
  speaker_images = []
28
  listener_images = []
29
  targets = []
@@ -36,46 +36,64 @@ def initialize_game() -> List[List[str]]:
36
 
37
  return list(zip(speaker_images, listener_images, targets, roles))
38
 
39
- @spaces.GPU
40
  def get_model_response(
41
  model, adapter_name, processor, index_to_token, role: str,
42
  image_paths: List[str], user_message: str = "", target_image: str = ""
43
  ) -> str:
44
  model.model.set_adapter(adapter_name)
45
- print(model.model.active_adapter)
46
  if role == "speaker":
47
  img_dir = "tangram_pngs"
 
48
  input_tokens, attn_mask, images, image_attn_mask, label = joint_speaker_input(
49
  processor, image_paths, target_image, model.get_listener().device
50
  )
51
- print("Hi")
52
- with torch.no_grad():
53
- image_paths = [image_paths]
54
- captions, _, _, _, _ = model.generate(
55
- images, input_tokens, attn_mask, image_attn_mask, label,
56
- image_paths, processor, img_dir, index_to_token,
57
- max_steps=30, sampling_type="nucleus", temperature=0.7,
58
- top_k=50, top_p=1, repetition_penalty=1, num_samples=5
59
- )
60
- print("There")
61
  response = captions[0]
62
  else: # listener
 
63
  images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens, s_attn_mask, \
64
  s_image_attn_mask, s_target_mask, s_target_label = joint_listener_input(
65
  processor, image_paths, user_message, model.get_listener().device
66
  )
67
 
68
- with torch.no_grad():
69
- # Forward
70
- _, _, joint_log_probs = model.comprehension_side([
71
- images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
72
- s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label,
73
- ])
74
- target_idx = joint_log_probs[0].argmax().item()
75
- response = image_paths[target_idx]
76
 
77
  return response
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def interaction(model, processor, index_to_token, model_iteration: str) -> Tuple[List[str], List[str]]:
80
  image_role_pairs = initialize_game()
81
  conversation = []
@@ -195,6 +213,7 @@ def create_app():
195
  processor = get_processor()
196
  index_to_token = get_index_to_token()
197
 
 
198
  def start_interaction(model_iteration):
199
  if model_iteration is None:
200
  return [], "Please select a model iteration.", "", "", "", gr.update(interactive=False), \
 
23
  def initialize_game() -> List[List[str]]:
24
  context_dicts = [generate_complete_game() for _ in range(2)]
25
 
26
+ roles = ["listener"] * 3 + ["speaker"] * 3
27
  speaker_images = []
28
  listener_images = []
29
  targets = []
 
36
 
37
  return list(zip(speaker_images, listener_images, targets, roles))
38
 
 
39
  def get_model_response(
40
  model, adapter_name, processor, index_to_token, role: str,
41
  image_paths: List[str], user_message: str = "", target_image: str = ""
42
  ) -> str:
43
  model.model.set_adapter(adapter_name)
 
44
  if role == "speaker":
45
  img_dir = "tangram_pngs"
46
+ print("Starting processing")
47
  input_tokens, attn_mask, images, image_attn_mask, label = joint_speaker_input(
48
  processor, image_paths, target_image, model.get_listener().device
49
  )
50
+ image_paths = [image_paths]
51
+ print("Starting inference")
52
+ captions = get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths,
53
+ processor, img_dir, index_to_token)
54
+ print("Done")
 
 
 
 
 
55
  response = captions[0]
56
  else: # listener
57
+ print("Starting processing")
58
  images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens, s_attn_mask, \
59
  s_image_attn_mask, s_target_mask, s_target_label = joint_listener_input(
60
  processor, image_paths, user_message, model.get_listener().device
61
  )
62
 
63
+ print("Starting inference")
64
+ response = get_listener_response(
65
+ model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
66
+ s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths
67
+ )
68
+ print("Done")
 
 
69
 
70
  return response
71
 
72
+ @spaces.GPU(duration=20)
73
+ def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token):
74
+ model = model.cuda()
75
+ with torch.no_grad():
76
+ captions, _, _, _, _ = model.generate(
77
+ images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(),
78
+ image_paths, processor, img_dir, index_to_token,
79
+ max_steps=30, sampling_type="nucleus", temperature=0.7,
80
+ top_k=50, top_p=1, repetition_penalty=1, num_samples=5
81
+ )
82
+ return captions
83
+
84
+ @spaces.GPU(duration=20)
85
+ def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
86
+ s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths):
87
+ model = model.cuda()
88
+ with torch.no_grad():
89
+ _, _, joint_log_probs = model.comprehension_side([
90
+ images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token,
91
+ s_input_tokens.cuda(), s_attn_mask.cuda(), s_image_attn_mask.cuda(), s_target_mask.cuda(), s_target_label.cuda(),
92
+ ])
93
+ target_idx = joint_log_probs[0].argmax().item()
94
+ response = image_paths[target_idx]
95
+ return response
96
+
97
  def interaction(model, processor, index_to_token, model_iteration: str) -> Tuple[List[str], List[str]]:
98
  image_role_pairs = initialize_game()
99
  conversation = []
 
213
  processor = get_processor()
214
  index_to_token = get_index_to_token()
215
 
216
+ print("Heyo!")
217
  def start_interaction(model_iteration):
218
  if model_iteration is None:
219
  return [], "Please select a model iteration.", "", "", "", gr.update(interactive=False), \