rbattle commited on
Commit
0278fdb
1 Parent(s): 6a94432

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -33,16 +33,16 @@ import os
33
  import torch
34
  from transformers import AutoModelForCausalLM, AutoTokenizer
35
 
36
- model_name = 'VMware/open-llama-7B-open-instruct'
37
 
38
 
39
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
40
 
41
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype= torch.float16, device_map = 'sequential')
42
 
43
  prompt_template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:"
44
 
45
- prompt= 'Explain in simple terms how the attention mechanism of a transformer model works'
46
 
47
 
48
  inputt = prompt_template.format(instruction= prompt)
@@ -51,7 +51,7 @@ input_ids = tokenizer(inputt, return_tensors="pt").input_ids.to("cuda")
51
  output1 = model.generate(input_ids, max_length=512)
52
  input_length = input_ids.shape[1]
53
  output1 = output1[:, input_length:]
54
- output= tokenizer.decode(output1[0])
55
 
56
  print(output)
57
 
 
33
  import torch
34
  from transformers import AutoModelForCausalLM, AutoTokenizer
35
 
36
+ model_name = 'VMware/open-llama-7b-open-instruct'
37
 
38
 
39
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
40
 
41
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='sequential')
42
 
43
  prompt_template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:"
44
 
45
+ prompt = 'Explain in simple terms how the attention mechanism of a transformer model works'
46
 
47
 
48
  inputt = prompt_template.format(instruction= prompt)
 
51
  output1 = model.generate(input_ids, max_length=512)
52
  input_length = input_ids.shape[1]
53
  output1 = output1[:, input_length:]
54
+ output = tokenizer.decode(output1[0])
55
 
56
  print(output)
57