injilashah commited on
Commit
a740fc3
·
verified ·
1 Parent(s): 97d06ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -10,7 +10,7 @@ b_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-1b1")#using small
10
  b_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b1",device_map = "auto")
11
 
12
  g_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b",token = hf_token)#using small paramerter version of model for faster inference on hf
13
- g_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b",token = hf_token)
14
 
15
  def Sentence_Commpletion(model_name, input):
16
 
@@ -18,11 +18,11 @@ def Sentence_Commpletion(model_name, input):
18
  if model_name == "Bloom":
19
  tokenizer, model = b_tokenizer, b_model
20
  inputss = tokenizer(input, return_tensors="pt")
21
- outputs = model.generate(inputss.input_ids, max_length=100, num_return_sequences=1)
22
  elif model_name == "Gemma":
23
  tokenizer, model = g_tokenizer, g_model
24
- input_ids = tokenizer(input, return_tensors="pt")
25
- outputs = model.generate(**input_ids, max_new_tokens=32)
26
  return tokenizer.decode(outputs[0])
27
 
28
 
 
10
  b_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b1",device_map = "auto")
11
 
12
  g_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b",token = hf_token)#using small paramerter version of model for faster inference on hf
13
+ g_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b",token = hf_token,device_map="auto")
14
 
15
  def Sentence_Commpletion(model_name, input):
16
 
 
18
  if model_name == "Bloom":
19
  tokenizer, model = b_tokenizer, b_model
20
  inputss = tokenizer(input, return_tensors="pt")
21
+ outputs = model.generate(inputss.input_ids, max_new_tokens=31, num_return_sequences=1)
22
  elif model_name == "Gemma":
23
  tokenizer, model = g_tokenizer, g_model
24
+ inputs= tokenizer(input, return_tensors="pt")
25
+ outputs = model.generate(inputs.input_ids, max_new_tokens=32)
26
  return tokenizer.decode(outputs[0])
27
 
28