|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from transformers import PreTrainedModel, PretrainedConfig, AutoModelForCausalLM, AutoConfig
|
|
import logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
torch.autograd.set_detect_anomaly(True)
|
|
|
|
class JudgeXLConfig(PretrainedConfig):
|
|
model_type = "judge-xl"
|
|
|
|
def __init__(self, vocab_size=50276, hidden_size=768, max_len=256, n_layer=12, n_head=12,
|
|
ff_expansion_factor=4, rnn_units=768, num_labels=5, dropout=0.1, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.max_len = max_len
|
|
self.n_layer = n_layer
|
|
self.n_head = n_head
|
|
self.ff_expansion_factor = ff_expansion_factor
|
|
self.rnn_units = rnn_units
|
|
self.num_labels = num_labels
|
|
self.dropout = dropout
|
|
self.is_decoder = True
|
|
|
|
class CustomEmbedding(nn.Module):
|
|
def __init__(self, vocab_size, hidden_size):
|
|
super(CustomEmbedding, self).__init__()
|
|
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
|
|
|
def forward(self, inputs):
|
|
return self.embedding(inputs)
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, n_embd, max_len=5000):
|
|
super(PositionalEncoding, self).__init__()
|
|
self.n_embd = n_embd
|
|
self.max_len = max_len
|
|
pe = torch.zeros(max_len, n_embd)
|
|
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
div_term = torch.exp(torch.arange(0, n_embd, 2).float() * -(np.log(10000.0) / n_embd))
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0).transpose(0, 1)
|
|
self.register_buffer('pe', pe)
|
|
|
|
def forward(self, x):
|
|
return x + self.pe[:x.size(0), :]
|
|
|
|
class TransformerXLBlock(nn.Module):
|
|
def __init__(self, config):
|
|
super(TransformerXLBlock, self).__init__()
|
|
self.attn = nn.MultiheadAttention(config.hidden_size, config.n_head, dropout=config.dropout)
|
|
self.ff = FeedForward(config)
|
|
self.ln1 = nn.LayerNorm(config.hidden_size)
|
|
self.ln2 = nn.LayerNorm(config.hidden_size)
|
|
|
|
def forward(self, x, mask=None):
|
|
attn_out, _ = self.attn(x, x, x, attn_mask=mask)
|
|
out1 = self.ln1(x + attn_out)
|
|
ff_out = self.ff(out1)
|
|
return self.ln2(out1 + ff_out)
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, config):
|
|
super(FeedForward, self).__init__()
|
|
self.dense1 = nn.Linear(config.hidden_size, config.hidden_size * config.ff_expansion_factor)
|
|
self.dense2 = nn.Linear(config.hidden_size * config.ff_expansion_factor, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
|
|
def forward(self, x):
|
|
x = torch.nn.functional.gelu(self.dense1(x))
|
|
x = self.dropout(x)
|
|
return self.dense2(x)
|
|
|
|
class JudgeXL(PreTrainedModel):
|
|
config_class = JudgeXLConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.token_embedding = CustomEmbedding(config.vocab_size, config.hidden_size)
|
|
self.pos_encoding = PositionalEncoding(config.hidden_size, config.max_len)
|
|
self.transformer_blocks = nn.ModuleList([TransformerXLBlock(config) for _ in range(config.n_layer)])
|
|
self.ln_f = nn.LayerNorm(config.hidden_size)
|
|
self.rnn = nn.LSTM(config.hidden_size, config.rnn_units, num_layers=2, dropout=config.dropout, bidirectional=True, batch_first=True)
|
|
self.fc = nn.Linear(config.rnn_units * 2, config.vocab_size)
|
|
self.lm_head = nn.Linear(config.rnn_units, config.vocab_size)
|
|
self.post_init()
|
|
|
|
def forward(self, x, mask=None):
|
|
x = self.token_embedding(x)
|
|
x = self.pos_encoding(x)
|
|
for block in self.transformer_blocks:
|
|
x = block(x, mask=mask)
|
|
x = self.ln_f(x)
|
|
x, _ = self.rnn(x)
|
|
x = self.fc(x)
|
|
x = self.lm_head(x)
|
|
return x
|
|
def init_weights(self):
|
|
"""
|
|
Initialize weights for your custom layers using PreTrainedModel's default weight initialization method.
|
|
"""
|
|
|
|
super().init_weights()
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
|
if past is None:
|
|
return {"input_ids": input_ids}
|
|
else:
|
|
return {"input_ids": input_ids[:, -1:], "past_key_values": past}
|
|
|
|
def _reorder_cache(self, past, beam_idx):
|
|
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
|
def generate(self, prompt, max_len=100):
|
|
self.eval()
|
|
input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids
|
|
generated = input_ids
|
|
with torch.no_grad():
|
|
for _ in range(max_len):
|
|
outputs = self.forward(generated)
|
|
next_token_logits = outputs[:, :]
|
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
|
|
generated = torch.cat((generated, next_token_id), dim=1)
|
|
if next_token_id.item() == self.tokenizer.sep_token_id:
|
|
break
|
|
generated_text = self.tokenizer.decode(generated[0], skip_special_tokens=True)
|
|
return generated_text
|
|
|
|
config = JudgeXLConfig()
|
|
model = JudgeXL(config)
|
|
|
|
|
|
JudgeXLConfig.register_for_auto_class(AutoConfig)
|
|
|
|
|
|
JudgeXL.register_for_auto_class(AutoModelForCausalLM)
|
|
|
|
model.push_to_hub("Wonder-Griffin/judge-xl-model")
|
|
|