app update
Browse files
app.py
CHANGED
@@ -155,6 +155,9 @@ class BigramLanguageModel(nn.Module):
|
|
155 |
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
|
156 |
self.ln_f = nn.LayerNorm(n_embd)
|
157 |
self.lm_head = nn.Linear(n_embd, self.vocab_size)
|
|
|
|
|
|
|
158 |
|
159 |
def forward(self, idx, targets=None):
|
160 |
B, T = idx.shape
|
@@ -236,7 +239,7 @@ wikipedia_model.eval() # Set the model to evaluation mode
|
|
236 |
|
237 |
def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
|
238 |
if prompt:
|
239 |
-
context = torch.tensor(encode(prompt), dtype=torch.long, device=device).view(1, -1)
|
240 |
else:
|
241 |
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
242 |
text_output = decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
|
@@ -245,7 +248,7 @@ def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
|
|
245 |
|
246 |
def generate_wikipedia_outputs(prompt=None, max_new_tokens=2000):
|
247 |
if prompt:
|
248 |
-
context = torch.tensor(encode(prompt), dtype=torch.long, device=device).view(1, -1)
|
249 |
else:
|
250 |
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
251 |
text_output = decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
|
|
|
155 |
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
|
156 |
self.ln_f = nn.LayerNorm(n_embd)
|
157 |
self.lm_head = nn.Linear(n_embd, self.vocab_size)
|
158 |
+
self.encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
|
159 |
+
self.decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
|
160 |
+
|
161 |
|
162 |
def forward(self, idx, targets=None):
|
163 |
B, T = idx.shape
|
|
|
239 |
|
240 |
def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
|
241 |
if prompt:
|
242 |
+
context = torch.tensor(shakespeare_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
|
243 |
else:
|
244 |
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
245 |
text_output = decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
|
|
|
248 |
|
249 |
def generate_wikipedia_outputs(prompt=None, max_new_tokens=2000):
|
250 |
if prompt:
|
251 |
+
context = torch.tensor(wikipedia_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
|
252 |
else:
|
253 |
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
254 |
text_output = decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
|