bertugmirasyedi commited on
Commit
3e788ad
·
1 Parent(s): 936ef88
Files changed (1) hide show
  1. app.py +62 -73
app.py CHANGED
@@ -1,14 +1,7 @@
1
- from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  import os
4
- from transformers import (
5
- AutoModelForSeq2SeqLM,
6
- AutoTokenizer,
7
- AutoModelForSequenceClassification,
8
- )
9
- from optimum.onnxruntime import ORTModelForSeq2SeqLM, ORTModelForSequenceClassification
10
- from sentence_transformers import SentenceTransformer
11
- import torch
12
 
13
  # Define the FastAPI app
14
  app = FastAPI(docs_url="/")
@@ -22,45 +15,14 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # Define the Google Books API key
26
  key = os.environ.get("GOOGLE_BOOKS_API_KEY")
27
 
28
- # Define summarization models
29
- summary_tokenizer_normal = AutoTokenizer.from_pretrained("lidiya/bart-base-samsum")
30
- summary_model_normal = AutoModelForSeq2SeqLM.from_pretrained("lidiya/bart-base-samsum")
31
- summary_tokenizer_onnx = AutoTokenizer.from_pretrained("optimum/t5-small")
32
- summary_model_onnx = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small")
33
-
34
- # Define classification models
35
- classification_tokenizer_normal = AutoTokenizer.from_pretrained(
36
- "sileod/deberta-v3-base-tasksource-nli"
37
- )
38
- classification_model_normal = AutoModelForSequenceClassification.from_pretrained(
39
- "sileod/deberta-v3-base-tasksource-nli"
40
- )
41
-
42
- audience_classification_tokenizer = AutoTokenizer.from_pretrained(
43
- "bertugmirasyedi/deberta-v3-base-book-classification"
44
- )
45
- audience_classification_model = AutoModelForSequenceClassification.from_pretrained(
46
- "bertugmirasyedi/deberta-v3-base-book-classification"
47
- )
48
-
49
- level_classification_tokenizer = AutoTokenizer.from_pretrained(
50
- "bertugmirasyedi/deberta-v3-base-level-classification"
51
- )
52
- level_classification_model = AutoModelForSequenceClassification.from_pretrained(
53
- "bertugmirasyedi/deberta-v3-base-level-classification"
54
- )
55
-
56
- # Define similarity model
57
- similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
58
-
59
 
60
  @app.get("/search")
61
  async def search(
62
  query: str,
63
  add_chatgpt_results: bool = False,
 
64
  n_results: int = 10,
65
  ):
66
  """
@@ -215,21 +177,22 @@ async def search(
215
 
216
  return titles, authors, publishers, descriptions, images
217
 
218
- # Run the openalex_search function
219
- (
220
- titles_placeholder,
221
- authors_placeholder,
222
- publishers_placeholder,
223
- descriptions_placeholder,
224
- images_placeholder,
225
- ) = openalex_search(query, n_results=n_results)
 
226
 
227
- # Append the results to the lists
228
- [titles.append(title) for title in titles_placeholder]
229
- [authors.append(author) for author in authors_placeholder]
230
- [publishers.append(publisher) for publisher in publishers_placeholder]
231
- [descriptions.append(description) for description in descriptions_placeholder]
232
- [images.append(image) for image in images_placeholder]
233
 
234
  # Calculate the elapsed time between the first and second checkpoints
235
  second_checkpoint = time.time()
@@ -249,7 +212,7 @@ async def search(
249
  images = []
250
 
251
  # Set the OpenAI API key
252
- openai.api_key = os.environ.get("OPENAI_API_KEY")
253
 
254
  # Create ChatGPT query
255
  chatgpt_response = openai.ChatCompletion.create(
@@ -348,7 +311,9 @@ async def search(
348
 
349
 
350
  @app.post("/classify")
351
- async def classify(data: list, runtime: str = "normal"):
 
 
352
  """
353
  Create classifier pipeline and return the results.
