import spaces import gradio as gr # code import pandas as pd from datasets import load_dataset # from sentence_transformers import ( # SentenceTransformer, # SentenceTransformerTrainer, # SentenceTransformerTrainingArguments, # SentenceTransformerModelCardData # ) ### we can imporet everhtuing from the main class... from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer from sentence_transformers.losses import CachedMultipleNegativesRankingLoss from sentence_transformers.evaluation import InformationRetrievalEvaluator from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers def get_ir_evaluator(eval_ds): """create from anchor positive dataset instance... could make from a better dataset... LLM generate?""" corpus = {} queries = {} relevant_docs = {} # relevant documents (qid => set[cid]) for idx, example in enumerate(eval_ds): query = example['anchor'] queries[idx] = query document = example['positive'] corpus[idx] = document relevant_docs[idx] = set([idx]) # note: should have more relevant docs here ir_evaluator = InformationRetrievalEvaluator( queries=queries, corpus=corpus, relevant_docs=relevant_docs, name="ir-evaluator", ) return ir_evaluator @spaces.GPU(duration=3600) def train(hf_token, dataset_id, model_id, num_epochs, dev=True): ds = load_dataset(dataset_id, split="train", token=hf_token) ds = ds.shuffle(seed=42) if len(ds) > 1000 and dev: ds = ds.select(range(0, 999)) ds = ds.train_test_split(train_size=0.75) train_ds, eval_ds = ds['train'], ds['test'] print('train: ', len(train_ds), 'eval: ', len(eval_ds)) # model model = SentenceTransformer(model_id) # loss loss = CachedMultipleNegativesRankingLoss(model) # training args args = SentenceTransformerTrainingArguments( output_dir="outputs", # required num_train_epochs=num_epochs, # optional... per_device_train_batch_size=16, warmup_ratio=0.1, #fp16=True, # Set to False if your GPU can't handle FP16 #bf16=False, # Set to True if your GPU supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # Losses using "in-batch negatives" benefit from no duplicates save_total_limit=2 # per_device_eval_batch_size=1, # eval_strategy="epoch", # save_strategy="epoch", # logging_steps=100, # Optional tracking/debugging parameters: # eval_strategy="steps", # eval_steps=100, # save_strategy="steps", # save_steps=100, # logging_steps=100, # run_name="jina-code-vechain-pair", # Used in W&B if `wandb` is installed ) # ir evaluator ir_evaluator = get_ir_evaluator(eval_ds) # base model metrics base_metrics = ir_evaluator(model) print(ir_evaluator.primary_metric) print(base_metrics[ir_evaluator.primary_metric]) # train trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_ds, # eval_dataset=eval_ds, loss=loss, # evaluator=ir_evaluator, ) trainer.train() # fine tuned model metrics ft_metrics = ir_evaluator(model) print(ir_evaluator.primary_metric) print(ft_metrics[ir_evaluator.primary_metric]) if not dev: model.push_to_hub("fine-tuned-sentence-transformer", private=True, token=hf_token) metrics = pd.DataFrame([base_metrics, ft_metrics]).T print(metrics) return str(metrics) ## logs to UI # https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778 demo = gr.Interface(fn=train, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) # "dataframe" demo.launch()