shahzaib201 commited on
Commit
d223493
·
verified ·
1 Parent(s): 872a00a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -4
main.py CHANGED
@@ -1,9 +1,11 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import pipeline
4
 
5
- # Load the summarization model
6
- summarizer = pipeline("summarization", model="t5-small")
 
 
7
 
8
  # Pydantic model for input validation
9
  class TextInput(BaseModel):
@@ -16,5 +18,13 @@ app = FastAPI()
16
  # Endpoint for text summarization
17
  @app.post("/summarize_text")
18
  async def summarize_text_endpoint(item: TextInput):
19
- summary = summarizer(item.text, max_length=item.max_length, min_length=30, do_sample=False)[0]['summary_text']
 
 
 
 
 
 
 
 
20
  return {"summary": summary}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
 
5
+ # Load the model and tokenizer
6
+ model_name = "shahzaib201/AI_OEL"
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
 
10
  # Pydantic model for input validation
11
  class TextInput(BaseModel):
 
18
  # Endpoint for text summarization
19
  @app.post("/summarize_text")
20
  async def summarize_text_endpoint(item: TextInput):
21
+ # Tokenize the input text
22
+ inputs = tokenizer(item.text, return_tensors="pt", max_length=1024, truncation=True)
23
+
24
+ # Generate the summary
25
+ summary_ids = model.generate(inputs.input_ids, max_length=item.max_length, num_beams=4, length_penalty=2.0, early_stopping=True)
26
+
27
+ # Decode the generated summary
28
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
29
+
30
  return {"summary": summary}