Spaces:
Sleeping
Sleeping
no message
Browse files
main.py
CHANGED
@@ -141,7 +141,7 @@ def segment_text(text: str, max_tokens=500): # Setting a conservative limit bel
|
|
141 |
|
142 |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
|
143 |
|
144 |
-
def robust_segment_text(text: str, max_tokens=510):
|
145 |
doc = nlp(text)
|
146 |
segments = []
|
147 |
current_segment = []
|
@@ -152,26 +152,24 @@ def robust_segment_text(text: str, max_tokens=510): # Slightly less to ensure a
|
|
152 |
sentence_tokens = tokenizer.encode(' '.join(words), add_special_tokens=False)
|
153 |
|
154 |
if len(current_tokens) + len(sentence_tokens) > max_tokens:
|
155 |
-
|
156 |
-
segments.append(tokenizer.decode(current_tokens))
|
157 |
current_segment = words
|
158 |
current_tokens = sentence_tokens
|
159 |
else:
|
160 |
current_segment.extend(words)
|
161 |
current_tokens.extend(sentence_tokens)
|
162 |
|
163 |
-
if current_tokens:
|
164 |
segments.append(tokenizer.decode(current_tokens))
|
165 |
|
166 |
return segments
|
167 |
|
168 |
|
169 |
-
# Load a zero-shot classification model
|
170 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
171 |
|
172 |
def classify_segments(segments):
|
173 |
-
labels = ["Coverage Details", "Exclusions", "Premiums", "Claims Process",
|
174 |
-
"Policy Limits", "Legal and Regulatory Information", "Renewals and Cancellations",
|
175 |
"Discounts and Incentives", "Duties and Responsibilities", "Contact Information"]
|
176 |
classified_segments = []
|
177 |
for segment in segments:
|
@@ -181,10 +179,14 @@ def classify_segments(segments):
|
|
181 |
|
182 |
|
183 |
|
|
|
|
|
|
|
|
|
184 |
@app.post("/process_document")
|
185 |
async def process_document(request: TextRequest):
|
186 |
try:
|
187 |
-
processed_text = preprocess_text(request.text)
|
188 |
segments = robust_segment_text(processed_text)
|
189 |
classified_segments = classify_segments(segments)
|
190 |
|
@@ -196,7 +198,6 @@ async def process_document(request: TextRequest):
|
|
196 |
raise HTTPException(status_code=500, detail=str(e))
|
197 |
|
198 |
|
199 |
-
|
200 |
@app.post("/summarize")
|
201 |
async def summarize(request: TextRequest):
|
202 |
try:
|
|
|
141 |
|
142 |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
|
143 |
|
144 |
+
def robust_segment_text(text: str, max_tokens=510):
|
145 |
doc = nlp(text)
|
146 |
segments = []
|
147 |
current_segment = []
|
|
|
152 |
sentence_tokens = tokenizer.encode(' '.join(words), add_special_tokens=False)
|
153 |
|
154 |
if len(current_tokens) + len(sentence_tokens) > max_tokens:
|
155 |
+
segments.append(tokenizer.decode(current_tokens))
|
|
|
156 |
current_segment = words
|
157 |
current_tokens = sentence_tokens
|
158 |
else:
|
159 |
current_segment.extend(words)
|
160 |
current_tokens.extend(sentence_tokens)
|
161 |
|
162 |
+
if current_tokens: # Add the last segment
|
163 |
segments.append(tokenizer.decode(current_tokens))
|
164 |
|
165 |
return segments
|
166 |
|
167 |
|
|
|
168 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
169 |
|
170 |
def classify_segments(segments):
|
171 |
+
labels = ["Coverage Details", "Exclusions", "Premiums", "Claims Process",
|
172 |
+
"Policy Limits", "Legal and Regulatory Information", "Renewals and Cancellations",
|
173 |
"Discounts and Incentives", "Duties and Responsibilities", "Contact Information"]
|
174 |
classified_segments = []
|
175 |
for segment in segments:
|
|
|
179 |
|
180 |
|
181 |
|
182 |
+
|
183 |
+
class TextRequest(BaseModel):
|
184 |
+
text: str
|
185 |
+
|
186 |
@app.post("/process_document")
|
187 |
async def process_document(request: TextRequest):
|
188 |
try:
|
189 |
+
processed_text = preprocess_text(request.text) # Ensure preprocess_text is defined
|
190 |
segments = robust_segment_text(processed_text)
|
191 |
classified_segments = classify_segments(segments)
|
192 |
|
|
|
198 |
raise HTTPException(status_code=500, detail=str(e))
|
199 |
|
200 |
|
|
|
201 |
@app.post("/summarize")
|
202 |
async def summarize(request: TextRequest):
|
203 |
try:
|