|
--- |
|
library_name: transformers |
|
language: |
|
- en |
|
base_model: |
|
- google/gemma-2-9b-it |
|
pipeline_tag: text-classification |
|
--- |
|
|
|
# Model Card for Model ID |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
Given a (Query, ModelAAnswer, ModelBAnswer) |
|
This model gives a vector in 3D like lMSYS (ModelAWin Proba), (ModelBWin Proba), (Tie Proba) |
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
<!-- Provide a longer summary of what this model is. --> |
|
|
|
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated. |
|
|
|
- **Developed by:** @sayoulala (Yang Zhou) |
|
|
|
- **Model type:** Gemma for Sentence Classification |
|
- **Language(s) (NLP):** English Only |
|
|
|
### Model Sources [optional] |
|
|
|
<!-- Provide the basic links for the model. --> |
|
|
|
- **Repository:** [More Information Needed] |
|
- **Paper [optional]:** [More Information Needed] |
|
- **Demo [optional]:** [More Information Needed] |
|
|
|
## Uses |
|
|
|
Mimic human preference given a query and 2 different answers. |
|
|
|
### Direct Use |
|
|
|
```python |
|
import torch |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss |
|
from transformers import Gemma2PreTrainedModel,Gemma2Model, Cache, AutoTokenizer |
|
from transformers.modeling_outputs import SequenceClassifierOutputWithPast |
|
from typing import Optional, List, Union, Tuple |
|
from dataclasses import dataclass |
|
|
|
@dataclass |
|
class Config: |
|
gemma_dir = 'wath5/kgl_lmsys_pref_classif' |
|
max_length = 2000 |
|
batch_size = 8 |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
cfg = Config() |
|
|
|
class Gemma2ForSequenceClassificationV1(Gemma2PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.model = Gemma2Model(config) |
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
|
# Initialize weights and apply final processing |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
transformer_outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
# logits = self.score(hidden_states) |
|
|
|
if input_ids is not None: |
|
batch_size = input_ids.shape[0] |
|
else: |
|
batch_size = inputs_embeds.shape[0] |
|
|
|
if self.config.pad_token_id is None and batch_size != 1: |
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
|
if self.config.pad_token_id is None: |
|
sequence_lengths = -1 |
|
else: |
|
if input_ids is not None: |
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility |
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
|
sequence_lengths = sequence_lengths % input_ids.shape[-1] |
|
sequence_lengths = sequence_lengths.to(hidden_states.device) |
|
else: |
|
sequence_lengths = -1 |
|
hidden_states = hidden_states[ |
|
torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # eos |
|
pooled_logits = self.score(hidden_states) |
|
|
|
return pooled_logits |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/v7-dpo-16bit-01234-8bit-all/v7_dpo_16bit_01234_8bit_all") |
|
|
|
model = Gemma2ForSequenceClassificationV1.from_pretrained( |
|
cfg.gemma_dir, |
|
num_labels=3, |
|
device_map=cfg.device, |
|
use_cache=False, |
|
) |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
``` |
|
|
|
|
|
## How to Get Started with the Model |
|
|
|
```python |
|
from transformers.data.data_collator import pad_without_fast_tokenizer_warning |
|
|
|
@torch.no_grad() |
|
def single_prompt_inference(prompt, model, device, max_length=cfg.max_length): |
|
""" |
|
Perform inference on a single prompt. |
|
|
|
Args: |
|
prompt (str): The input prompt for inference. |
|
model (torch.nn.Module): The model used for inference. |
|
device (torch.device): The device to run inference on. |
|
tokenizer (Tokenizer): Tokenizer for preprocessing input text. |
|
max_length (int): Maximum sequence length for tokenization. |
|
|
|
Returns: |
|
dict: Probabilities for "a_win", "b_win", and "tie". |
|
""" |
|
# Tokenize the input prompt |
|
input_ids = tokenizer(prompt, truncation=True, max_length=max_length)['input_ids'] |
|
input_ids.append(tokenizer.eos_token_id) # Add EOS token if needed |
|
|
|
# Prepare inputs |
|
inputs = pad_without_fast_tokenizer_warning( |
|
tokenizer, |
|
{"input_ids": [input_ids]}, # Wrap in a list for compatibility |
|
padding="max_length", |
|
pad_to_multiple_of=None, |
|
max_length=max_length, |
|
return_tensors="pt", |
|
) |
|
|
|
# Move inputs to the appropriate device |
|
inputs = inputs.to(device) |
|
|
|
# Run the model |
|
outputs = model(**inputs) |
|
|
|
# Get probabilities using softmax |
|
proba = outputs.softmax(-1).cpu().squeeze() |
|
|
|
return { |
|
"winner_model_a": proba[0].item(), |
|
"winner_model_b": proba[1].item(), |
|
"tie": proba[2].item(), |
|
} |
|
|
|
|
|
def create_rounds(query: str, |
|
answer_a: str, |
|
answer_b: str) -> str: |
|
prompt =f"""User question: |
|
\"""{query}\""" |
|
Answer A: |
|
\"""{answer_a}\""" |
|
Answer B: |
|
\"""{answer_b}\""" |
|
""" |
|
return prompt |
|
|
|
query = "Hello, what is the height of the reassembled blind product?" |
|
answer_a = "Vous pouvez trouver toutes les informations techniques, y compris la hauteur du produit store remonté, directement sur la fiche produit de notre site. Cliquez sur l'onglet 'Produits' dnas la barre de navigation ou utilisez le moteur de recherche pour accéder au produit recherché. Avez vous une autre question ?" |
|
answer_b = "The height of the aluminum Venetian blind is 130 cm." |
|
prompt_direct = create_rounds(query, answer_a, answer_b) |
|
|
|
single_prompt_inference(prompt_direct, model=model, device=cfg.device) |
|
``` |
|
|
|
## Training Details |
|
|
|
https://github.com/shyoulala/LMSYS_BlackPearl |
|
|
|
|