sohomghosh commited on
Commit
4855ea0
·
1 Parent(s): 9777bb1

Upload 2 files

Browse files
fin_readability_sustainability.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from transformers import RobertaModel, RobertaTokenizer, BertModel, BertTokenizer
5
+ import pandas as pd
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ MAX_LEN = 128
10
+ BATCH_SIZE = 20
11
+ text_col_name = 'sentence'
12
+
13
+ def scoring_data_prep(dataset):
14
+ out = []
15
+ target = []
16
+ mask = []
17
+
18
+ for i in range(len(dataset)):
19
+ rec = dataset[i]
20
+ out.append(rec['ids'].reshape(-1,MAX_LEN))
21
+ mask.append(rec['mask'].reshape(-1,MAX_LEN))
22
+
23
+ out_stack = torch.cat(out, dim = 0)
24
+ mask_stack = torch.cat(mask, dim =0 )
25
+ out_stack = out_stack.to(device, dtype = torch.long)
26
+ mask_stack = mask_stack.to(device, dtype = torch.long)
27
+
28
+ return out_stack, mask_stack
29
+
30
+ class Triage(Dataset):
31
+ """
32
+ This is a subclass of torch packages Dataset class. It processes input to create ids, masks and targets required for model training.
33
+ """
34
+
35
+ def __init__(self, dataframe, tokenizer, max_len, text_col_name):
36
+ self.len = len(dataframe)
37
+ self.data = dataframe
38
+ self.tokenizer = tokenizer
39
+ self.max_len = max_len
40
+ self.text_col_name = text_col_name
41
+
42
+
43
+ def __getitem__(self, index):
44
+ title = str(self.data[self.text_col_name][index])
45
+ title = " ".join(title.split())
46
+ inputs = self.tokenizer.encode_plus(
47
+ title,
48
+ None,
49
+ add_special_tokens=True,
50
+ max_length=self.max_len,
51
+ pad_to_max_length=True, #padding='max_length' #For future version use `padding='max_length'`
52
+ return_token_type_ids=True,
53
+ truncation=True,
54
+ )
55
+ ids = inputs["input_ids"]
56
+ mask = inputs["attention_mask"]
57
+
58
+ return {
59
+ "ids": torch.tensor(ids, dtype=torch.long),
60
+ "mask": torch.tensor(mask, dtype=torch.long),
61
+
62
+ }
63
+
64
+ def __len__(self):
65
+ return self.len
66
+
67
+ class BERTClass(torch.nn.Module):
68
+ def __init__(self, num_class, task):
69
+ super(BERTClass, self).__init__()
70
+ self.num_class = num_class
71
+ if task =="sustanability":
72
+ self.l1 = RobertaModel.from_pretrained("roberta-base")
73
+ else:
74
+ self.l1 = BertModel.from_pretrained("ProsusAI/finbert")
75
+ self.pre_classifier = torch.nn.Linear(768, 768)
76
+ self.dropout = torch.nn.Dropout(0.3)
77
+ self.classifier = torch.nn.Linear(768, self.num_class)
78
+ self.history = dict()
79
+
80
+ def forward(self, input_ids, attention_mask):
81
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
82
+ hidden_state = output_1[0]
83
+ pooler = hidden_state[:, 0]
84
+ pooler = self.pre_classifier(pooler)
85
+ pooler = torch.nn.ReLU()(pooler)
86
+ pooler = self.dropout(pooler)
87
+ output = self.classifier(pooler)
88
+ return output
89
+
90
+ def do_predict(model, tokenizer, test_df):
91
+ test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
92
+ test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
93
+ test_loader = DataLoader(test_set, **test_params)
94
+ out_stack, mask_stack = scoring_data_prep(dataset = test_set)
95
+ n = 0
96
+ combined_output = []
97
+ model.eval()
98
+ with torch.no_grad():
99
+ while n < test_df.shape[0]:
100
+ output = model(out_stack[n:n+BATCH_SIZE,:],mask_stack[n:n+BATCH_SIZE,:])
101
+ n = n + BATCH_SIZE
102
+ combined_output.append(output)
103
+ combined_output = torch.cat(combined_output, dim = 0)
104
+ preds = torch.argsort(combined_output, axis = 1, descending = True)
105
+ preds = preds.to('cpu')
106
+ actual_predictions = [i[0] for i in preds.tolist()]
107
+ combined_output = combined_output.to('cpu')
108
+ prob_predictions= [i[1] for i in combined_output.tolist()]
109
+ return (actual_predictions, prob_predictions)
110
+
readability_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d360ee86e83576331a3db9a59f7998f7c00107d95ac860f463232834150596e0
3
+ size 1316371909