Marroco93 commited on
Commit
021d564
1 Parent(s): 9530b69

no message

Browse files
Files changed (1) hide show
  1. main.py +27 -11
main.py CHANGED
@@ -108,31 +108,33 @@ def reduce_tokens(text: str):
108
  token_count = len(reduced_doc)
109
  return reduced_text, token_count
110
 
111
- def segment_text(text: str, max_length=512):
112
  # Use spaCy to divide the document into sentences
113
  doc = nlp(text)
114
- sentences = [sent.text for sent in doc.sents]
115
 
116
- # Group sentences into segments of approximately max_length tokens
117
  segments = []
118
  current_segment = []
119
  current_length = 0
120
 
121
  for sentence in sentences:
122
- sentence_length = len(sentence.split())
123
- if current_length + sentence_length > max_length:
124
- segments.append(' '.join(current_segment))
 
125
  current_segment = [sentence]
126
  current_length = sentence_length
127
  else:
128
  current_segment.append(sentence)
129
  current_length += sentence_length
130
 
 
131
  if current_segment:
132
  segments.append(' '.join(current_segment))
133
 
134
  return segments
135
 
 
136
  classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
137
 
138
  def classify_segments(segments):
@@ -146,11 +148,25 @@ async def summarize(request: TextRequest):
146
  processed_text = preprocess_text(request.text)
147
  segments = segment_text(processed_text)
148
 
149
- # Classify each segment
150
- classified_segments = classify_segments(segments)
151
-
152
- # Optionally, reduce tokens for some specific task or summarize
153
- reduced_texts = [reduce_tokens(segment)[0] for segment in segments]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  return {
156
  "classified_segments": classified_segments,
 
108
  token_count = len(reduced_doc)
109
  return reduced_text, token_count
110
 
111
+ def segment_text(text: str, max_tokens=500): # Slightly less than 512 for safety
112
  # Use spaCy to divide the document into sentences
113
  doc = nlp(text)
114
+ sentences = [sent.text.strip() for sent in doc.sents]
115
 
 
116
  segments = []
117
  current_segment = []
118
  current_length = 0
119
 
120
  for sentence in sentences:
121
+ sentence_length = len(sentence.split()) # Simple word count
122
+ if current_length + sentence_length > max_tokens:
123
+ if current_segment: # Make sure there's something to add
124
+ segments.append(' '.join(current_segment))
125
  current_segment = [sentence]
126
  current_length = sentence_length
127
  else:
128
  current_segment.append(sentence)
129
  current_length += sentence_length
130
 
131
+ # Add the last segment if any
132
  if current_segment:
133
  segments.append(' '.join(current_segment))
134
 
135
  return segments
136
 
137
+
138
  classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
139
 
140
  def classify_segments(segments):
 
148
  processed_text = preprocess_text(request.text)
149
  segments = segment_text(processed_text)
150
 
151
+ # Classify each segment safely
152
+ classified_segments = []
153
+ for segment in segments:
154
+ try:
155
+ result = classifier(segment)
156
+ classified_segments.append(result)
157
+ except Exception as e:
158
+ print(f"Error classifying segment: {e}")
159
+ classified_segments.append({"error": str(e)})
160
+
161
+ # Optional: Reduce tokens or summarize
162
+ reduced_texts = []
163
+ for segment in segments:
164
+ try:
165
+ reduced_text, token_count = reduce_tokens(segment)
166
+ reduced_texts.append((reduced_text, token_count))
167
+ except Exception as e:
168
+ print(f"Error during token reduction: {e}")
169
+ reduced_texts.append(("Error", 0))
170
 
171
  return {
172
  "classified_segments": classified_segments,