wenjiao commited on
Commit
6dd427b
1 Parent(s): d981095

update WeightDtype params

Browse files
Files changed (1) hide show
  1. src/display/utils.py +7 -1
src/display/utils.py CHANGED
@@ -237,19 +237,25 @@ class WeightDtype(Enum):
237
  nf4 = ModelDetails("nf4")
238
  fp4 = ModelDetails("fp4")
239
 
 
240
  Unknown = ModelDetails("?")
241
 
 
 
 
242
  def from_str(weight_dtype):
243
  if weight_dtype in ["int2"]:
244
  return WeightDtype.int2
245
  if weight_dtype in ["int3"]:
246
- return WeightDtype.int3
247
  if weight_dtype in ["int4"]:
248
  return WeightDtype.int4
249
  if weight_dtype in ["nf4"]:
250
  return WeightDtype.nf4
251
  if weight_dtype in ["fp4"]:
252
  return WeightDtype.fp4
 
 
253
  return WeightDtype.Unknown
254
 
255
  class ComputeDtype(Enum):
 
237
  nf4 = ModelDetails("nf4")
238
  fp4 = ModelDetails("fp4")
239
 
240
+
241
  Unknown = ModelDetails("?")
242
 
243
+ all = ModelDetails("All")
244
+
245
+
246
  def from_str(weight_dtype):
247
  if weight_dtype in ["int2"]:
248
  return WeightDtype.int2
249
  if weight_dtype in ["int3"]:
250
+ return WeightDtype.int3
251
  if weight_dtype in ["int4"]:
252
  return WeightDtype.int4
253
  if weight_dtype in ["nf4"]:
254
  return WeightDtype.nf4
255
  if weight_dtype in ["fp4"]:
256
  return WeightDtype.fp4
257
+ if weight_dtype in ["All"]:
258
+ return WeightDtype.all
259
  return WeightDtype.Unknown
260
 
261
  class ComputeDtype(Enum):