BAAI
/

BoyaWu10 commited on
Commit
07f6838
·
verified ·
1 Parent(s): 0d22253

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -4
README.md CHANGED
@@ -29,6 +29,10 @@ Before running the snippet, you need to install the following dependencies:
29
  pip install torch transformers accelerate pillow
30
  ```
31
 
 
 
 
 
32
  ```python
33
  import torch
34
  import transformers
@@ -42,12 +46,12 @@ transformers.logging.disable_progress_bar()
42
  warnings.filterwarnings('ignore')
43
 
44
  # set device
45
- torch.set_default_device('cpu') # or 'cuda'
46
 
47
  # create model
48
  model = AutoModelForCausalLM.from_pretrained(
49
  'BAAI/Bunny-v1_0-3B-zh',
50
- torch_dtype=torch.float16,
51
  device_map='auto',
52
  trust_remote_code=True)
53
  tokenizer = AutoTokenizer.from_pretrained(
@@ -58,11 +62,11 @@ tokenizer = AutoTokenizer.from_pretrained(
58
  prompt = 'Why is the image funny?'
59
  text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
60
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
61
- input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0)
62
 
63
  # image, sample images can be found in images folder
64
  image = Image.open('example_2.png')
65
- image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
66
 
67
  # generate
68
  output_ids = model.generate(
 
29
  pip install torch transformers accelerate pillow
30
  ```
31
 
32
+ If the CUDA memory is enough, it would be faster to execute this snippet by setting `CUDA_VISIBLE_DEVICES=0`.
33
+
34
+
35
+
36
  ```python
37
  import torch
38
  import transformers
 
46
  warnings.filterwarnings('ignore')
47
 
48
  # set device
49
+ device = 'cuda' # or cpu
50
 
51
  # create model
52
  model = AutoModelForCausalLM.from_pretrained(
53
  'BAAI/Bunny-v1_0-3B-zh',
54
+ torch_dtype=torch.float16, # float32 for cpu
55
  device_map='auto',
56
  trust_remote_code=True)
57
  tokenizer = AutoTokenizer.from_pretrained(
 
62
  prompt = 'Why is the image funny?'
63
  text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
64
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
65
+ input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device)
66
 
67
  # image, sample images can be found in images folder
68
  image = Image.open('example_2.png')
69
+ image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device)
70
 
71
  # generate
72
  output_ids = model.generate(