Ali Elfilali commited on
Commit
f47170a
β€’
1 Parent(s): 9faa299

Update src/display/utils.py

Browse files
Files changed (1) hide show
  1. src/display/utils.py +44 -19
src/display/utils.py CHANGED
@@ -61,11 +61,33 @@ class ModelDetails:
61
  symbol: str = "" # emoji
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  class ModelType(Enum):
65
  PT = ModelDetails(name="pretrained", symbol="🟒")
66
- FT = ModelDetails(name="fine-tuned", symbol="πŸ”Ά")
67
- IFT = ModelDetails(name="instruction-tuned", symbol="β­•")
68
- RL = ModelDetails(name="RL-tuned", symbol="🟦")
 
69
  Unknown = ModelDetails(name="", symbol="?")
70
 
71
  def to_str(self, separator=" "):
@@ -75,41 +97,44 @@ class ModelType(Enum):
75
  def from_str(type):
76
  if "fine-tuned" in type or "πŸ”Ά" in type:
77
  return ModelType.FT
 
 
78
  if "pretrained" in type or "🟒" in type:
79
  return ModelType.PT
80
- if "RL-tuned" in type or "🟦" in type:
81
- return ModelType.RL
82
- if "instruction-tuned" in type or "β­•" in type:
83
- return ModelType.IFT
84
  return ModelType.Unknown
85
 
 
86
  class WeightType(Enum):
87
  Adapter = ModelDetails("Adapter")
88
  Original = ModelDetails("Original")
89
  Delta = ModelDetails("Delta")
90
 
91
  class Precision(Enum):
 
92
  float16 = ModelDetails("float16")
93
  bfloat16 = ModelDetails("bfloat16")
94
- float32 = ModelDetails("float32")
95
- #qt_8bit = ModelDetails("8bit")
96
- #qt_4bit = ModelDetails("4bit")
97
- #qt_GPTQ = ModelDetails("GPTQ")
98
  Unknown = ModelDetails("?")
99
 
100
  def from_str(precision):
 
 
101
  if precision in ["torch.float16", "float16"]:
102
  return Precision.float16
103
  if precision in ["torch.bfloat16", "bfloat16"]:
104
  return Precision.bfloat16
105
- if precision in ["float32"]:
106
- return Precision.float32
107
- #if precision in ["8bit"]:
108
- # return Precision.qt_8bit
109
- #if precision in ["4bit"]:
110
- # return Precision.qt_4bit
111
- #if precision in ["GPTQ", "None"]:
112
- # return Precision.qt_GPTQ
113
  return Precision.Unknown
114
 
115
  # Column selection
 
61
  symbol: str = "" # emoji
62
 
63
 
64
+ # class ModelType(Enum):
65
+ # PT = ModelDetails(name="pretrained", symbol="🟒")
66
+ # FT = ModelDetails(name="fine-tuned", symbol="πŸ”Ά")
67
+ # IFT = ModelDetails(name="instruction-tuned", symbol="β­•")
68
+ # RL = ModelDetails(name="RL-tuned", symbol="🟦")
69
+ # Unknown = ModelDetails(name="", symbol="?")
70
+
71
+ # def to_str(self, separator=" "):
72
+ # return f"{self.value.symbol}{separator}{self.value.name}"
73
+
74
+ # @staticmethod
75
+ # def from_str(type):
76
+ # if "fine-tuned" in type or "πŸ”Ά" in type:
77
+ # return ModelType.FT
78
+ # if "pretrained" in type or "🟒" in type:
79
+ # return ModelType.PT
80
+ # if "RL-tuned" in type or "🟦" in type:
81
+ # return ModelType.RL
82
+ # if "instruction-tuned" in type or "β­•" in type:
83
+ # return ModelType.IFT
84
+ # return ModelType.Unknown
85
  class ModelType(Enum):
86
  PT = ModelDetails(name="pretrained", symbol="🟒")
87
+ CPT = ModelDetails(name="continuously pretrained", symbol="🟩")
88
+ FT = ModelDetails(name="fine-tuned on domain-specific datasets", symbol="πŸ”Ά")
89
+ chat = ModelDetails(name="chat models (RLHF, DPO, IFT, ...)", symbol="πŸ’¬")
90
+ merges = ModelDetails(name="base merges and moerges", symbol="🀝")
91
  Unknown = ModelDetails(name="", symbol="?")
92
 
93
  def to_str(self, separator=" "):
 
97
  def from_str(type):
98
  if "fine-tuned" in type or "πŸ”Ά" in type:
99
  return ModelType.FT
100
+ if "continously pretrained" in type or "🟩" in type:
101
+ return ModelType.CPT
102
  if "pretrained" in type or "🟒" in type:
103
  return ModelType.PT
104
+ if any([k in type for k in ["instruction-tuned", "RL-tuned", "chat", "🟦", "β­•", "πŸ’¬"]]):
105
+ return ModelType.chat
106
+ if "merge" in type or "🀝" in type:
107
+ return ModelType.merges
108
  return ModelType.Unknown
109
 
110
+
111
  class WeightType(Enum):
112
  Adapter = ModelDetails("Adapter")
113
  Original = ModelDetails("Original")
114
  Delta = ModelDetails("Delta")
115
 
116
  class Precision(Enum):
117
+ float32 = ModelDetails("float32")
118
  float16 = ModelDetails("float16")
119
  bfloat16 = ModelDetails("bfloat16")
120
+ qt_8bit = ModelDetails("8bit")
121
+ qt_4bit = ModelDetails("4bit")
122
+ qt_GPTQ = ModelDetails("GPTQ")
 
123
  Unknown = ModelDetails("?")
124
 
125
  def from_str(precision):
126
+ if precision in ["float32"]:
127
+ return Precision.float32
128
  if precision in ["torch.float16", "float16"]:
129
  return Precision.float16
130
  if precision in ["torch.bfloat16", "bfloat16"]:
131
  return Precision.bfloat16
132
+ if precision in ["8bit"]:
133
+ return Precision.qt_8bit
134
+ if precision in ["4bit"]:
135
+ return Precision.qt_4bit
136
+ if precision in ["GPTQ", "None"]:
137
+ return Precision.qt_GPTQ
 
 
138
  return Precision.Unknown
139
 
140
  # Column selection