Update src/display/utils.py
Browse files- src/display/utils.py +7 -6
src/display/utils.py
CHANGED
@@ -240,16 +240,15 @@ class WeightDtype(Enum):
|
|
240 |
int2 = ModelDetails("int2")
|
241 |
int3 = ModelDetails("int3")
|
242 |
int4 = ModelDetails("int4")
|
|
|
243 |
nf4 = ModelDetails("nf4")
|
244 |
fp4 = ModelDetails("fp4")
|
245 |
-
|
246 |
bf16 = ModelDetails("bfloat16")
|
247 |
-
|
248 |
|
249 |
Unknown = ModelDetails("?")
|
250 |
|
251 |
-
|
252 |
-
|
253 |
def from_str(weight_dtype):
|
254 |
if weight_dtype in ["int2"]:
|
255 |
return WeightDtype.int2
|
@@ -257,6 +256,8 @@ class WeightDtype(Enum):
|
|
257 |
return WeightDtype.int3
|
258 |
if weight_dtype in ["int4"]:
|
259 |
return WeightDtype.int4
|
|
|
|
|
260 |
if weight_dtype in ["nf4"]:
|
261 |
return WeightDtype.nf4
|
262 |
if weight_dtype in ["fp4"]:
|
@@ -264,11 +265,11 @@ class WeightDtype(Enum):
|
|
264 |
if weight_dtype in ["All"]:
|
265 |
return WeightDtype.all
|
266 |
if weight_dtype in ["float16"]:
|
267 |
-
return WeightDtype.
|
268 |
if weight_dtype in ["bfloat16"]:
|
269 |
return WeightDtype.bf16
|
270 |
if weight_dtype in ["float32"]:
|
271 |
-
return WeightDtype.
|
272 |
return WeightDtype.Unknown
|
273 |
|
274 |
class ComputeDtype(Enum):
|
|
|
240 |
int2 = ModelDetails("int2")
|
241 |
int3 = ModelDetails("int3")
|
242 |
int4 = ModelDetails("int4")
|
243 |
+
int8 = ModelDetails("int8")
|
244 |
nf4 = ModelDetails("nf4")
|
245 |
fp4 = ModelDetails("fp4")
|
246 |
+
f16 = ModelDetails("float16")
|
247 |
bf16 = ModelDetails("bfloat16")
|
248 |
+
f32 = ModelDetails("float32")
|
249 |
|
250 |
Unknown = ModelDetails("?")
|
251 |
|
|
|
|
|
252 |
def from_str(weight_dtype):
|
253 |
if weight_dtype in ["int2"]:
|
254 |
return WeightDtype.int2
|
|
|
256 |
return WeightDtype.int3
|
257 |
if weight_dtype in ["int4"]:
|
258 |
return WeightDtype.int4
|
259 |
+
if weight_dtype in ["int8"]:
|
260 |
+
return WeightDtype.int8
|
261 |
if weight_dtype in ["nf4"]:
|
262 |
return WeightDtype.nf4
|
263 |
if weight_dtype in ["fp4"]:
|
|
|
265 |
if weight_dtype in ["All"]:
|
266 |
return WeightDtype.all
|
267 |
if weight_dtype in ["float16"]:
|
268 |
+
return WeightDtype.f16
|
269 |
if weight_dtype in ["bfloat16"]:
|
270 |
return WeightDtype.bf16
|
271 |
if weight_dtype in ["float32"]:
|
272 |
+
return WeightDtype.f32
|
273 |
return WeightDtype.Unknown
|
274 |
|
275 |
class ComputeDtype(Enum):
|