354
  """
@@ -368,11 +333,16 @@ async def classify(data: list, runtime: str = "normal"):
368
  pipeline,
369
  )
370
  from optimum.onnxruntime import ORTModelForSequenceClassification
 
371
 
372
- if runtime == "normal":
373
  # Define the zero-shot classifier
374
- tokenizer = classification_tokenizer_normal
375
- model = classification_model_normal
 
 
 
 
376
 
377
  classifier_pipe = pipeline(
378
  "zero-shot-classification",
@@ -401,20 +371,36 @@ async def classify(data: list, runtime: str = "normal"):
401
  }
402
  for doc in combined_data
403
  ]
404
- elif runtime == "local":
 
405
  ### Define the classifier for audience prediction ###
 
 
 
 
 
 
 
 
406
  audience_classifier = pipeline(
407
  "text-classification",
408
- model=audience_classification_model,
409
- tokenizer=audience_classification_tokenizer,
410
  device=-1,
411
  )
412
-
413
  ### Define the classifier for level prediction ###
 
 
 
 
 
 
 
 
414
  level_classifier = pipeline(
415
  "text-classification",
416
- model=level_classification_model,
417
- tokenizer=level_classification_tokenizer,
418
  device=-1,
419
  )
420
 
@@ -457,7 +443,7 @@ async def find_similar(data: list, top_k: int = 5):
457
  for title, description, publisher in zip(titles, descriptions, publishers)
458
  ]
459
 
460
- sentence_transformer = similarity_model
461
  book_embeddings = sentence_transformer.encode(combined_data, convert_to_tensor=True)
462
 
463
  # Make sure that the top_k value is not greater than the number of books
@@ -485,7 +471,10 @@ async def find_similar(data: list, top_k: int = 5):
485
 
486
 
487
  @app.post("/summarize")
488
- async def summarize(descriptions: list, runtime="normal"):
 
 
 
489
  """
490
  Summarize the descriptions and return the results.
491
  """
@@ -499,12 +488,12 @@ async def summarize(descriptions: list, runtime="normal"):
499
 
500
  # Define the summarizer model and tokenizer
501
  if runtime == "normal":
502
- tokenizer = summary_tokenizer_normal
503
- normal_model = summary_model_normal
504
- model = BetterTransformer.transform(normal_model)
505
  elif runtime == "onnxruntime":
506
- tokenizer = summary_tokenizer_onnx
507
- model = summary_model_onnx
508
 
509
  # Create the summarizer pipeline
510
  summarizer_pipe = pipeline("summarization", model=model, tokenizer=tokenizer)
 
1
+ from fastapi import FastAPI, Query
2
  from fastapi.middleware.cors import CORSMiddleware
3
  import os
4
+
 
 
 
 
 
 
 
5
 
6
  # Define the FastAPI app
7
  app = FastAPI(docs_url="/")
 
15
  allow_headers=["*"],
16
  )
17
 
 
18
  key = os.environ.get("GOOGLE_BOOKS_API_KEY")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @app.get("/search")
22
  async def search(
23
  query: str,
24
  add_chatgpt_results: bool = False,
25
+ add_articles: bool = False,
26
  n_results: int = 10,
27
  ):
28
  """
 
177
 
178
  return titles, authors, publishers, descriptions, images
179
 
180
+ if add_articles:
181
+ # Run the openalex_search function
182
+ (
183
+ titles_placeholder,
184
+ authors_placeholder,
185
+ publishers_placeholder,
186
+ descriptions_placeholder,
187
+ images_placeholder,
188
+ ) = openalex_search(query, n_results=n_results)
189
 
190
+ # Append the results to the lists
191
+ [titles.append(title) for title in titles_placeholder]
192
+ [authors.append(author) for author in authors_placeholder]
193
+ [publishers.append(publisher) for publisher in publishers_placeholder]
194
+ [descriptions.append(description) for description in descriptions_placeholder]
195
+ [images.append(image) for image in images_placeholder]
196
 
