judge-xl-model / judge_xl_model.py
Wonder-Griffin's picture
Upload JudgeXL
2d8d1b2 verified
raw
history blame
5.81 kB
import torch
import torch.nn as nn
import numpy as np
from transformers import PreTrainedModel, PretrainedConfig, AutoModelForCausalLM, AutoConfig
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
torch.autograd.set_detect_anomaly(True)
class JudgeXLConfig(PretrainedConfig):
model_type = "judge-xl"
def __init__(self, vocab_size=50276, hidden_size=768, max_len=256, n_layer=12, n_head=12,
ff_expansion_factor=4, rnn_units=768, num_labels=5, dropout=0.1, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.max_len = max_len
self.n_layer = n_layer
self.n_head = n_head
self.ff_expansion_factor = ff_expansion_factor
self.rnn_units = rnn_units
self.num_labels = num_labels
self.dropout = dropout
self.is_decoder = True
class CustomEmbedding(nn.Module):
def __init__(self, vocab_size, hidden_size):
super(CustomEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
def forward(self, inputs):
return self.embedding(inputs)
class PositionalEncoding(nn.Module):
def __init__(self, n_embd, max_len=5000):
super(PositionalEncoding, self).__init__()
self.n_embd = n_embd
self.max_len = max_len
pe = torch.zeros(max_len, n_embd)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, n_embd, 2).float() * -(np.log(10000.0) / n_embd))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class TransformerXLBlock(nn.Module):
def __init__(self, config):
super(TransformerXLBlock, self).__init__()
self.attn = nn.MultiheadAttention(config.hidden_size, config.n_head, dropout=config.dropout)
self.ff = FeedForward(config)
self.ln1 = nn.LayerNorm(config.hidden_size)
self.ln2 = nn.LayerNorm(config.hidden_size)
def forward(self, x, mask=None):
attn_out, _ = self.attn(x, x, x, attn_mask=mask)
out1 = self.ln1(x + attn_out)
ff_out = self.ff(out1)
return self.ln2(out1 + ff_out)
class FeedForward(nn.Module):
def __init__(self, config):
super(FeedForward, self).__init__()
self.dense1 = nn.Linear(config.hidden_size, config.hidden_size * config.ff_expansion_factor)
self.dense2 = nn.Linear(config.hidden_size * config.ff_expansion_factor, config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = torch.nn.functional.gelu(self.dense1(x))
x = self.dropout(x)
return self.dense2(x)
class JudgeXL(PreTrainedModel):
config_class = JudgeXLConfig
def __init__(self, config):
super().__init__(config)
self.token_embedding = CustomEmbedding(config.vocab_size, config.hidden_size)
self.pos_encoding = PositionalEncoding(config.hidden_size, config.max_len)
self.transformer_blocks = nn.ModuleList([TransformerXLBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.hidden_size)
self.rnn = nn.LSTM(config.hidden_size, config.rnn_units, num_layers=2, dropout=config.dropout, bidirectional=True, batch_first=True)
self.fc = nn.Linear(config.rnn_units * 2, config.vocab_size)
self.lm_head = nn.Linear(config.rnn_units, config.vocab_size)
self.post_init()
def forward(self, x, mask=None):
x = self.token_embedding(x)
x = self.pos_encoding(x)
for block in self.transformer_blocks:
x = block(x, mask=mask)
x = self.ln_f(x)
x, _ = self.rnn(x)
x = self.fc(x)
x = self.lm_head(x)
return x
def init_weights(self):
"""
Initialize weights for your custom layers using PreTrainedModel's default weight initialization method.
"""
# Hugging Face’s PreTrainedModel has a standard method for initializing weights
super().init_weights()
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
if past is None:
return {"input_ids": input_ids}
else:
return {"input_ids": input_ids[:, -1:], "past_key_values": past}
def _reorder_cache(self, past, beam_idx):
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def generate(self, prompt, max_len=100):
self.eval()
input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids
generated = input_ids
with torch.no_grad():
for _ in range(max_len):
outputs = self.forward(generated)
next_token_logits = outputs[:, :] # Adjusted indexing
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
generated = torch.cat((generated, next_token_id), dim=1)
if next_token_id.item() == self.tokenizer.sep_token_id:
break
generated_text = self.tokenizer.decode(generated[0], skip_special_tokens=True)
return generated_text
config = JudgeXLConfig()
model = JudgeXL(config)
# Register JudgeXLConfig with AutoConfig
JudgeXLConfig.register_for_auto_class(AutoConfig)
# Register JudgeXL with AutoModelForCausalLM
JudgeXL.register_for_auto_class(AutoModelForCausalLM)
model.push_to_hub("Wonder-Griffin/judge-xl-model")