Marroco93 commited on
Commit
fe81f5c
1 Parent(s): cb746f1

no message

Browse files
Files changed (1) hide show
  1. main.py +40 -33
main.py CHANGED
@@ -108,63 +108,70 @@ def reduce_tokens(text: str):
108
  token_count = len(reduced_doc)
109
  return reduced_text, token_count
110
 
111
- def segment_text(text: str, max_tokens=500): # Slightly less than 512 for safety
112
- # Use spaCy to divide the document into sentences
113
  doc = nlp(text)
114
- sentences = [sent.text.strip() for sent in doc.sents]
115
-
116
  segments = []
117
  current_segment = []
118
  current_length = 0
119
-
120
- for sentence in sentences:
121
- sentence_words = sentence.split()
122
- sentence_length = len(sentence_words)
123
-
124
- # If sentence exceeds max_tokens, split it further
125
  if sentence_length > max_tokens:
126
- parts = split_into_parts(sentence, max_tokens)
127
- segments.extend(parts) # Add split parts directly to segments
128
- continue
129
-
130
- if current_length + sentence_length > max_tokens:
 
 
131
  segments.append(' '.join(current_segment))
132
  current_segment = [sentence]
133
  current_length = sentence_length
134
  else:
135
  current_segment.append(sentence)
136
  current_length += sentence_length
137
-
138
- if current_segment: # Add the last segment if any
139
- segments.append(' '.join(current_segment))
140
-
141
- return segments
142
 
143
- def split_into_parts(text, max_tokens):
144
- words = text.split()
145
- parts = []
146
- for i in range(0, len(words), max_tokens):
147
- part = " ".join(words[i:i + max_tokens])
148
- parts.append(part)
149
- return parts
150
 
 
151
 
152
 
153
  classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
154
 
155
 
156
  def classify_segments(segments):
157
- results = []
 
 
158
  for segment in segments:
159
  try:
160
- if len(segment.split()) <= 512: # Ensure segment is within the limit
161
  result = classifier(segment)
162
- results.append(result)
163
  else:
164
- results.append({"error": f"Segment too long: {len(segment.split())} tokens"})
165
  except Exception as e:
166
- results.append({"error": str(e)})
167
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
 
170
 
 
108
  token_count = len(reduced_doc)
109
  return reduced_text, token_count
110
 
111
+ def segment_text(text: str, max_tokens=500): # Setting a conservative limit below 512
 
112
  doc = nlp(text)
 
 
113
  segments = []
114
  current_segment = []
115
  current_length = 0
116
+
117
+ for sent in doc.sents:
118
+ sentence = sent.text.strip()
119
+ sentence_length = len(sentence.split()) # Counting words for simplicity
120
+
 
121
  if sentence_length > max_tokens:
122
+ # Split long sentences into smaller chunks if a single sentence exceeds max_tokens
123
+ words = sentence.split()
124
+ while words:
125
+ part = ' '.join(words[:max_tokens])
126
+ segments.append(part)
127
+ words = words[max_tokens:]
128
+ elif current_length + sentence_length > max_tokens:
129
  segments.append(' '.join(current_segment))
130
  current_segment = [sentence]
131
  current_length = sentence_length
132
  else:
133
  current_segment.append(sentence)
134
  current_length += sentence_length
 
 
 
 
 
135
 
136
+ if current_segment: # Add the last segment
137
+ segments.append(' '.join(current_segment))
 
 
 
 
 
138
 
139
+ return segments
140
 
141
 
142
  classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
143
 
144
 
145
  def classify_segments(segments):
146
+ classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
147
+ classified_segments = []
148
+
149
  for segment in segments:
150
  try:
151
+ if len(segment.split()) <= 512: # Double-check to avoid errors
152
  result = classifier(segment)
153
+ classified_segments.append(result)
154
  else:
155
+ classified_segments.append({"error": f"Segment too long: {len(segment.split())} tokens"})
156
  except Exception as e:
157
+ classified_segments.append({"error": str(e)})
158
+
159
+ return classified_segments
160
+
161
+
162
+ @app.post("/process_document")
163
+ async def process_document(request: TextRequest):
164
+ try:
165
+ processed_text = preprocess_text(request.text)
166
+ segments = segment_text(processed_text)
167
+ classified_segments = classify_segments(segments)
168
+
169
+ return {
170
+ "classified_segments": classified_segments
171
+ }
172
+ except Exception as e:
173
+ print(f"Error during document processing: {e}")
174
+ raise HTTPException(status_code=500, detail=str(e))
175
 
176
 
177