mkthoma commited on
Commit
9798ca5
1 Parent(s): 97b7e88

app update

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -155,6 +155,9 @@ class BigramLanguageModel(nn.Module):
155
  self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
156
  self.ln_f = nn.LayerNorm(n_embd)
157
  self.lm_head = nn.Linear(n_embd, self.vocab_size)
 
 
 
158
 
159
  def forward(self, idx, targets=None):
160
  B, T = idx.shape
@@ -236,7 +239,7 @@ wikipedia_model.eval() # Set the model to evaluation mode
236
 
237
  def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
238
  if prompt:
239
- context = torch.tensor(encode(prompt), dtype=torch.long, device=device).view(1, -1)
240
  else:
241
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
242
  text_output = decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
@@ -245,7 +248,7 @@ def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
245
 
246
  def generate_wikipedia_outputs(prompt=None, max_new_tokens=2000):
247
  if prompt:
248
- context = torch.tensor(encode(prompt), dtype=torch.long, device=device).view(1, -1)
249
  else:
250
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
251
  text_output = decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
 
155
  self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
156
  self.ln_f = nn.LayerNorm(n_embd)
157
  self.lm_head = nn.Linear(n_embd, self.vocab_size)
158
+ self.encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
159
+ self.decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
160
+
161
 
162
  def forward(self, idx, targets=None):
163
  B, T = idx.shape
 
239
 
240
  def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
241
  if prompt:
242
+ context = torch.tensor(shakespeare_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
243
  else:
244
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
245
  text_output = decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
 
248
 
249
  def generate_wikipedia_outputs(prompt=None, max_new_tokens=2000):
250
  if prompt:
251
+ context = torch.tensor(wikipedia_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
252
  else:
253
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
254
  text_output = decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())