|
import spaces |
|
import gradio as gr |
|
|
|
|
|
import pandas as pd |
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer |
|
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss |
|
from sentence_transformers.evaluation import InformationRetrievalEvaluator |
|
from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers |
|
|
|
|
|
|
|
|
|
def get_ir_evaluator(eval_ds): |
|
"""create from anchor positive dataset instance... could make from a better dataset... LLM generate?""" |
|
|
|
corpus = {} |
|
queries = {} |
|
relevant_docs = {} |
|
for idx, example in enumerate(eval_ds): |
|
query = example['anchor'] |
|
queries[idx] = query |
|
|
|
document = example['positive'] |
|
corpus[idx] = document |
|
|
|
relevant_docs[idx] = set([idx]) |
|
|
|
ir_evaluator = InformationRetrievalEvaluator( |
|
queries=queries, |
|
corpus=corpus, |
|
relevant_docs=relevant_docs, |
|
name="ir-evaluator", |
|
) |
|
return ir_evaluator |
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=3600) |
|
def train(hf_token, dataset_id, model_id, num_epochs, dev=True): |
|
|
|
ds = load_dataset(dataset_id, split="train", token=hf_token) |
|
ds = ds.shuffle(seed=42) |
|
if len(ds) > 1000 and dev: ds = ds.select(range(0, 999)) |
|
ds = ds.train_test_split(train_size=0.75) |
|
train_ds, eval_ds = ds['train'], ds['test'] |
|
print('train: ', len(train_ds), 'eval: ', len(eval_ds)) |
|
|
|
|
|
model = SentenceTransformer(model_id) |
|
|
|
|
|
loss = CachedMultipleNegativesRankingLoss(model) |
|
|
|
|
|
args = SentenceTransformerTrainingArguments( |
|
output_dir="outputs", |
|
num_train_epochs=num_epochs, |
|
per_device_train_batch_size=16, |
|
warmup_ratio=0.1, |
|
|
|
|
|
batch_sampler=BatchSamplers.NO_DUPLICATES, |
|
save_total_limit=2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
ir_evaluator = get_ir_evaluator(eval_ds) |
|
|
|
|
|
base_metrics = ir_evaluator(model) |
|
print(ir_evaluator.primary_metric) |
|
print(base_metrics[ir_evaluator.primary_metric]) |
|
|
|
|
|
|
|
trainer = SentenceTransformerTrainer( |
|
model=model, |
|
args=args, |
|
train_dataset=train_ds, |
|
|
|
loss=loss, |
|
|
|
) |
|
trainer.train() |
|
|
|
|
|
ft_metrics = ir_evaluator(model) |
|
print(ir_evaluator.primary_metric) |
|
print(ft_metrics[ir_evaluator.primary_metric]) |
|
|
|
if not dev: model.push_to_hub("fine-tuned-sentence-transformer", private=True, token=hf_token) |
|
|
|
|
|
metrics = pd.DataFrame([base_metrics, ft_metrics]).T |
|
print(metrics) |
|
return str(metrics) |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface(fn=train, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) |
|
demo.launch() |
|
|