pr0mila commited on
Commit
09a32cd
·
1 Parent(s): 2a320b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -26,17 +26,19 @@ import gradio as gr
26
  from transformers import GPTJForCausalLM
27
  import torch
28
 
29
- model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=False)
30
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
31
 
32
  def get_result_with_bloom(text):
33
- context = text
34
-
35
- input_ids = tokenizer(context, return_tensors="pt").input_ids
36
- gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,)
37
- gen_text = tokenizer.batch_decode(gen_tokens)[0]
38
- return gen_text
39
-
 
 
40
 
41
 
42
 
 
26
  from transformers import GPTJForCausalLM
27
  import torch
28
 
29
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
30
+ model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
31
 
32
  def get_result_with_bloom(text):
33
+ result_length = 200
34
+ inputs1 = tokenizer(text, return_tensors="pt")
35
+ output1 = tokenizer.decode(model.generate(inputs1["input_ids"],
36
+ max_length=result_length,
37
+ num_beams=2,
38
+ no_repeat_ngram_size=2,
39
+ early_stopping=True
40
+ )[0])
41
+ return output1
42
 
43
 
44