jan-hq commited on
Commit
4069267
·
verified ·
1 Parent(s): 4a480d8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -2
README.md CHANGED
@@ -71,10 +71,9 @@ sound_tokens = audio_to_sound_tokens("/path/to/your/audio/file")
71
  Then, we can inference the model the same as any other LLM.
72
 
73
  ```python
74
- import torch
75
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
76
 
77
- def setup_pipeline(model_path, use_4bit=True):
78
  tokenizer = AutoTokenizer.from_pretrained(model_path)
79
 
80
  model_kwargs = {"device_map": "auto"}
@@ -86,6 +85,12 @@ def setup_pipeline(model_path, use_4bit=True):
86
  bnb_4bit_use_double_quant=True,
87
  bnb_4bit_quant_type="nf4",
88
  )
 
 
 
 
 
 
89
  else:
90
  model_kwargs["torch_dtype"] = torch.bfloat16
91
 
 
71
  Then, we can inference the model the same as any other LLM.
72
 
73
  ```python
 
74
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
75
 
76
+ def setup_pipeline(model_path, use_4bit=False, use_8bit=False):
77
  tokenizer = AutoTokenizer.from_pretrained(model_path)
78
 
79
  model_kwargs = {"device_map": "auto"}
 
85
  bnb_4bit_use_double_quant=True,
86
  bnb_4bit_quant_type="nf4",
87
  )
88
+ elif use_8bit:
89
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
90
+ load_in_8bit=True,
91
+ bnb_8bit_compute_dtype=torch.bfloat16,
92
+ bnb_8bit_use_double_quant=True,
93
+ )
94
  else:
95
  model_kwargs["torch_dtype"] = torch.bfloat16
96