Marroco93 commited on
Commit
3717137
1 Parent(s): 182943b

no message

Browse files
Files changed (1) hide show
  1. main.py +10 -9
main.py CHANGED
@@ -141,7 +141,7 @@ def segment_text(text: str, max_tokens=500): # Setting a conservative limit bel
141
 
142
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
143
 
144
- def robust_segment_text(text: str, max_tokens=510): # Slightly less to ensure a buffer
145
  doc = nlp(text)
146
  segments = []
147
  current_segment = []
@@ -152,26 +152,24 @@ def robust_segment_text(text: str, max_tokens=510): # Slightly less to ensure a
152
  sentence_tokens = tokenizer.encode(' '.join(words), add_special_tokens=False)
153
 
154
  if len(current_tokens) + len(sentence_tokens) > max_tokens:
155
- if current_tokens:
156
- segments.append(tokenizer.decode(current_tokens))
157
  current_segment = words
158
  current_tokens = sentence_tokens
159
  else:
160
  current_segment.extend(words)
161
  current_tokens.extend(sentence_tokens)
162
 
163
- if current_tokens:
164
  segments.append(tokenizer.decode(current_tokens))
165
 
166
  return segments
167
 
168
 
169
- # Load a zero-shot classification model
170
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
171
 
172
  def classify_segments(segments):
173
- labels = ["Coverage Details", "Exclusions", "Premiums", "Claims Process",
174
- "Policy Limits", "Legal and Regulatory Information", "Renewals and Cancellations",
175
  "Discounts and Incentives", "Duties and Responsibilities", "Contact Information"]
176
  classified_segments = []
177
  for segment in segments:
@@ -181,10 +179,14 @@ def classify_segments(segments):
181
 
182
 
183
 
 
 
 
 
184
  @app.post("/process_document")
185
  async def process_document(request: TextRequest):
186
  try:
187
- processed_text = preprocess_text(request.text)
188
  segments = robust_segment_text(processed_text)
189
  classified_segments = classify_segments(segments)
190
 
@@ -196,7 +198,6 @@ async def process_document(request: TextRequest):
196
  raise HTTPException(status_code=500, detail=str(e))
197
 
198
 
199
-
200
  @app.post("/summarize")
201
  async def summarize(request: TextRequest):
202
  try:
 
141
 
142
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
143
 
144
+ def robust_segment_text(text: str, max_tokens=510):
145
  doc = nlp(text)
146
  segments = []
147
  current_segment = []
 
152
  sentence_tokens = tokenizer.encode(' '.join(words), add_special_tokens=False)
153
 
154
  if len(current_tokens) + len(sentence_tokens) > max_tokens:
155
+ segments.append(tokenizer.decode(current_tokens))
 
156
  current_segment = words
157
  current_tokens = sentence_tokens
158
  else:
159
  current_segment.extend(words)
160
  current_tokens.extend(sentence_tokens)
161
 
162
+ if current_tokens: # Add the last segment
163
  segments.append(tokenizer.decode(current_tokens))
164
 
165
  return segments
166
 
167
 
 
168
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
169
 
170
  def classify_segments(segments):
171
+ labels = ["Coverage Details", "Exclusions", "Premiums", "Claims Process",
172
+ "Policy Limits", "Legal and Regulatory Information", "Renewals and Cancellations",
173
  "Discounts and Incentives", "Duties and Responsibilities", "Contact Information"]
174
  classified_segments = []
175
  for segment in segments:
 
179
 
180
 
181
 
182
+
183
+ class TextRequest(BaseModel):
184
+ text: str
185
+
186
  @app.post("/process_document")
187
  async def process_document(request: TextRequest):
188
  try:
189
+ processed_text = preprocess_text(request.text) # Ensure preprocess_text is defined
190
  segments = robust_segment_text(processed_text)
191
  classified_segments = classify_segments(segments)
192
 
 
198
  raise HTTPException(status_code=500, detail=str(e))
199
 
200
 
 
201
  @app.post("/summarize")
202
  async def summarize(request: TextRequest):
203
  try: