Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -20,7 +20,7 @@ class ModelWrapper:
|
|
20 |
super().__init__()
|
21 |
torch.set_grad_enabled(False)
|
22 |
|
23 |
-
self.DTYPE =
|
24 |
self.device = accelerator.device
|
25 |
|
26 |
self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
|
@@ -49,7 +49,7 @@ class ModelWrapper:
|
|
49 |
|
50 |
def create_generator(self, model_id, checkpoint_path):
|
51 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
52 |
-
state_dict = torch.load(checkpoint_path
|
53 |
generator.load_state_dict(state_dict, strict=True)
|
54 |
generator.requires_grad_(False)
|
55 |
return generator
|
|
|
20 |
super().__init__()
|
21 |
torch.set_grad_enabled(False)
|
22 |
|
23 |
+
self.DTYPE = torch.float16
|
24 |
self.device = accelerator.device
|
25 |
|
26 |
self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
|
|
|
49 |
|
50 |
def create_generator(self, model_id, checkpoint_path):
|
51 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
52 |
+
state_dict = torch.load(checkpoint_path)
|
53 |
generator.load_state_dict(state_dict, strict=True)
|
54 |
generator.requires_grad_(False)
|
55 |
return generator
|