cahya commited on
Commit
02ae971
1 Parent(s): 4b0f2e1

add .to(0) for inputs

Browse files
Files changed (1) hide show
  1. app/api.py +1 -1
app/api.py CHANGED
@@ -108,7 +108,7 @@ async def text_generate(
108
  max_penalty = 1.5
109
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
110
  prompt = f"User: {text}\nAssistant: "
111
- input_ids = text_generator[model_name]["tokenizer"](prompt, return_tensors='pt').input_ids # .to(device)
112
  text_generator[model_name]["model"].eval()
113
  print("Generating text...")
114
  print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, "
 
108
  max_penalty = 1.5
109
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
110
  prompt = f"User: {text}\nAssistant: "
111
+ input_ids = text_generator[model_name]["tokenizer"](prompt, return_tensors='pt').input_ids.to(0)
112
  text_generator[model_name]["model"].eval()
113
  print("Generating text...")
114
  print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, "