app update
Browse files
app.py
CHANGED
@@ -155,8 +155,8 @@ class BigramLanguageModel(nn.Module):
|
|
155 |
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
|
156 |
self.ln_f = nn.LayerNorm(n_embd)
|
157 |
self.lm_head = nn.Linear(n_embd, self.vocab_size)
|
158 |
-
self.encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
|
159 |
-
self.decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
|
160 |
|
161 |
|
162 |
def forward(self, idx, targets=None):
|
@@ -242,7 +242,7 @@ def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
|
|
242 |
context = torch.tensor(shakespeare_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
|
243 |
else:
|
244 |
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
245 |
-
text_output = decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
|
246 |
return text_output
|
247 |
|
248 |
|
@@ -251,7 +251,7 @@ def generate_wikipedia_outputs(prompt=None, max_new_tokens=2000):
|
|
251 |
context = torch.tensor(wikipedia_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
|
252 |
else:
|
253 |
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
254 |
-
text_output = decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
|
255 |
return text_output
|
256 |
|
257 |
|
|
|
155 |
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
|
156 |
self.ln_f = nn.LayerNorm(n_embd)
|
157 |
self.lm_head = nn.Linear(n_embd, self.vocab_size)
|
158 |
+
self.encode = lambda s: [self.stoi[c] for c in s] # encoder: take a string, output a list of integers
|
159 |
+
self.decode = lambda l: ''.join([self.itos[i] for i in l]) # decoder: take a list of integers, output a string
|
160 |
|
161 |
|
162 |
def forward(self, idx, targets=None):
|
|
|
242 |
context = torch.tensor(shakespeare_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
|
243 |
else:
|
244 |
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
245 |
+
text_output = shakespeare_model.decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
|
246 |
return text_output
|
247 |
|
248 |
|
|
|
251 |
context = torch.tensor(wikipedia_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
|
252 |
else:
|
253 |
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
254 |
+
text_output = wikipedia_model.decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
|
255 |
return text_output
|
256 |
|
257 |
|