Spaces:
Sleeping
Sleeping
ajaynagotha
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
from datasets import load_dataset
|
3 |
-
from transformers import
|
4 |
import torch
|
5 |
import logging
|
6 |
from fastapi import FastAPI, HTTPException
|
@@ -19,9 +19,11 @@ logger.info("Dataset loaded successfully")
|
|
19 |
# Load model and tokenizer
|
20 |
logger.info("Loading the model and tokenizer")
|
21 |
model_name = "facebook/bart-large-cnn"
|
22 |
-
tokenizer =
|
23 |
-
model =
|
24 |
-
|
|
|
|
|
25 |
|
26 |
def clean_answer(answer):
|
27 |
special_tokens = set(tokenizer.all_special_tokens)
|
@@ -51,7 +53,7 @@ def answer_question(question):
|
|
51 |
padding='max_length'
|
52 |
)
|
53 |
|
54 |
-
inputs = {k: v.to(
|
55 |
|
56 |
logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
|
57 |
|
@@ -131,7 +133,7 @@ iface = gr.Interface(
|
|
131 |
# Mount Gradio app to FastAPI
|
132 |
app = gr.mount_gradio_app(app, iface, path="/")
|
133 |
|
134 |
-
# For
|
135 |
if __name__ == "__main__":
|
136 |
import uvicorn
|
137 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
import gradio as gr
|
2 |
from datasets import load_dataset
|
3 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
4 |
import torch
|
5 |
import logging
|
6 |
from fastapi import FastAPI, HTTPException
|
|
|
19 |
# Load model and tokenizer
|
20 |
logger.info("Loading the model and tokenizer")
|
21 |
model_name = "facebook/bart-large-cnn"
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
+
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
model.to(device)
|
26 |
+
logger.info(f"Model and tokenizer loaded successfully. Using device: {device}")
|
27 |
|
28 |
def clean_answer(answer):
|
29 |
special_tokens = set(tokenizer.all_special_tokens)
|
|
|
53 |
padding='max_length'
|
54 |
)
|
55 |
|
56 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
57 |
|
58 |
logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
|
59 |
|
|
|
133 |
# Mount Gradio app to FastAPI
|
134 |
app = gr.mount_gradio_app(app, iface, path="/")
|
135 |
|
136 |
+
# For Hugging Face Spaces
|
137 |
if __name__ == "__main__":
|
138 |
import uvicorn
|
139 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|