Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -140,7 +140,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
140 |
self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
|
141 |
self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
|
142 |
|
143 |
-
self.pipe = load_model(model_key, self.sd_version, self.device, self.dtype)
|
144 |
|
145 |
self.pipe.load_lora_weights(lora_key, weight_name=lora_weight_name, adapter_name='lcm')
|
146 |
self.pipe.fuse_lora(
|
@@ -380,7 +380,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
380 |
"""
|
381 |
question = 'Question: What are in the image? Answer:'
|
382 |
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
383 |
-
out = self.i2t_model.generate(**{k: v.to(self.i2t_model.device) for k, v in inputs.items()}
|
384 |
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
385 |
return prompt
|
386 |
|
@@ -1121,12 +1121,15 @@ class StreamMultiDiffusion(nn.Module):
|
|
1121 |
else:
|
1122 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
1123 |
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
|
|
|
|
|
|
1130 |
print('222222222222222', model_pred.dtype)
|
1131 |
|
1132 |
if self.bootstrap_steps[0] > 0:
|
|
|
140 |
self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
|
141 |
self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
|
142 |
|
143 |
+
self.pipe = load_model(model_key, self.sd_version, self.device, self.dtype).to(dtype=self.dtype)
|
144 |
|
145 |
self.pipe.load_lora_weights(lora_key, weight_name=lora_weight_name, adapter_name='lcm')
|
146 |
self.pipe.fuse_lora(
|
|
|
380 |
"""
|
381 |
question = 'Question: What are in the image? Answer:'
|
382 |
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
383 |
+
out = self.i2t_model.generate(**{k: v.to(self.i2t_model.device) for k, v in inputs.items()}) #, max_new_tokens=75)
|
384 |
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
385 |
return prompt
|
386 |
|
|
|
1121 |
else:
|
1122 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
1123 |
|
1124 |
+
try:
|
1125 |
+
model_pred = self.unet(
|
1126 |
+
x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
|
1127 |
+
t_list, # (B,)
|
1128 |
+
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
1129 |
+
return_dict=False,
|
1130 |
+
)[0] # (B, 4, h, w)
|
1131 |
+
except e:
|
1132 |
+
print(e)
|
1133 |
print('222222222222222', model_pred.dtype)
|
1134 |
|
1135 |
if self.bootstrap_steps[0] > 0:
|