rynmurdock commited on
Commit
93ebe82
1 Parent(s): 3b945c7
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -134,7 +134,7 @@ pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt
134
  processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
135
 
136
 
137
-
138
  def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
139
  inputs_embeds = pali.get_input_embeddings()(input_ids)
140
  selected_image_feature = image_outputs.to(dtype).to(device)
@@ -148,7 +148,7 @@ def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None
148
  return inputs_embeds
149
 
150
 
151
-
152
  def generate_pali(user_emb):
153
  prompt = 'caption en'
154
  model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
@@ -540,8 +540,7 @@ def encode_space(x):
540
  im_emb, _ = pipe.encode_image(
541
  image, DEVICE, 1, output_hidden_state
542
  )
543
-
544
-
545
  im = torchvision.transforms.ToTensor()(x).unsqueeze(0)
546
  im = torch.nn.functional.interpolate(im, (224, 224))
547
  im = (im - .5) * 2
 
134
  processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
135
 
136
 
137
+ @spaces.GPU()
138
  def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
139
  inputs_embeds = pali.get_input_embeddings()(input_ids)
140
  selected_image_feature = image_outputs.to(dtype).to(device)
 
148
  return inputs_embeds
149
 
150
 
151
+ @spaces.GPU()
152
  def generate_pali(user_emb):
153
  prompt = 'caption en'
154
  model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
 
540
  im_emb, _ = pipe.encode_image(
541
  image, DEVICE, 1, output_hidden_state
542
  )
543
+
 
544
  im = torchvision.transforms.ToTensor()(x).unsqueeze(0)
545
  im = torch.nn.functional.interpolate(im, (224, 224))
546
  im = (im - .5) * 2