Azure99 commited on
Commit
f6bc3e9
1 Parent(s): c1f8c68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -66,7 +66,8 @@ def generate(
66
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
67
  max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
68
  llm_result = llm.generate(**generation_kwargs)
69
- llm_result = BOT_PREFIX + tokenizer.decode(llm_result.cpu()[0], skip_special_tokens=True)
 
70
  print(llm_result)
71
  expanded_description = json.loads(llm_result)["expanded_description"]
72
  print(expanded_description)
 
66
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
67
  max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
68
  llm_result = llm.generate(**generation_kwargs)
69
+ llm_result = llm_result.cpu()[0][len(input_ids):]
70
+ llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True)
71
  print(llm_result)
72
  expanded_description = json.loads(llm_result)["expanded_description"]
73
  print(expanded_description)