CodeHima commited on
Commit
8766819
·
1 Parent(s): 34e855f

feat: Add utility functions for text processing and model prediction

Browse files
utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # utils/__init__.py
2
+
3
+ from .text_processing import extract_text_from_pdf, split_into_clauses
4
+ from .model_utils import predict_unfairness
utils/model_utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def predict_unfairness(text, model, tokenizer):
4
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
5
+
6
+ model.eval()
7
+ with torch.no_grad():
8
+ outputs = model(**inputs)
9
+
10
+ probabilities = torch.softmax(outputs.logits, dim=-1).squeeze()
11
+ predicted_class = torch.argmax(probabilities).item()
12
+
13
+ label_mapping = {0: 'clearly_fair', 1: 'potentially_unfair', 2: 'clearly_unfair'}
14
+ predicted_label = label_mapping[predicted_class]
15
+
16
+ return predicted_label, probabilities.tolist()
utils/text_processing.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PyPDF2
2
+ import spacy
3
+ import re
4
+
5
+ nlp = spacy.load("en_core_web_sm")
6
+
7
+ def extract_text_from_pdf(pdf_file):
8
+ reader = PyPDF2.PdfReader(pdf_file)
9
+ text = ""
10
+ for page in reader.pages:
11
+ text += page.extract_text()
12
+ return text
13
+
14
+ def split_into_clauses(text):
15
+ # Preprocess the text
16
+ text = re.sub(r'\s+', ' ', text) # Remove extra whitespace
17
+ text = re.sub(r'\n+', '\n', text) # Remove extra newlines
18
+
19
+ # Use spaCy to parse the text
20
+ doc = nlp(text)
21
+
22
+ clauses = []
23
+ current_clause = []
24
+
25
+ for sent in doc.sents:
26
+ current_clause.append(sent.text)
27
+
28
+ # Check if this sentence ends a clause
29
+ if re.search(r'\d+\.|\([a-z]\)|\([iv]+\)', sent.text) or len(' '.join(current_clause)) > 200:
30
+ clauses.append(' '.join(current_clause))
31
+ current_clause = []
32
+
33
+ # Add any remaining text as the last clause
34
+ if current_clause:
35
+ clauses.append(' '.join(current_clause))
36
+
37
+ # Post-process clauses
38
+ cleaned_clauses = []
39
+ for clause in clauses:
40
+ # Remove leading/trailing whitespace and numbers
41
+ clause = re.sub(r'^\s*\d+\.?\s*', '', clause.strip())
42
+ if clause:
43
+ cleaned_clauses.append(clause)
44
+
45
+ return cleaned_clauses