Spaces:
Paused
Paused
Update utils.py
Browse files
utils.py
CHANGED
@@ -10,6 +10,7 @@ import requests
|
|
10 |
import re
|
11 |
import html
|
12 |
import torch
|
|
|
13 |
import sys
|
14 |
import gc
|
15 |
from pygments.lexers import guess_lexer, ClassNotFound
|
@@ -18,7 +19,7 @@ from pygments import highlight
|
|
18 |
from pygments.lexers import guess_lexer,get_lexer_by_name
|
19 |
from pygments.formatters import HtmlFormatter
|
20 |
import transformers
|
21 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
22 |
import datasets
|
23 |
from datasets import load_dataset
|
24 |
import evaluate
|
@@ -335,12 +336,11 @@ def daten_laden(name):
|
|
335 |
|
336 |
|
337 |
#Quantisation - tzo speed up training
|
338 |
-
def bnb_config (load4Bit, double_quant)
|
339 |
-
compute_dtype = getattr(torch, "float16")
|
340 |
bnb_config = BitsAndBytesConfig(
|
341 |
load_in_4bit= load4Bit,
|
342 |
bnb_4bit_quant_type="nf4",
|
343 |
-
bnb_4bit_compute_dtype=
|
344 |
bnb_4bit_use_double_quant=double_quant,
|
345 |
)
|
346 |
return bnb_config
|
|
|
10 |
import re
|
11 |
import html
|
12 |
import torch
|
13 |
+
from torch import cuda, bfloat16
|
14 |
import sys
|
15 |
import gc
|
16 |
from pygments.lexers import guess_lexer, ClassNotFound
|
|
|
19 |
from pygments.lexers import guess_lexer,get_lexer_by_name
|
20 |
from pygments.formatters import HtmlFormatter
|
21 |
import transformers
|
22 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
23 |
import datasets
|
24 |
from datasets import load_dataset
|
25 |
import evaluate
|
|
|
336 |
|
337 |
|
338 |
#Quantisation - tzo speed up training
|
339 |
+
def bnb_config (load4Bit, double_quant):
|
|
|
340 |
bnb_config = BitsAndBytesConfig(
|
341 |
load_in_4bit= load4Bit,
|
342 |
bnb_4bit_quant_type="nf4",
|
343 |
+
bnb_4bit_compute_dtype=bfloat16,
|
344 |
bnb_4bit_use_double_quant=double_quant,
|
345 |
)
|
346 |
return bnb_config
|