DoctorSlimm's picture
add train code and requirements text file...
d655f51
raw
history blame
3.94 kB
import spaces
import gradio as gr
# code
import pandas as pd
from datasets import load_dataset
# from sentence_transformers import (
# SentenceTransformer,
# SentenceTransformerTrainer,
# SentenceTransformerTrainingArguments,
# SentenceTransformerModelCardData
# ) ### we can imporet everhtuing from the main class...
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers
def get_ir_evaluator(eval_ds):
"""create from anchor positive dataset instance... could make from a better dataset... LLM generate?"""
corpus = {}
queries = {}
relevant_docs = {} # relevant documents (qid => set[cid])
for idx, example in enumerate(eval_ds):
query = example['anchor']
queries[idx] = query
document = example['positive']
corpus[idx] = document
relevant_docs[idx] = set([idx]) # note: should have more relevant docs here
ir_evaluator = InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant_docs,
name="ir-evaluator",
)
return ir_evaluator
@spaces.GPU(duration=3600)
def train(hf_token, dataset_id, model_id, num_epochs, dev):
ds = load_dataset(dataset_id, split="train", token=hf_token)
ds = ds.shuffle(seed=42)
if len(ds) > 1000 and dev: ds = ds.select(range(0, 999))
ds = ds.train_test_split(train_size=0.75)
train_ds, eval_ds = ds['train'], ds['test']
print('train: ', len(train_ds), 'eval: ', len(eval_ds))
# model
model = SentenceTransformer(model_id)
# loss
loss = CachedMultipleNegativesRankingLoss(model)
# training args
args = SentenceTransformerTrainingArguments(
output_dir="outputs", # required
num_train_epochs=num_epochs, # optional...
per_device_train_batch_size=16,
warmup_ratio=0.1,
#fp16=True, # Set to False if your GPU can't handle FP16
#bf16=False, # Set to True if your GPU supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # Losses using "in-batch negatives" benefit from no duplicates
save_total_limit=2
# per_device_eval_batch_size=1,
# eval_strategy="epoch",
# save_strategy="epoch",
# logging_steps=100,
# Optional tracking/debugging parameters:
# eval_strategy="steps",
# eval_steps=100,
# save_strategy="steps",
# save_steps=100,
# logging_steps=100,
# run_name="jina-code-vechain-pair", # Used in W&B if `wandb` is installed
)
# ir evaluator
ir_evaluator = get_ir_evaluator(eval_ds)
# base model metrics
base_metrics = ir_evaluator(model)
print(ir_evaluator.primary_metric)
print(base_metrics[ir_evaluator.primary_metric])
# train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_ds,
# eval_dataset=eval_ds,
loss=loss,
# evaluator=ir_evaluator,
)
trainer.train()
# fine tuned model metrics
ft_metrics = ir_evaluator(model)
print(ir_evaluator.primary_metric)
print(ft_metrics[ir_evaluator.primary_metric])
metrics = pd.DataFrame([base_metrics, ft_metrics]).T
print(metrics)
return str(metrics)
## logs to UI
# https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778
demo = gr.Interface(fn=greet, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) # "dataframe"
demo.launch()