Spaces:
Paused
Paused
Update utils.py
Browse files
utils.py
CHANGED
@@ -123,6 +123,28 @@ def load_tokenizer_and_model_gpt2(base_model,load_8bit=False):
|
|
123 |
|
124 |
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
# Greedy Search
|
127 |
def greedy_search(input_ids: torch.Tensor,
|
128 |
model: torch.nn.Module,
|
|
|
123 |
|
124 |
|
125 |
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
def load_tokenizer_and_model_bloke_gpt(base_model, model_basename):
|
131 |
+
if torch.cuda.is_available():
|
132 |
+
device = "cuda"
|
133 |
+
else:
|
134 |
+
device = "cpu"
|
135 |
+
|
136 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
|
137 |
+
model = AutoGPTQForCausalLM.from_quantized(base_model,
|
138 |
+
model_basename=model_basename,
|
139 |
+
use_safetensors=True,
|
140 |
+
trust_remote_code=False,
|
141 |
+
device="cuda:0",
|
142 |
+
use_triton=use_triton,
|
143 |
+
quantize_config=None)
|
144 |
+
|
145 |
+
return tokenizer,model,device
|
146 |
+
|
147 |
+
|
148 |
# Greedy Search
|
149 |
def greedy_search(input_ids: torch.Tensor,
|
150 |
model: torch.nn.Module,
|