Rename cognitivess_model/modeling_Cognitivess.py to cognitivess_model/modeling_cognitivess.py
e942e2b
verified
import torch | |
import torch.nn as nn | |
from torch.nn import CrossEntropyLoss | |
from transformers.modeling_outputs import ( | |
CausalLMOutputWithCrossAttentions, | |
SequenceClassifierOutput, | |
TokenClassifierOutput, | |
QuestionAnsweringModelOutput, | |
) | |
from transformers import LlamaModel, LlamaPreTrainedModel | |
from .configuration_cognitivess import CognitivessConfig | |
class CognitivessModel(LlamaModel): | |
config_class = CognitivessConfig | |
class CognitivessForCausalLM(LlamaPreTrainedModel): | |
config_class = CognitivessConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = CognitivessModel(config) | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
lm_logits = self.lm_head(hidden_states) | |
loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
loss_fct = CrossEntropyLoss(ignore_index=-100) | |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
if not return_dict: | |
output = (lm_logits,) + outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return CausalLMOutputWithCrossAttentions( | |
loss=loss, | |
logits=lm_logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
cross_attentions=outputs.cross_attentions, | |
) | |
class CognitivessForSequenceClassification(LlamaPreTrainedModel): | |
config_class = CognitivessConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.model = CognitivessModel(config) | |
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
logits = self.score(hidden_states[:, 0, :]) | |
loss = None | |
if labels is not None: | |
if self.num_labels == 1: | |
loss_fct = nn.MSELoss() | |
loss = loss_fct(logits.view(-1), labels.view(-1)) | |
else: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class CognitivessForTokenClassification(LlamaPreTrainedModel): | |
config_class = CognitivessConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.model = CognitivessModel(config) | |
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
logits = self.score(hidden_states) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return TokenClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class CognitivessForQuestionAnswering(LlamaPreTrainedModel): | |
config_class = CognitivessConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = CognitivessModel(config) | |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
start_positions=None, | |
end_positions=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
logits = self.qa_outputs(sequence_output) | |
start_logits, end_logits = logits.split(1, dim=-1) | |
start_logits = start_logits.squeeze(-1).contiguous() | |
end_logits = end_logits.squeeze(-1).contiguous() | |
loss = None | |
if start_positions is not None and end_positions is not None: | |
if len(start_positions.size()) > 1: | |
start_positions = start_positions.squeeze(-1) | |
if len(end_positions.size()) > 1: | |
end_positions = end_positions.squeeze(-1) | |
ignored_index = start_logits.size(1) | |
start_positions.clamp_(0, ignored_index) | |
end_positions.clamp_(0, ignored_index) | |
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | |
start_loss = loss_fct(start_logits, start_positions) | |
end_loss = loss_fct(end_logits, end_positions) | |
loss = (start_loss + end_loss) / 2 | |
if not return_dict: | |
output = (start_logits, end_logits) + outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return QuestionAnsweringModelOutput( | |
loss=loss, | |
start_logits=start_logits, | |
end_logits=end_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |