Spaces:
Sleeping
Sleeping
no message
Browse files
main.py
CHANGED
@@ -11,7 +11,7 @@ import nltk
|
|
11 |
import os
|
12 |
import google.protobuf # This line should execute without errors if protobuf is installed correctly
|
13 |
import sentencepiece
|
14 |
-
from transformers import pipeline, AutoTokenizer,
|
15 |
import spacy
|
16 |
|
17 |
|
@@ -139,7 +139,12 @@ def segment_text(text: str, max_tokens=500): # Setting a conservative limit bel
|
|
139 |
return segments
|
140 |
|
141 |
|
142 |
-
tokenizer
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
def robust_segment_text(text: str, max_tokens=510):
|
145 |
doc = nlp(text)
|
@@ -165,16 +170,17 @@ def robust_segment_text(text: str, max_tokens=510):
|
|
165 |
return segments
|
166 |
|
167 |
|
168 |
-
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
169 |
-
|
170 |
def classify_segments(segments):
|
171 |
-
labels = [
|
172 |
-
|
173 |
-
|
|
|
|
|
174 |
classified_segments = []
|
175 |
for segment in segments:
|
176 |
-
|
177 |
-
|
|
|
178 |
return classified_segments
|
179 |
|
180 |
|
|
|
11 |
import os
|
12 |
import google.protobuf # This line should execute without errors if protobuf is installed correctly
|
13 |
import sentencepiece
|
14 |
+
from transformers import pipeline, AutoTokenizer,AutoModelForSequenceClassification
|
15 |
import spacy
|
16 |
|
17 |
|
|
|
139 |
return segments
|
140 |
|
141 |
|
142 |
+
# Load the tokenizer and model
|
143 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
144 |
+
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
|
145 |
+
|
146 |
+
# Set up the pipeline
|
147 |
+
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
148 |
|
149 |
def robust_segment_text(text: str, max_tokens=510):
|
150 |
doc = nlp(text)
|
|
|
170 |
return segments
|
171 |
|
172 |
|
|
|
|
|
173 |
def classify_segments(segments):
|
174 |
+
labels = [
|
175 |
+
"Coverage Details", "Exclusions", "Premiums", "Claims Process",
|
176 |
+
"Policy Limits", "Legal and Regulatory Information", "Renewals and Cancellations",
|
177 |
+
"Discounts and Incentives", "Duties and Responsibilities", "Contact Information"
|
178 |
+
]
|
179 |
classified_segments = []
|
180 |
for segment in segments:
|
181 |
+
# Note: Adjust the input here based on how your model was trained
|
182 |
+
predictions = classifier(segment)
|
183 |
+
classified_segments.append(predictions)
|
184 |
return classified_segments
|
185 |
|
186 |
|