Marroco93 commited on
Commit
182943b
1 Parent(s): 131731f

no message

Browse files
Files changed (1) hide show
  1. main.py +10 -8
main.py CHANGED
@@ -140,7 +140,6 @@ def segment_text(text: str, max_tokens=500): # Setting a conservative limit bel
140
 
141
 
142
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
143
- classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
144
 
145
  def robust_segment_text(text: str, max_tokens=510): # Slightly less to ensure a buffer
146
  doc = nlp(text)
@@ -167,15 +166,18 @@ def robust_segment_text(text: str, max_tokens=510): # Slightly less to ensure a
167
  return segments
168
 
169
 
 
 
 
170
  def classify_segments(segments):
171
- results = []
 
 
 
172
  for segment in segments:
173
- try:
174
- result = classifier(segment)
175
- results.append(result)
176
- except Exception as e:
177
- results.append({"error": str(e), "segment": segment[:50]}) # Include a part of the segment to debug if needed
178
- return results
179
 
180
 
181
 
 
140
 
141
 
142
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
 
143
 
144
  def robust_segment_text(text: str, max_tokens=510): # Slightly less to ensure a buffer
145
  doc = nlp(text)
 
166
  return segments
167
 
168
 
169
+ # Load a zero-shot classification model
170
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
171
+
172
  def classify_segments(segments):
173
+ labels = ["Coverage Details", "Exclusions", "Premiums", "Claims Process",
174
+ "Policy Limits", "Legal and Regulatory Information", "Renewals and Cancellations",
175
+ "Discounts and Incentives", "Duties and Responsibilities", "Contact Information"]
176
+ classified_segments = []
177
  for segment in segments:
178
+ result = classifier(segment, candidate_labels=labels, multi_label=True)
179
+ classified_segments.append(result)
180
+ return classified_segments
 
 
 
181
 
182
 
183