Update modeling_custom.py
Browse files- modeling_custom.py +4 -3
modeling_custom.py
CHANGED
@@ -85,10 +85,11 @@ class CustomOutput(ModelOutput):
|
|
85 |
|
86 |
class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
87 |
def __init__(self, config):
|
|
|
88 |
super().__init__(config)
|
89 |
-
self.model = AutoModelForSequenceClassification.from_pretrained(
|
90 |
-
|
91 |
-
|
92 |
self.num_labels = config.num_labels
|
93 |
config_dict = config.to_dict()
|
94 |
self.num_objectives = config_dict.get("num_objectives", 19)
|
|
|
85 |
|
86 |
class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
87 |
def __init__(self, config):
|
88 |
+
config.torch_dtype = torch.bfloat16
|
89 |
super().__init__(config)
|
90 |
+
#self.model = AutoModelForSequenceClassification.from_pretrained(
|
91 |
+
# "sfairXC/FsfairX-LLaMA3-RM-v0.1", num_labels=1, torch_dtype=torch.bfloat16, use_flash_attention_2=True, trust_remote_code=True,)
|
92 |
+
self.model = LlamaModel(config)
|
93 |
self.num_labels = config.num_labels
|
94 |
config_dict = config.to_dict()
|
95 |
self.num_objectives = config_dict.get("num_objectives", 19)
|