ironjr commited on
Commit
64f9101
1 Parent(s): 4a8bd93

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +11 -8
model.py CHANGED
@@ -140,7 +140,7 @@ class StreamMultiDiffusion(nn.Module):
140
  self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
141
  self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
142
 
143
- self.pipe = load_model(model_key, self.sd_version, self.device, self.dtype)
144
 
145
  self.pipe.load_lora_weights(lora_key, weight_name=lora_weight_name, adapter_name='lcm')
146
  self.pipe.fuse_lora(
@@ -380,7 +380,7 @@ class StreamMultiDiffusion(nn.Module):
380
  """
381
  question = 'Question: What are in the image? Answer:'
382
  inputs = self.i2t_processor(image, question, return_tensors='pt')
383
- out = self.i2t_model.generate(**{k: v.to(self.i2t_model.device) for k, v in inputs.items()}, max_new_tokens=77)
384
  prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
385
  return prompt
386
 
@@ -1121,12 +1121,15 @@ class StreamMultiDiffusion(nn.Module):
1121
  else:
1122
  x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
1123
 
1124
- model_pred = self.pipe.unet(
1125
- x_t_latent_plus_uc.to(self.pipe.unet.dtype), # (B, 4, h, w)
1126
- t_list, # (B,)
1127
- encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1128
- return_dict=False,
1129
- )[0] # (B, 4, h, w)
 
 
 
1130
  print('222222222222222', model_pred.dtype)
1131
 
1132
  if self.bootstrap_steps[0] > 0:
 
140
  self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
141
  self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
142
 
143
+ self.pipe = load_model(model_key, self.sd_version, self.device, self.dtype).to(dtype=self.dtype)
144
 
145
  self.pipe.load_lora_weights(lora_key, weight_name=lora_weight_name, adapter_name='lcm')
146
  self.pipe.fuse_lora(
 
380
  """
381
  question = 'Question: What are in the image? Answer:'
382
  inputs = self.i2t_processor(image, question, return_tensors='pt')
383
+ out = self.i2t_model.generate(**{k: v.to(self.i2t_model.device) for k, v in inputs.items()}) #, max_new_tokens=75)
384
  prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
385
  return prompt
386
 
 
1121
  else:
1122
  x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
1123
 
1124
+ try:
1125
+ model_pred = self.unet(
1126
+ x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
1127
+ t_list, # (B,)
1128
+ encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1129
+ return_dict=False,
1130
+ )[0] # (B, 4, h, w)
1131
+ except e:
1132
+ print(e)
1133
  print('222222222222222', model_pred.dtype)
1134
 
1135
  if self.bootstrap_steps[0] > 0: