Amitontheweb commited on
Commit
bf0c3af
·
verified ·
1 Parent(s): b27477f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -10,7 +10,7 @@ import gradio as gr
10
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
13
- model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
14
 
15
 
16
  # Define functions
@@ -21,7 +21,7 @@ global chosen_strategy
21
  def generate(input_text, number_steps, number_beams, number_beam_groups, diversity_penalty, length_penalty, num_return_sequences, temperature, no_repeat_ngram_size, repetition_penalty, early_stopping, beam_temperature, top_p, top_k,penalty_alpha,top_p_box,top_k_box,strategy_selected,model_selected):
22
 
23
  chosen_strategy = strategy_selected
24
- inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
25
 
26
  if chosen_strategy == "Sampling":
27
 
@@ -47,7 +47,7 @@ def generate(input_text, number_steps, number_beams, number_beam_groups, diversi
47
  beam_temp_flag = beam_temperature
48
  early_stop_flag = early_stopping
49
 
50
- inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
51
  outputs = model.generate(
52
 
53
  **inputs,
@@ -82,7 +82,7 @@ def generate(input_text, number_steps, number_beams, number_beam_groups, diversi
82
  if number_beam_groups > number_beams:
83
  number_beams = number_beam_groups
84
 
85
- inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
86
  outputs = model.generate(
87
 
88
  **inputs,
@@ -130,12 +130,12 @@ def load_model(model_selected):
130
 
131
  if model_selected == "gpt2":
132
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
133
- model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id).to(torch_device)
134
  #print (model_selected + " loaded")
135
 
136
  if model_selected == "Gemma 2":
137
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
138
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b").to(torch_device)
139
 
140
 
141
 
 
10
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
13
+ model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
14
 
15
 
16
  # Define functions
 
21
  def generate(input_text, number_steps, number_beams, number_beam_groups, diversity_penalty, length_penalty, num_return_sequences, temperature, no_repeat_ngram_size, repetition_penalty, early_stopping, beam_temperature, top_p, top_k,penalty_alpha,top_p_box,top_k_box,strategy_selected,model_selected):
22
 
23
  chosen_strategy = strategy_selected
24
+ inputs = tokenizer(input_text, return_tensors="pt")
25
 
26
  if chosen_strategy == "Sampling":
27
 
 
47
  beam_temp_flag = beam_temperature
48
  early_stop_flag = early_stopping
49
 
50
+ inputs = tokenizer(input_text, return_tensors="pt")
51
  outputs = model.generate(
52
 
53
  **inputs,
 
82
  if number_beam_groups > number_beams:
83
  number_beams = number_beam_groups
84
 
85
+ inputs = tokenizer(input_text, return_tensors="pt")
86
  outputs = model.generate(
87
 
88
  **inputs,
 
130
 
131
  if model_selected == "gpt2":
132
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
133
+ model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
134
  #print (model_selected + " loaded")
135
 
136
  if model_selected == "Gemma 2":
137
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
138
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
139
 
140
 
141