DoctorSlimm commited on
Commit
d655f51
1 Parent(s): 281995c

add train code and requirements text file...

Browse files
Files changed (2) hide show
  1. app.py +119 -3
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,7 +1,123 @@
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
+ import spaces
2
  import gradio as gr
3
 
4
+ # code
5
+ import pandas as pd
6
+ from datasets import load_dataset
7
 
8
+ # from sentence_transformers import (
9
+ # SentenceTransformer,
10
+ # SentenceTransformerTrainer,
11
+ # SentenceTransformerTrainingArguments,
12
+ # SentenceTransformerModelCardData
13
+ # ) ### we can imporet everhtuing from the main class...
14
+
15
+ from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
16
+ from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
17
+ from sentence_transformers.evaluation import InformationRetrievalEvaluator
18
+ from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers
19
+
20
+
21
+
22
+
23
+ def get_ir_evaluator(eval_ds):
24
+ """create from anchor positive dataset instance... could make from a better dataset... LLM generate?"""
25
+
26
+ corpus = {}
27
+ queries = {}
28
+ relevant_docs = {} # relevant documents (qid => set[cid])
29
+ for idx, example in enumerate(eval_ds):
30
+ query = example['anchor']
31
+ queries[idx] = query
32
+
33
+ document = example['positive']
34
+ corpus[idx] = document
35
+
36
+ relevant_docs[idx] = set([idx]) # note: should have more relevant docs here
37
+
38
+ ir_evaluator = InformationRetrievalEvaluator(
39
+ queries=queries,
40
+ corpus=corpus,
41
+ relevant_docs=relevant_docs,
42
+ name="ir-evaluator",
43
+ )
44
+ return ir_evaluator
45
+
46
+
47
+
48
+
49
+ @spaces.GPU(duration=3600)
50
+ def train(hf_token, dataset_id, model_id, num_epochs, dev):
51
+
52
+ ds = load_dataset(dataset_id, split="train", token=hf_token)
53
+ ds = ds.shuffle(seed=42)
54
+ if len(ds) > 1000 and dev: ds = ds.select(range(0, 999))
55
+ ds = ds.train_test_split(train_size=0.75)
56
+ train_ds, eval_ds = ds['train'], ds['test']
57
+ print('train: ', len(train_ds), 'eval: ', len(eval_ds))
58
+
59
+ # model
60
+ model = SentenceTransformer(model_id)
61
+
62
+ # loss
63
+ loss = CachedMultipleNegativesRankingLoss(model)
64
+
65
+ # training args
66
+ args = SentenceTransformerTrainingArguments(
67
+ output_dir="outputs", # required
68
+ num_train_epochs=num_epochs, # optional...
69
+ per_device_train_batch_size=16,
70
+ warmup_ratio=0.1,
71
+ #fp16=True, # Set to False if your GPU can't handle FP16
72
+ #bf16=False, # Set to True if your GPU supports BF16
73
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # Losses using "in-batch negatives" benefit from no duplicates
74
+ save_total_limit=2
75
+ # per_device_eval_batch_size=1,
76
+ # eval_strategy="epoch",
77
+ # save_strategy="epoch",
78
+ # logging_steps=100,
79
+ # Optional tracking/debugging parameters:
80
+ # eval_strategy="steps",
81
+ # eval_steps=100,
82
+ # save_strategy="steps",
83
+ # save_steps=100,
84
+ # logging_steps=100,
85
+ # run_name="jina-code-vechain-pair", # Used in W&B if `wandb` is installed
86
+ )
87
+
88
+ # ir evaluator
89
+ ir_evaluator = get_ir_evaluator(eval_ds)
90
+
91
+ # base model metrics
92
+ base_metrics = ir_evaluator(model)
93
+ print(ir_evaluator.primary_metric)
94
+ print(base_metrics[ir_evaluator.primary_metric])
95
+
96
+
97
+ # train
98
+ trainer = SentenceTransformerTrainer(
99
+ model=model,
100
+ args=args,
101
+ train_dataset=train_ds,
102
+ # eval_dataset=eval_ds,
103
+ loss=loss,
104
+ # evaluator=ir_evaluator,
105
+ )
106
+ trainer.train()
107
+
108
+ # fine tuned model metrics
109
+ ft_metrics = ir_evaluator(model)
110
+ print(ir_evaluator.primary_metric)
111
+ print(ft_metrics[ir_evaluator.primary_metric])
112
+
113
+
114
+ metrics = pd.DataFrame([base_metrics, ft_metrics]).T
115
+ print(metrics)
116
+ return str(metrics)
117
+
118
+
119
+ ## logs to UI
120
+ # https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778
121
+
122
+ demo = gr.Interface(fn=greet, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) # "dataframe"
123
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ datasets
2
+ accelerate
3
+ sentence-transformers