gorkaartola commited on
Commit
b27edec
1 Parent(s): a65e727

Upload run.py

Browse files
Files changed (1) hide show
  1. run.py +124 -0
run.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
2
+ from datasets import load_dataset, load_metric
3
+ import evaluate
4
+ from torch.utils.data import DataLoader
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ import options as op
9
+
10
+ def tp_tf_test(model_selector, test_dataset, queries_selector, prompt_selector, metric_selector, prediction_strategy_selector):
11
+
12
+ #Load test dataset___________________________
13
+ test_dataset = load_dataset(test_dataset)['test']
14
+
15
+ #Load queries________________________________
16
+ queries_data_files = {'queries': queries_selector}
17
+ queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
18
+
19
+ #Load prompt_________________________________
20
+ prompt = prompt_selector
21
+
22
+ #Load prediction strategias__________________
23
+ prediction_strategies = prediction_strategy_selector
24
+
25
+ #Load model, tokenizer and collator__________
26
+ model = AutoModelForSequenceClassification.from_pretrained(model_selector)
27
+ if torch.cuda.is_available():
28
+ device = torch.device("cuda")
29
+ model.to(device)
30
+ tokenizer = AutoTokenizer.from_pretrained(model_selector)
31
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
32
+
33
+ #Calculate and save predictions______________
34
+ #'''
35
+ def tokenize_function(example, prompt = '', query = ''):
36
+ queries = []
37
+ for i in range(len(example['title'])):
38
+ queries.append(prompt + query)
39
+ tokenize = tokenizer(example['title'], queries, truncation='only_first')
40
+ #tokenize['query'] = queries
41
+ return tokenize
42
+
43
+ results_test = pd.DataFrame()
44
+ for query_data in queries_dataset:
45
+ query = query_data['SDGquery']
46
+ tokenized_test_dataset = test_dataset.map(tokenize_function, batched = True, fn_kwargs = {'prompt' : prompt, 'query' : query})
47
+ columns_to_remove = test_dataset.column_names
48
+ for column_name in ['label_ids', 'nli_label']:
49
+ columns_to_remove.remove(column_name)
50
+ tokenized_test_dataset_for_inference = tokenized_test_dataset.remove_columns(columns_to_remove)
51
+ tokenized_test_dataset_for_inference.set_format('torch')
52
+ dataloader = DataLoader(
53
+ tokenized_test_dataset_for_inference,
54
+ batch_size=8,
55
+ collate_fn = data_collator,
56
+ )
57
+ values = []
58
+ labels = []
59
+ nli_labels =[]
60
+ for batch in dataloader:
61
+ if torch.cuda.is_available():
62
+ data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
63
+ else:
64
+ data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']}
65
+ with torch.no_grad():
66
+ outputs = model(**data)
67
+ logits = outputs.logits
68
+ entail_contradiction_logits = logits[:,[0,2]]
69
+ probs = entail_contradiction_logits.softmax(dim=1)
70
+ predictions = probs[:,1].tolist()
71
+ label_ids = batch['labels'].tolist()
72
+ nli_label_ids = batch['nli_label'].tolist()
73
+ for prediction, label, nli_label in zip(predictions, label_ids, nli_label_ids):
74
+ values.append(prediction)
75
+ labels.append(label)
76
+ nli_labels.append(nli_label)
77
+ results_test['dataset_labels'] = labels
78
+ results_test['nli_labels'] = nli_labels
79
+ results_test[query] = values
80
+
81
+ results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False)
82
+ #'''
83
+ #Load saved predictions____________________________
84
+ '''
85
+ results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
86
+ '''
87
+ #Analize predictions_______________________________
88
+ def logits_labels(raw):
89
+ raw_logits = raw.iloc[:,2:]
90
+ logits = np.zeros(shape=(len(raw_logits.index),17))
91
+ for i in range(17):
92
+ queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
93
+ logits[:,i]=raw_logits[queries].max(axis=1)
94
+ labels = raw[["dataset_labels","nli_labels"]]
95
+ labels = np.array(labels).astype(int)
96
+ return logits, labels
97
+
98
+ predictions, references = logits_labels(results_test)
99
+ prediction_strategies = [op.prediction_strategy_options[x] for x in prediction_strategy_selector]
100
+
101
+ metric = evaluate.load(metric_selector)
102
+ metric.add_batch(predictions = predictions, references = references)
103
+ results = metric.compute(prediction_strategies = prediction_strategies)
104
+ prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "")
105
+ output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '_Strategies-'+ prediction_strategies_names +'.csv'
106
+ with open(output_filename, 'a') as results_file:
107
+ for result in results:
108
+ results[result].to_csv(results_file, mode='a', index_label = result)
109
+ print(results[result], '\n')
110
+ return output_filename
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+