guyyariv commited on
Commit
2821e52
1 Parent(s): 56d047b

AudioTokenDemo

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -35,7 +35,7 @@ class AudioTokenWrapper(torch.nn.Module):
35
  )
36
 
37
  checkpoint = torch.load(
38
- 'BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
39
  cfg = BEATsConfig(checkpoint['cfg'])
40
  self.aud_encoder = BEATs(cfg)
41
  self.aud_encoder.load_state_dict(checkpoint['model'])
@@ -69,12 +69,12 @@ class AudioTokenWrapper(torch.nn.Module):
69
  self.unet.set_attn_processor(lora_attn_procs)
70
  self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
71
  self.lora_layers.eval()
72
- lora_layers_learned_embeds = 'sd1_lora_qi_lora_layers_learned_embeds-40000.bin'
73
  self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
74
  self.unet.load_attn_procs(lora_layers_learned_embeds)
75
 
76
  self.embedder.eval()
77
- embedder_learned_embeds = 'sd1_lora_qi_learned_embeds-40000.bin'
78
  self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
79
 
80
  self.placeholder_token = '<*>'
 
35
  )
36
 
37
  checkpoint = torch.load(
38
+ 'models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
39
  cfg = BEATsConfig(checkpoint['cfg'])
40
  self.aud_encoder = BEATs(cfg)
41
  self.aud_encoder.load_state_dict(checkpoint['model'])
 
69
  self.unet.set_attn_processor(lora_attn_procs)
70
  self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
71
  self.lora_layers.eval()
72
+ lora_layers_learned_embeds = 'models/embedder_learned_embeds.bin'
73
  self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
74
  self.unet.load_attn_procs(lora_layers_learned_embeds)
75
 
76
  self.embedder.eval()
77
+ embedder_learned_embeds = 'models/lora_layers_learned_embeds.bin'
78
  self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
79
 
80
  self.placeholder_token = '<*>'