ajaynagotha commited on
Commit
5be946a
·
verified ·
1 Parent(s): cd4ab5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from datasets import load_dataset
3
- from transformers import BartForQuestionAnswering, BartTokenizer
4
  import torch
5
  import logging
6
  from fastapi import FastAPI, HTTPException
@@ -19,9 +19,11 @@ logger.info("Dataset loaded successfully")
19
  # Load model and tokenizer
20
  logger.info("Loading the model and tokenizer")
21
  model_name = "facebook/bart-large-cnn"
22
- tokenizer = BartTokenizer.from_pretrained(model_name)
23
- model = BartForQuestionAnswering.from_pretrained(model_name)
24
- logger.info("Model and tokenizer loaded successfully")
 
 
25
 
26
  def clean_answer(answer):
27
  special_tokens = set(tokenizer.all_special_tokens)
@@ -51,7 +53,7 @@ def answer_question(question):
51
  padding='max_length'
52
  )
53
 
54
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
55
 
56
  logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
57
 
@@ -131,7 +133,7 @@ iface = gr.Interface(
131
  # Mount Gradio app to FastAPI
132
  app = gr.mount_gradio_app(app, iface, path="/")
133
 
134
- # For local development and testing
135
  if __name__ == "__main__":
136
  import uvicorn
137
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
4
  import torch
5
  import logging
6
  from fastapi import FastAPI, HTTPException
 
19
  # Load model and tokenizer
20
  logger.info("Loading the model and tokenizer")
21
  model_name = "facebook/bart-large-cnn"
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ model.to(device)
26
+ logger.info(f"Model and tokenizer loaded successfully. Using device: {device}")
27
 
28
  def clean_answer(answer):
29
  special_tokens = set(tokenizer.all_special_tokens)
 
53
  padding='max_length'
54
  )
55
 
56
+ inputs = {k: v.to(device) for k, v in inputs.items()}
57
 
58
  logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
59
 
 
133
  # Mount Gradio app to FastAPI
134
  app = gr.mount_gradio_app(app, iface, path="/")
135
 
136
+ # For Hugging Face Spaces
137
  if __name__ == "__main__":
138
  import uvicorn
139
  uvicorn.run(app, host="0.0.0.0", port=7860)