Spaces:
Runtime error
Runtime error
guyyariv
commited on
Commit
•
2821e52
1
Parent(s):
56d047b
AudioTokenDemo
Browse files
app.py
CHANGED
@@ -35,7 +35,7 @@ class AudioTokenWrapper(torch.nn.Module):
|
|
35 |
)
|
36 |
|
37 |
checkpoint = torch.load(
|
38 |
-
'BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
|
39 |
cfg = BEATsConfig(checkpoint['cfg'])
|
40 |
self.aud_encoder = BEATs(cfg)
|
41 |
self.aud_encoder.load_state_dict(checkpoint['model'])
|
@@ -69,12 +69,12 @@ class AudioTokenWrapper(torch.nn.Module):
|
|
69 |
self.unet.set_attn_processor(lora_attn_procs)
|
70 |
self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
|
71 |
self.lora_layers.eval()
|
72 |
-
lora_layers_learned_embeds = '
|
73 |
self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
|
74 |
self.unet.load_attn_procs(lora_layers_learned_embeds)
|
75 |
|
76 |
self.embedder.eval()
|
77 |
-
embedder_learned_embeds = '
|
78 |
self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
|
79 |
|
80 |
self.placeholder_token = '<*>'
|
|
|
35 |
)
|
36 |
|
37 |
checkpoint = torch.load(
|
38 |
+
'models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
|
39 |
cfg = BEATsConfig(checkpoint['cfg'])
|
40 |
self.aud_encoder = BEATs(cfg)
|
41 |
self.aud_encoder.load_state_dict(checkpoint['model'])
|
|
|
69 |
self.unet.set_attn_processor(lora_attn_procs)
|
70 |
self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
|
71 |
self.lora_layers.eval()
|
72 |
+
lora_layers_learned_embeds = 'models/embedder_learned_embeds.bin'
|
73 |
self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
|
74 |
self.unet.load_attn_procs(lora_layers_learned_embeds)
|
75 |
|
76 |
self.embedder.eval()
|
77 |
+
embedder_learned_embeds = 'models/lora_layers_learned_embeds.bin'
|
78 |
self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
|
79 |
|
80 |
self.placeholder_token = '<*>'
|