corrected 'generate' demo code

#1
Files changed (1) hide show
  1. README.md +5 -3
README.md CHANGED
@@ -378,7 +378,8 @@ You will first need to install `transformers` and `accelerate` (just to ease the
378
  import torch
379
  from transformers import AutoModelForCausalLM, AutoTokenizer
380
 
381
- model = AutoModelForCausalLM.from_pretrained("argilla/notus-7b-v1", torch_dtype=torch.bfloat16, device_map="auto")
 
382
  tokenizer = AutoTokenizer.from_pretrained("argilla/notus-7b-v1")
383
 
384
  messages = [
@@ -388,9 +389,10 @@ messages = [
388
  },
389
  {"role": "user", "content": "What's the best data annotation company out there in your opinion?"},
390
  ]
391
- inputs = tokenizer.apply_chat_template(prompt, tokenize=True, return_tensors="pt", add_special_tokens=False, add_generation_prompt=True)
392
- outputs = model.generate(inputs, num_return_sequences=1, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
393
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
394
  ```
395
 
396
  ### Via `pipeline` method
 
378
  import torch
379
  from transformers import AutoModelForCausalLM, AutoTokenizer
380
 
381
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
382
+ model = AutoModelForCausalLM.from_pretrained("argilla/notus-7b-v1", torch_dtype=torch.bfloat16, device_map=device)
383
  tokenizer = AutoTokenizer.from_pretrained("argilla/notus-7b-v1")
384
 
385
  messages = [
 
389
  },
390
  {"role": "user", "content": "What's the best data annotation company out there in your opinion?"},
391
  ]
392
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True).to(device)
393
+ outputs = model.generate(inputs, num_return_sequences=1, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, eos_token_id=tokenizer.eos_token_id)
394
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
395
+ print(response)
396
  ```
397
 
398
  ### Via `pipeline` method