Spaces:
Running
on
Zero
Running
on
Zero
fix past
Browse files- app.py +3 -1
- midi_model.py +3 -1
app.py
CHANGED
@@ -53,10 +53,11 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
|
|
53 |
cur_len = input_tensor.shape[1]
|
54 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
55 |
cache1 = DynamicCache()
|
|
|
56 |
with bar:
|
57 |
while cur_len < max_len:
|
58 |
end = [False] * batch_size
|
59 |
-
hidden = model.forward(input_tensor[:,
|
60 |
next_token_seq = None
|
61 |
event_names = [""] * batch_size
|
62 |
cache2 = DynamicCache()
|
@@ -110,6 +111,7 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
|
|
110 |
"constant", value=tokenizer.pad_id)
|
111 |
next_token_seq = next_token_seq.unsqueeze(1)
|
112 |
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
|
|
113 |
cur_len += 1
|
114 |
bar.update(1)
|
115 |
yield next_token_seq[:, 0].cpu().numpy()
|
|
|
53 |
cur_len = input_tensor.shape[1]
|
54 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
55 |
cache1 = DynamicCache()
|
56 |
+
past_len = 0
|
57 |
with bar:
|
58 |
while cur_len < max_len:
|
59 |
end = [False] * batch_size
|
60 |
+
hidden = model.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
|
61 |
next_token_seq = None
|
62 |
event_names = [""] * batch_size
|
63 |
cache2 = DynamicCache()
|
|
|
111 |
"constant", value=tokenizer.pad_id)
|
112 |
next_token_seq = next_token_seq.unsqueeze(1)
|
113 |
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
114 |
+
past_len = cur_len
|
115 |
cur_len += 1
|
116 |
bar.update(1)
|
117 |
yield next_token_seq[:, 0].cpu().numpy()
|
midi_model.py
CHANGED
@@ -160,10 +160,11 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
160 |
cur_len = input_tensor.shape[1]
|
161 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
162 |
cache1 = DynamicCache()
|
|
|
163 |
with bar:
|
164 |
while cur_len < max_len:
|
165 |
end = [False] * batch_size
|
166 |
-
hidden = self.forward(input_tensor[
|
167 |
next_token_seq = None
|
168 |
event_names = [""] * batch_size
|
169 |
cache2 = DynamicCache()
|
@@ -210,6 +211,7 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
210 |
"constant", value=tokenizer.pad_id)
|
211 |
next_token_seq = next_token_seq.unsqueeze(1)
|
212 |
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
|
|
213 |
cur_len += 1
|
214 |
bar.update(1)
|
215 |
|
|
|
160 |
cur_len = input_tensor.shape[1]
|
161 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
162 |
cache1 = DynamicCache()
|
163 |
+
past_len = 0
|
164 |
with bar:
|
165 |
while cur_len < max_len:
|
166 |
end = [False] * batch_size
|
167 |
+
hidden = self.forward(input_tensor[:,past_len:], cache=cache1)[:, -1]
|
168 |
next_token_seq = None
|
169 |
event_names = [""] * batch_size
|
170 |
cache2 = DynamicCache()
|
|
|
211 |
"constant", value=tokenizer.pad_id)
|
212 |
next_token_seq = next_token_seq.unsqueeze(1)
|
213 |
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
214 |
+
past_len = cur_len
|
215 |
cur_len += 1
|
216 |
bar.update(1)
|
217 |
|