Marroco93 commited on
Commit
44cdc71
1 Parent(s): 1d6eb67

no message

Browse files
Files changed (1) hide show
  1. main.py +16 -38
main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.responses import JSONResponse
@@ -101,50 +102,27 @@ tokenizer = AutoTokenizer.from_pretrained("nsi319/legal-pegasus")
101
  model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus")
102
 
103
 
104
- class SummarizeRequest(BaseModel):
105
  text: str
106
 
107
- def chunk_text(text, max_length=1024):
108
- """Split the text into manageable parts for the model to handle."""
109
- words = text.split()
110
- current_chunk = ""
111
- chunks = []
112
 
113
- for word in words:
114
- if len(tokenizer.encode(current_chunk + word)) < max_length:
115
- current_chunk += word + ' '
116
- else:
117
- chunks.append(current_chunk.strip())
118
- current_chunk = word + ' '
119
- chunks.append(current_chunk.strip()) # Add the last chunk
120
- return chunks
121
-
122
- def summarize_legal_text(text):
123
- """Generate summaries for each chunk and combine them."""
124
- chunks = chunk_text(text, max_length=900) # A bit less than 1024 to be safe
125
- all_summaries = []
126
-
127
- for chunk in chunks:
128
- inputs = tokenizer.encode(chunk, return_tensors='pt', max_length=1024, truncation=True)
129
- summary_ids = model.generate(
130
- inputs,
131
- num_beams=5,
132
- no_repeat_ngram_size=3,
133
- length_penalty=1.0,
134
- min_length=150,
135
- max_length=300, # You can adjust this based on your needs
136
- early_stopping=True
137
- )
138
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
139
- all_summaries.append(summary)
140
-
141
- return " ".join(all_summaries)
142
 
143
  @app.post("/summarize")
144
- async def summarize_text(request: SummarizeRequest):
145
  try:
146
- summarized_text = summarize_legal_text(request.text)
147
- return JSONResponse(content={"summary": summarized_text})
 
 
148
  except Exception as e:
149
  print(f"Error during summarization: {e}")
150
  raise HTTPException(status_code=500, detail=str(e))
 
1
+ import re
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.responses import StreamingResponse
4
  from fastapi.responses import JSONResponse
 
102
  model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus")
103
 
104
 
105
+ class TextRequest(BaseModel):
106
  text: str
107
 
 
 
 
 
 
108
 
109
+ def preprocess_text(text: str) -> str:
110
+ # Normalize whitespace
111
+ text = re.sub(r'\s+', ' ', text.strip())
112
+
113
+ # Optional: Add additional preprocessing steps
114
+ # E.g., handling or stripping special characters, lowercasing, etc.
115
+ text = re.sub(r'[^\w\s]', '', text) # Remove punctuation for simplicity
116
+
117
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  @app.post("/summarize")
120
+ async def summarize(request: TextRequest):
121
  try:
122
+ processed_text = preprocess_text(request.text)
123
+
124
+ return {"summary": processed_text}
125
+
126
  except Exception as e:
127
  print(f"Error during summarization: {e}")
128
  raise HTTPException(status_code=500, detail=str(e))