<a href="https://colab.research.google.com/github/GHuyHuynh/fine-tune-flanT5-lawyer-/blob/main/ml_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assignment for Shifkey Labs Gen AI

This project fine tune Flan-T5 Base model from Hugging Face using a lawyer interaction dataset.

The goal is to make the model better at text summarization. The lawyer instruct dataset will be use for a very concise, short summarization.

Link to dataset [here](https://huggingface.co/datasets/Alignment-Lab-AI/Lawyer-Instruct)

Thank you Vansh Sood for this tutorial.

In [3]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.1-py3-none-any.whl (471 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.9 MB/s[0m eta [36m0:00

In [4]:
from datasets import load_dataset

## Load Flan T5 model

In [5]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch

# Check if a GPU is available
import torch
device = 0 if torch.cuda.is_available() else -1

# Load the Flan-T5 base model for text summarization
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

## Check device for GPU (in this case cuda)

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

print(device)

cuda


## Test

This give us a baseline on how the base model perform.

We can then compare with the fine tune model.

In [7]:
def summarize(text):
  inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
  summary_ids = model.generate(inputs["input_ids"], max_length=128, num_beams=4, early_stopping=True)
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

In [8]:
# Define a sample text for summarization
sample_text =     """
Person A: Hey, did you hear about the new project management software our company is planning to implement?

Person B: Yeah, I heard a bit about it. What’s the deal with it?

Person A: It’s called "TaskFlow." The management thinks it’s going to streamline our workflow, especially with remote teams. It’s supposed to integrate all the tools we use, like Slack, Trello, and Google Drive, into one platform.

Person B: That sounds interesting. But I’m a bit concerned about the learning curve. Is it user-friendly?

Person A: From what I’ve seen, it looks pretty intuitive. They’re also planning to run a couple of training sessions to get everyone up to speed. The first one is next Monday.

Person B: Okay, that helps. I guess I’ll have to attend that session. How does it compare to what we’re using now?

Person A: It’s supposed to be much more efficient. We’ll be able to track project progress more easily and get real-time updates. Plus, it has built-in analytics to help us with performance tracking.

Person B: That sounds promising. I just hope it doesn’t come with too many bugs at launch.

Person A: Yeah, that’s always a concern with new software. But they’ve been testing it for a while now, so fingers crossed it goes smoothly.

Person B: Let’s hope for the best. Thanks for the info!

Person A: No problem. See you at the training!
"""

# Summarize the sample text using the pre-trained model (without fine-tuning)
pre_finetuned_summary = summarize(sample_text)
print("Summary before fine-tuning:", pre_finetuned_summary)

Summary before fine-tuning: The new project management software called TaskFlow is being implemented by the company. It's called "TaskFlow" and it's supposed to integrate all the tools we use, like Slack, Trello, and Google Drive, into one platform. The first training session is next Monday.


## Load the lawyer instruct dataset

In [9]:
from datasets import load_dataset

dataset = load_dataset("Alignment-Lab-AI/Lawyer-Instruct", split="train")

README.md:   0%|          | 0.00/2.00k [00:00<?, ?B/s]

alpacmygavel.json:   0%|          | 0.00/6.12M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9241 [00:00<?, ? examples/s]

Split the dataset

In [10]:
dataset_split = dataset.train_test_split(test_size=0.1)

small_train_dataset = dataset_split["train"].train_test_split(test_size=0.99)["train"]
eval_dataset = dataset_split["test"]

# Preprocessing the Dataset

The lawyer instruct dataset have 2 columns
- 'instrunction': the text input
- 'output': the lawyer interpertation of the input

In [11]:
def preprocess_function(examples):
  # Extract the articles from the dataset
  inputs = [doc for doc in examples['instruction']]

  # Tokenize the instuction
  model_inputs = tokenizer(inputs, max_length=512, padding="max_length", truncation=True, return_tensors="pt")

  # Tokenize the output
  with tokenizer.as_target_tokenizer():
    labels = tokenizer(examples['output'], max_length=128, padding="max_length", truncation=True, return_tensors="pt")

  # Attach the tokenized summaries as labels to the model inputs
  model_inputs["labels"] = labels["input_ids"]

  # Move the tokenized inputs and labels to the appropriate device (GPU/CPU)
  model_inputs = {k: v.to(device) for k, v in model_inputs.items()}

  return model_inputs


# Tokenizing Dataset

In [12]:
# Tokenize the small training dataset
tokenized_train_dataset = small_train_dataset.map(preprocess_function, batched=True)

# Tokenize the evaluation dataset
tokenized_eval_dataset = eval_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/83 [00:00<?, ? examples/s]



Map:   0%|          | 0/925 [00:00<?, ? examples/s]

# Setting Training Argurments

In [13]:
from transformers import Seq2SeqTrainingArguments

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir='./results',              # Directory to save the model checkpoints
    evaluation_strategy="epoch",         # Evaluate the model at the end of every epoch
    learning_rate=2e-5,                  # Learning rate for the optimizer
    per_device_train_batch_size=8,       # Batch size for training
    per_device_eval_batch_size=8,        # Batch size for evaluation
    weight_decay=0.01,                   # Regularization to prevent overfitting
    save_total_limit=3,                  # Only keep the last 3 checkpoints
    num_train_epochs=5,                  # Number of training epochs
    predict_with_generate=True,          # Enable text generation during evaluation
    logging_dir="./logs"                 # Directory for storing training logs
)




Trainer

In [14]:
from transformers import Seq2SeqTrainer

# Create the trainer object
trainer = Seq2SeqTrainer(
    model=model,                            # The model to be trained
    args=training_args,                     # The training arguments defined earlier
    train_dataset=tokenized_train_dataset,  # The tokenized training dataset
    eval_dataset=tokenized_eval_dataset,    # The tokenized evaluation dataset
    tokenizer=tokenizer                     # The tokenizer to handle input and output
)


# Train model

In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,21.553352
2,No log,18.918873
3,No log,17.439026
4,No log,16.559143
5,No log,16.236048


TrainOutput(global_step=55, training_loss=20.093738902698863, metrics={'train_runtime': 358.3321, 'train_samples_per_second': 1.158, 'train_steps_per_second': 0.153, 'total_flos': 284174301265920.0, 'train_loss': 20.093738902698863, 'epoch': 5.0})

# Evaluate

In [21]:

# Evaluate the model on the evaluation dataset
metrics = trainer.evaluate()

# Print the evaluation metrics
print(metrics)


{'eval_loss': 16.236047744750977, 'eval_runtime': 53.0288, 'eval_samples_per_second': 17.443, 'eval_steps_per_second': 2.187, 'epoch': 5.0}


# Summary Function

In [22]:

def summarize(text):
  # Tokenize the input text and move it to the correct device
  inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)

  # Generate the summary using the fine-tuned model
  summary_ids = model.generate(inputs["input_ids"], max_length=128, num_beams=4, early_stopping=True)

  # Decode the generated summary back into text and return it
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

Test

In [23]:
print(summarize(
    """
Person A: Hey, did you hear about the new project management software our company is planning to implement?

Person B: Yeah, I heard a bit about it. What’s the deal with it?

Person A: It’s called "TaskFlow." The management thinks it’s going to streamline our workflow, especially with remote teams. It’s supposed to integrate all the tools we use, like Slack, Trello, and Google Drive, into one platform.

Person B: That sounds interesting. But I’m a bit concerned about the learning curve. Is it user-friendly?

Person A: From what I’ve seen, it looks pretty intuitive. They’re also planning to run a couple of training sessions to get everyone up to speed. The first one is next Monday.

Person B: Okay, that helps. I guess I’ll have to attend that session. How does it compare to what we’re using now?

Person A: It’s supposed to be much more efficient. We’ll be able to track project progress more easily and get real-time updates. Plus, it has built-in analytics to help us with performance tracking.

Person B: That sounds promising. I just hope it doesn’t come with too many bugs at launch.

Person A: Yeah, that’s always a concern with new software. But they’ve been testing it for a while now, so fingers crossed it goes smoothly.

Person B: Let’s hope for the best. Thanks for the info!

Person A: No problem. See you at the training!
"""
))

Person B is excited about the new project management software "TaskFlow" that the company is planning to implement.


As expected, the model is very concise on the summarization but woud leave some details.