197
  # Calculate the elapsed time between the first and second checkpoints
198
  second_checkpoint = time.time()
 
212
  images = []
213
 
214
  # Set the OpenAI API key
215
+ openai.api_key = "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"
216
 
217
  # Create ChatGPT query
218
  chatgpt_response = openai.ChatCompletion.create(
 
311
 
312
 
313
  @app.post("/classify")
314
+ async def classify(
315
+ data: list, runtime: str = Query(default="trained", enum=["trained", "zero-shot"])
316
+ ):
317
  """
318
  Create classifier pipeline and return the results.
319
  """
 
333
  pipeline,
334
  )
335
  from optimum.onnxruntime import ORTModelForSequenceClassification
336
+ from optimum.bettertransformer import BetterTransformer
337
 
338
+ if runtime == "zero-shot":
339
  # Define the zero-shot classifier
340
+ tokenizer = AutoTokenizer.from_pretrained(
341
+ "sileod/deberta-v3-base-tasksource-nli"
342
+ )
343
+ model = AutoModelForSequenceClassification.from_pretrained(
344
+ "sileod/deberta-v3-base-tasksource-nli"
345
+ )
346
 
347
  classifier_pipe = pipeline(
348
  "zero-shot-classification",
 
371
  }
372
  for doc in combined_data
373
  ]
374
+
375
+ elif runtime == "trained":
376
  ### Define the classifier for audience prediction ###
377
+ audience_tokenizer = AutoTokenizer.from_pretrained(
378
+ "bertugmirasyedi/deberta-v3-base-book-classification",
379
+ max_len=512,
380
+ )
381
+ audience_model = AutoModelForSequenceClassification.from_pretrained(
382
+ "bertugmirasyedi/deberta-v3-base-book-classification"
383
+ )
384
+
385
  audience_classifier = pipeline(
386
  "text-classification",
387
+ model=audience_model,
388
+ tokenizer=audience_tokenizer,
389
  device=-1,
390
  )
 
391
  ### Define the classifier for level prediction ###
392
+ level_tokenizer = AutoTokenizer.from_pretrained(
393
+ "bertugmirasyedi/deberta-v3-base-level-classification",
394
+ max_len=512,
395
+ )
396
+ level_model = AutoModelForSequenceClassification.from_pretrained(
397
+ "bertugmirasyedi/deberta-v3-base-level-classification"
398
+ )
399
+
400
  level_classifier = pipeline(
401
  "text-classification",
402
+ model=level_model,
403
+ tokenizer=level_tokenizer,
404
  device=-1,
405
  )
406
 
 
443
  for title, description, publisher in zip(titles, descriptions, publishers)
444
  ]
445
 
446
+ sentence_transformer = SentenceTransformer("all-MiniLM-L6-v2")
447
  book_embeddings = sentence_transformer.encode(combined_data, convert_to_tensor=True)
448
 
449
  # Make sure that the top_k value is not greater than the number of books
 
471
 
472
 
473
  @app.post("/summarize")
474
+ async def summarize(
475
+ descriptions: list,
476
+ runtime: str = Query(default="normal", enum=["normal", "onnxruntime"]),
477
+ ):
478
  """
479
  Summarize the descriptions and return the results.
480
  """
 
488
 
489
  # Define the summarizer model and tokenizer
490
  if runtime == "normal":
491
+ tokenizer = AutoTokenizer.from_pretrained("lidiya/bart-base-samsum")
492
+ model = AutoModelForSeq2SeqLM.from_pretrained("lidiya/bart-base-samsum")
493
+ model = BetterTransformer.transform(model)
494
  elif runtime == "onnxruntime":
495
+ tokenizer = AutoTokenizer.from_pretrained("optimum/t5-small")
496
+ model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small")
497
 
498
  # Create the summarizer pipeline
499
  summarizer_pipe = pipeline("summarization", model=model, tokenizer=tokenizer)