ruslanruslanruslan commited on
Commit
aa9fefe
1 Parent(s): 7181d11

Borgesian caching added, truncation fixed

Browse files
Files changed (1) hide show
  1. pages/Borgesian.py +9 -4
pages/Borgesian.py CHANGED
@@ -3,16 +3,21 @@ import transformers
3
  import torch
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
 
6
- borgesian = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2', output_attentions = False, output_hidden_states = False)
7
- borgesian.load_state_dict(torch.load('borgesian_weights.pt', map_location=torch.device('cpu')))
8
- tokenizer = GPT2Tokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")
 
 
 
 
 
9
  borgesian.to('cpu')
10
  borgesian.eval()
11
 
12
  def generate_response(text, temperature, length, top_p):
13
  input_ids = tokenizer.encode(text, return_tensors="pt")
14
  with torch.no_grad():
15
- out = borgesian.generate(input_ids, do_sample=True, num_beams=2, temperature=float(temperature), top_p=float(top_p), max_length=length)
16
  generated_text = list(map(tokenizer.decode, out))[0]
17
  st.write(generated_text)
18
 
 
3
  import torch
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
 
6
+ @st.cache_resource
7
+ def load_model():
8
+ borgesian = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2', output_attentions = False, output_hidden_states = False)
9
+ borgesian.load_state_dict(torch.load('borgesian_weights.pt', map_location=torch.device('cpu')))
10
+ tokenizer = GPT2Tokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")
11
+ return borgesian, tokenizer
12
+
13
+ borgesian, tokenizer = load_model()
14
  borgesian.to('cpu')
15
  borgesian.eval()
16
 
17
  def generate_response(text, temperature, length, top_p):
18
  input_ids = tokenizer.encode(text, return_tensors="pt")
19
  with torch.no_grad():
20
+ out = borgesian.generate(input_ids, do_sample=True, num_beams=2, temperature=float(temperature), top_p=float(top_p), max_length=length, truncate=".")
21
  generated_text = list(map(tokenizer.decode, out))[0]
22
  st.write(generated_text)
23