jina-v3-rullmarena-judge-300924 / modeling_jina_judge.py
kaleinaNyan's picture
Upload JinaJudge
277d73c verified
raw
history blame
2.99 kB
from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers import PreTrainedModel, PretrainedConfig
from transformers import CONFIG_MAPPING, MODEL_MAPPING
import torch
import torch.nn.functional as F
import torch.nn as nn
class JinaJudgeConfig(PretrainedConfig):
model_type = "jina-judge"
def __init__(self, n_classes=3, hidden_dim=512, num_decoder_layers=5, nhead=8, dropout_prob=0.2, **kwargs):
super().__init__(**kwargs)
self.n_classes = n_classes
self.hidden_dim = hidden_dim
self.num_decoder_layers = num_decoder_layers
self.nhead = nhead
self.dropout_prob = dropout_prob
class JinaJudge(PreTrainedModel):
config_class = JinaJudgeConfig
def __init__(self, config: JinaJudgeConfig):
super().__init__(config)
self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True)
jina_config = AutoConfig.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True)
self.encoder = AutoModel.from_config(jina_config, trust_remote_code=True, torch_dtype=torch.bfloat16)
self.encoder.lora_main_params_trainable = True
self.projection = nn.Linear(self.encoder.config.hidden_size, config.hidden_dim)
# Transformer Decoder Layer
decoder_layer = nn.TransformerDecoderLayer(
d_model=config.hidden_dim,
nhead=config.nhead,
dim_feedforward=config.hidden_dim * 2,
dropout=config.dropout_prob
)
# Transformer Decoder
self.decoder = nn.TransformerDecoder(
decoder_layer,
num_layers=config.num_decoder_layers
)
# Embedding for a single token as the initial input to the decoder
self.decoder_input_embedding = nn.Parameter(
torch.randn(1, 1, config.hidden_dim,)
)
# Classification head
self.classification_head = nn.Linear(config.hidden_dim, config.n_classes)
def forward(self, prompts):
inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(self.device)
encoder_outputs = self.encoder(**inputs)
encoder_hidden_states = encoder_outputs.last_hidden_state.float()
encoder_hidden_states = self.projection(encoder_hidden_states)
encoder_padding_mask = (inputs["attention_mask"] == 0).to(self.device)
batch_size = encoder_hidden_states.size(0)
decoder_input = self.decoder_input_embedding.expand(1, batch_size, -1).to(self.device)
decoder_output = self.decoder(
tgt=decoder_input,
memory=encoder_hidden_states.transpose(0, 1),
memory_key_padding_mask=encoder_padding_mask
).squeeze(0)
logits = self.classification_head(decoder_output)
return logits
AutoConfig.register("jina-judge", JinaJudgeConfig)
AutoModel.register(JinaJudgeConfig, JinaJudge)