mkthoma commited on
Commit
753da27
1 Parent(s): 56138b3

app update

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -155,8 +155,8 @@ class BigramLanguageModel(nn.Module):
155
  self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
156
  self.ln_f = nn.LayerNorm(n_embd)
157
  self.lm_head = nn.Linear(n_embd, self.vocab_size)
158
- self.encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
159
- self.decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
160
 
161
 
162
  def forward(self, idx, targets=None):
@@ -242,7 +242,7 @@ def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
242
  context = torch.tensor(shakespeare_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
243
  else:
244
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
245
- text_output = decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
246
  return text_output
247
 
248
 
@@ -251,7 +251,7 @@ def generate_wikipedia_outputs(prompt=None, max_new_tokens=2000):
251
  context = torch.tensor(wikipedia_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
252
  else:
253
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
254
- text_output = decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
255
  return text_output
256
 
257
 
 
155
  self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
156
  self.ln_f = nn.LayerNorm(n_embd)
157
  self.lm_head = nn.Linear(n_embd, self.vocab_size)
158
+ self.encode = lambda s: [self.stoi[c] for c in s] # encoder: take a string, output a list of integers
159
+ self.decode = lambda l: ''.join([self.itos[i] for i in l]) # decoder: take a list of integers, output a string
160
 
161
 
162
  def forward(self, idx, targets=None):
 
242
  context = torch.tensor(shakespeare_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
243
  else:
244
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
245
+ text_output = shakespeare_model.decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
246
  return text_output
247
 
248
 
 
251
  context = torch.tensor(wikipedia_model.encode(prompt), dtype=torch.long, device=device).view(1, -1)
252
  else:
253
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
254
+ text_output = wikipedia_model.decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
255
  return text_output
256
 
257