vilarin commited on
Commit
c2b3f2d
·
verified ·
1 Parent(s): 5989c67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -20,7 +20,7 @@ class ModelWrapper:
20
  super().__init__()
21
  torch.set_grad_enabled(False)
22
 
23
- self.DTYPE = getattr(torch, precision)
24
  self.device = accelerator.device
25
 
26
  self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
@@ -49,7 +49,7 @@ class ModelWrapper:
49
 
50
  def create_generator(self, model_id, checkpoint_path):
51
  generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
52
- state_dict = torch.load(checkpoint_path, map_location="cuda")
53
  generator.load_state_dict(state_dict, strict=True)
54
  generator.requires_grad_(False)
55
  return generator
 
20
  super().__init__()
21
  torch.set_grad_enabled(False)
22
 
23
+ self.DTYPE = torch.float16
24
  self.device = accelerator.device
25
 
26
  self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
 
49
 
50
  def create_generator(self, model_id, checkpoint_path):
51
  generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
52
+ state_dict = torch.load(checkpoint_path)
53
  generator.load_state_dict(state_dict, strict=True)
54
  generator.requires_grad_(False)
55
  return generator