Spaces:
Sleeping
Sleeping
bertugmirasyedi
commited on
Commit
·
3e788ad
1
Parent(s):
936ef88
Overhaul
Browse files
app.py
CHANGED
@@ -1,14 +1,7 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
import os
|
4 |
-
|
5 |
-
AutoModelForSeq2SeqLM,
|
6 |
-
AutoTokenizer,
|
7 |
-
AutoModelForSequenceClassification,
|
8 |
-
)
|
9 |
-
from optimum.onnxruntime import ORTModelForSeq2SeqLM, ORTModelForSequenceClassification
|
10 |
-
from sentence_transformers import SentenceTransformer
|
11 |
-
import torch
|
12 |
|
13 |
# Define the FastAPI app
|
14 |
app = FastAPI(docs_url="/")
|
@@ -22,45 +15,14 @@ app.add_middleware(
|
|
22 |
allow_headers=["*"],
|
23 |
)
|
24 |
|
25 |
-
# Define the Google Books API key
|
26 |
key = os.environ.get("GOOGLE_BOOKS_API_KEY")
|
27 |
|
28 |
-
# Define summarization models
|
29 |
-
summary_tokenizer_normal = AutoTokenizer.from_pretrained("lidiya/bart-base-samsum")
|
30 |
-
summary_model_normal = AutoModelForSeq2SeqLM.from_pretrained("lidiya/bart-base-samsum")
|
31 |
-
summary_tokenizer_onnx = AutoTokenizer.from_pretrained("optimum/t5-small")
|
32 |
-
summary_model_onnx = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small")
|
33 |
-
|
34 |
-
# Define classification models
|
35 |
-
classification_tokenizer_normal = AutoTokenizer.from_pretrained(
|
36 |
-
"sileod/deberta-v3-base-tasksource-nli"
|
37 |
-
)
|
38 |
-
classification_model_normal = AutoModelForSequenceClassification.from_pretrained(
|
39 |
-
"sileod/deberta-v3-base-tasksource-nli"
|
40 |
-
)
|
41 |
-
|
42 |
-
audience_classification_tokenizer = AutoTokenizer.from_pretrained(
|
43 |
-
"bertugmirasyedi/deberta-v3-base-book-classification"
|
44 |
-
)
|
45 |
-
audience_classification_model = AutoModelForSequenceClassification.from_pretrained(
|
46 |
-
"bertugmirasyedi/deberta-v3-base-book-classification"
|
47 |
-
)
|
48 |
-
|
49 |
-
level_classification_tokenizer = AutoTokenizer.from_pretrained(
|
50 |
-
"bertugmirasyedi/deberta-v3-base-level-classification"
|
51 |
-
)
|
52 |
-
level_classification_model = AutoModelForSequenceClassification.from_pretrained(
|
53 |
-
"bertugmirasyedi/deberta-v3-base-level-classification"
|
54 |
-
)
|
55 |
-
|
56 |
-
# Define similarity model
|
57 |
-
similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
|
58 |
-
|
59 |
|
60 |
@app.get("/search")
|
61 |
async def search(
|
62 |
query: str,
|
63 |
add_chatgpt_results: bool = False,
|
|
|
64 |
n_results: int = 10,
|
65 |
):
|
66 |
"""
|
@@ -215,21 +177,22 @@ async def search(
|
|
215 |
|
216 |
return titles, authors, publishers, descriptions, images
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
|
234 |
# Calculate the elapsed time between the first and second checkpoints
|
235 |
second_checkpoint = time.time()
|
@@ -249,7 +212,7 @@ async def search(
|
|
249 |
images = []
|
250 |
|
251 |
# Set the OpenAI API key
|
252 |
-
openai.api_key =
|
253 |
|
254 |
# Create ChatGPT query
|
255 |
chatgpt_response = openai.ChatCompletion.create(
|
@@ -348,7 +311,9 @@ async def search(
|
|
348 |
|
349 |
|
350 |
@app.post("/classify")
|
351 |
-
async def classify(
|
|
|
|
|
352 |
"""
|
353 |
Create classifier pipeline and return the results.
|
354 |
"""
|
@@ -368,11 +333,16 @@ async def classify(data: list, runtime: str = "normal"):
|
|
368 |
pipeline,
|
369 |
)
|
370 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
|
|
371 |
|
372 |
-
if runtime == "
|
373 |
# Define the zero-shot classifier
|
374 |
-
tokenizer =
|
375 |
-
|
|
|
|
|
|
|
|
|
376 |
|
377 |
classifier_pipe = pipeline(
|
378 |
"zero-shot-classification",
|
@@ -401,20 +371,36 @@ async def classify(data: list, runtime: str = "normal"):
|
|
401 |
}
|
402 |
for doc in combined_data
|
403 |
]
|
404 |
-
|
|
|
405 |
### Define the classifier for audience prediction ###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
audience_classifier = pipeline(
|
407 |
"text-classification",
|
408 |
-
model=
|
409 |
-
tokenizer=
|
410 |
device=-1,
|
411 |
)
|
412 |
-
|
413 |
### Define the classifier for level prediction ###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
level_classifier = pipeline(
|
415 |
"text-classification",
|
416 |
-
model=
|
417 |
-
tokenizer=
|
418 |
device=-1,
|
419 |
)
|
420 |
|
@@ -457,7 +443,7 @@ async def find_similar(data: list, top_k: int = 5):
|
|
457 |
for title, description, publisher in zip(titles, descriptions, publishers)
|
458 |
]
|
459 |
|
460 |
-
sentence_transformer =
|
461 |
book_embeddings = sentence_transformer.encode(combined_data, convert_to_tensor=True)
|
462 |
|
463 |
# Make sure that the top_k value is not greater than the number of books
|
@@ -485,7 +471,10 @@ async def find_similar(data: list, top_k: int = 5):
|
|
485 |
|
486 |
|
487 |
@app.post("/summarize")
|
488 |
-
async def summarize(
|
|
|
|
|
|
|
489 |
"""
|
490 |
Summarize the descriptions and return the results.
|
491 |
"""
|
@@ -499,12 +488,12 @@ async def summarize(descriptions: list, runtime="normal"):
|
|
499 |
|
500 |
# Define the summarizer model and tokenizer
|
501 |
if runtime == "normal":
|
502 |
-
tokenizer =
|
503 |
-
|
504 |
-
model = BetterTransformer.transform(
|
505 |
elif runtime == "onnxruntime":
|
506 |
-
tokenizer =
|
507 |
-
model =
|
508 |
|
509 |
# Create the summarizer pipeline
|
510 |
summarizer_pipe = pipeline("summarization", model=model, tokenizer=tokenizer)
|
|
|
1 |
+
from fastapi import FastAPI, Query
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
import os
|
4 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Define the FastAPI app
|
7 |
app = FastAPI(docs_url="/")
|
|
|
15 |
allow_headers=["*"],
|
16 |
)
|
17 |
|
|
|
18 |
key = os.environ.get("GOOGLE_BOOKS_API_KEY")
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
@app.get("/search")
|
22 |
async def search(
|
23 |
query: str,
|
24 |
add_chatgpt_results: bool = False,
|
25 |
+
add_articles: bool = False,
|
26 |
n_results: int = 10,
|
27 |
):
|
28 |
"""
|
|
|
177 |
|
178 |
return titles, authors, publishers, descriptions, images
|
179 |
|
180 |
+
if add_articles:
|
181 |
+
# Run the openalex_search function
|
182 |
+
(
|
183 |
+
titles_placeholder,
|
184 |
+
authors_placeholder,
|
185 |
+
publishers_placeholder,
|
186 |
+
descriptions_placeholder,
|
187 |
+
images_placeholder,
|
188 |
+
) = openalex_search(query, n_results=n_results)
|
189 |
|
190 |
+
# Append the results to the lists
|
191 |
+
[titles.append(title) for title in titles_placeholder]
|
192 |
+
[authors.append(author) for author in authors_placeholder]
|
193 |
+
[publishers.append(publisher) for publisher in publishers_placeholder]
|
194 |
+
[descriptions.append(description) for description in descriptions_placeholder]
|
195 |
+
[images.append(image) for image in images_placeholder]
|
196 |
|
197 |
# Calculate the elapsed time between the first and second checkpoints
|
198 |
second_checkpoint = time.time()
|
|
|
212 |
images = []
|
213 |
|
214 |
# Set the OpenAI API key
|
215 |
+
openai.api_key = "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"
|
216 |
|
217 |
# Create ChatGPT query
|
218 |
chatgpt_response = openai.ChatCompletion.create(
|
|
|
311 |
|
312 |
|
313 |
@app.post("/classify")
|
314 |
+
async def classify(
|
315 |
+
data: list, runtime: str = Query(default="trained", enum=["trained", "zero-shot"])
|
316 |
+
):
|
317 |
"""
|
318 |
Create classifier pipeline and return the results.
|
319 |
"""
|
|
|
333 |
pipeline,
|
334 |
)
|
335 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
336 |
+
from optimum.bettertransformer import BetterTransformer
|
337 |
|
338 |
+
if runtime == "zero-shot":
|
339 |
# Define the zero-shot classifier
|
340 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
341 |
+
"sileod/deberta-v3-base-tasksource-nli"
|
342 |
+
)
|
343 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
344 |
+
"sileod/deberta-v3-base-tasksource-nli"
|
345 |
+
)
|
346 |
|
347 |
classifier_pipe = pipeline(
|
348 |
"zero-shot-classification",
|
|
|
371 |
}
|
372 |
for doc in combined_data
|
373 |
]
|
374 |
+
|
375 |
+
elif runtime == "trained":
|
376 |
### Define the classifier for audience prediction ###
|
377 |
+
audience_tokenizer = AutoTokenizer.from_pretrained(
|
378 |
+
"bertugmirasyedi/deberta-v3-base-book-classification",
|
379 |
+
max_len=512,
|
380 |
+
)
|
381 |
+
audience_model = AutoModelForSequenceClassification.from_pretrained(
|
382 |
+
"bertugmirasyedi/deberta-v3-base-book-classification"
|
383 |
+
)
|
384 |
+
|
385 |
audience_classifier = pipeline(
|
386 |
"text-classification",
|
387 |
+
model=audience_model,
|
388 |
+
tokenizer=audience_tokenizer,
|
389 |
device=-1,
|
390 |
)
|
|
|
391 |
### Define the classifier for level prediction ###
|
392 |
+
level_tokenizer = AutoTokenizer.from_pretrained(
|
393 |
+
"bertugmirasyedi/deberta-v3-base-level-classification",
|
394 |
+
max_len=512,
|
395 |
+
)
|
396 |
+
level_model = AutoModelForSequenceClassification.from_pretrained(
|
397 |
+
"bertugmirasyedi/deberta-v3-base-level-classification"
|
398 |
+
)
|
399 |
+
|
400 |
level_classifier = pipeline(
|
401 |
"text-classification",
|
402 |
+
model=level_model,
|
403 |
+
tokenizer=level_tokenizer,
|
404 |
device=-1,
|
405 |
)
|
406 |
|
|
|
443 |
for title, description, publisher in zip(titles, descriptions, publishers)
|
444 |
]
|
445 |
|
446 |
+
sentence_transformer = SentenceTransformer("all-MiniLM-L6-v2")
|
447 |
book_embeddings = sentence_transformer.encode(combined_data, convert_to_tensor=True)
|
448 |
|
449 |
# Make sure that the top_k value is not greater than the number of books
|
|
|
471 |
|
472 |
|
473 |
@app.post("/summarize")
|
474 |
+
async def summarize(
|
475 |
+
descriptions: list,
|
476 |
+
runtime: str = Query(default="normal", enum=["normal", "onnxruntime"]),
|
477 |
+
):
|
478 |
"""
|
479 |
Summarize the descriptions and return the results.
|
480 |
"""
|
|
|
488 |
|
489 |
# Define the summarizer model and tokenizer
|
490 |
if runtime == "normal":
|
491 |
+
tokenizer = AutoTokenizer.from_pretrained("lidiya/bart-base-samsum")
|
492 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("lidiya/bart-base-samsum")
|
493 |
+
model = BetterTransformer.transform(model)
|
494 |
elif runtime == "onnxruntime":
|
495 |
+
tokenizer = AutoTokenizer.from_pretrained("optimum/t5-small")
|
496 |
+
model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small")
|
497 |
|
498 |
# Create the summarizer pipeline
|
499 |
summarizer_pipe = pipeline("summarization", model=model, tokenizer=tokenizer)
|