gemma-2b-orpo / README.md
anakin87's picture
improve readme
1a06bf8
|
raw
history blame
3.23 kB
metadata
license: other
license_name: gemma-terms-of-use
license_link: https://ai.google.dev/gemma/terms
library_name: transformers
base_model: google/gemma-2b
tags:
  - trl
  - orpo
  - generated_from_trainer
model-index:
  - name: gemma-2b-orpo
    results: []
datasets:
  - alvarobartt/dpo-mix-7k-simplified
language:
  - en

gemma-2b-orpo

This is an ORPO fine-tune of google/gemma-2b with alvarobartt/dpo-mix-7k-simplified.

ORPO

ORPO (Odds Ratio Preference Optimization) is a new training paradigm that combines the usually separated phases of SFT (Supervised Fine-Tuning) and Preference Alignment (usually performed with RLHF or simpler methods like DPO).

  • Faster training
  • Less memory usage (no reference model needed)
  • Good results!

๐Ÿ† Evaluation

Nous

gemma-2b-orpo performs well for its size on Nous' benchmark suite.

(evaluation conducted using LLM AutoEval).

Model Average AGIEval GPT4All TruthfulQA Bigbench
anakin87/gemma-2b-orpo ๐Ÿ“„ 39.45 23.76 58.25 44.47 31.32
mlabonne/Gemmalpaca-2B ๐Ÿ“„ 38.39 24.48 51.22 47.02 30.85
google/gemma-2b-it ๐Ÿ“„ 36.1 23.76 43.6 47.64 29.41
google/gemma-2b ๐Ÿ“„ 34.26 22.7 43.35 39.96 31.03

๐Ÿ™ Dataset

alvarobartt/dpo-mix-7k-simplified is a simplified version of argilla/dpo-mix-7k. You can find more information in the dataset card.

๐ŸŽฎ Model in action

Usage notebook

๐Ÿ““ Chat and RAG using Haystack

Simple text generation with Transformers

The model is small, so runs smoothly on Colab. It is also fine to load the model using quantization.

# pip install transformers accelerate
import torch
from transformers import pipeline
pipe = pipeline("text-generation", model="anakin87/gemma-2b-orpo", torch_dtype=torch.bfloat16, device_map="auto")
messages = [{"role": "user", "content": "Write a rap song on Vim vs VSCode."}]
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False)
outputs = pipe(prompt, max_new_tokens=500, do_sample=True, temperature=0.7,  top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])

Training

The model was trained using HF TRL. ๐Ÿ““ Training notebook

Framework versions

  • Transformers 4.39.1
  • Pytorch 2.2.0+cu121
  • Datasets 2.18.0
  • Tokenizers 0.15.2