Alexvatti's picture
Update app.py
fd46fd2 verified
raw
history blame
3.7 kB
import gradio as gr
from transformers import pipeline
import spaces
import torch
zero = torch.Tensor([0]).cuda()
print(zero.device)
# Check if GPU is available for FP16 inference
device = 0 if torch.cuda.is_available() else -1
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load Pipelines with FP16 (if GPU available)
question_answering = pipeline("question-answering", model="deepset/roberta-base-squad2", device=device)
code_generation = pipeline("text-generation", model="Salesforce/codegen-350M-mono", device=device)
summarization = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
#translation = pipeline("translation_en_to_fr", model="Helsinki-NLP/opus-mt-en-fr", device=device)
#translation = pipeline("translation", model="facebook/m2m100_418M", device=device)
text_generation = pipeline("text-generation", model="gpt2", device=device)
text_classification = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", device=device)
# Define Functions for Each Task
@spaces.GPU
def answer_question(context, question):
result = question_answering(question=question, context=context)
return result["answer"]
@spaces.GPU
def generate_code(prompt):
output = code_generation(prompt, max_length=50)
return output[0]['generated_text']
@spaces.GPU
def summarize_text(text):
output = summarization(text, max_length=100, min_length=30, do_sample=False)
return output[0]['summary_text']
@spaces.GPU
def generate_text(prompt):
output = text_generation(prompt, max_length=100)
return output[0]['generated_text']
@spaces.GPU
def classify_text(text):
output = text_classification(text)
return f"Label: {output[0]['label']} | Score: {output[0]['score']:.4f}"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# 🤖 Transformers Pipeline with FP16 Inference")
with gr.Tab("1️⃣ Question Answering"):
with gr.Row():
context = gr.Textbox(label="Context", lines=4, placeholder="Paste your paragraph here...")
question = gr.Textbox(label="Question", placeholder="Ask a question...")
answer_btn = gr.Button("Get Answer")
answer_output = gr.Textbox(label="Answer")
answer_btn.click(answer_question, inputs=[context, question], outputs=answer_output)
with gr.Tab("2️⃣ Code Generation"):
code_input = gr.Textbox(label="Code Prompt", placeholder="Write code snippet...")
code_btn = gr.Button("Generate Code")
code_output = gr.Textbox(label="Generated Code")
code_btn.click(generate_code, inputs=code_input, outputs=code_output)
with gr.Tab("3️⃣ Summarization"):
summary_input = gr.Textbox(label="Text to Summarize", lines=5, placeholder="Paste long text here...")
summary_btn = gr.Button("Summarize")
summary_output = gr.Textbox(label="Summary")
summary_btn.click(summarize_text, inputs=summary_input, outputs=summary_output)
with gr.Tab("5️⃣ Text Generation"):
text_input = gr.Textbox(label="Text Prompt", placeholder="Start your text...")
text_btn = gr.Button("Generate Text")
text_output = gr.Textbox(label="Generated Text")
text_btn.click(generate_text, inputs=text_input, outputs=text_output)
with gr.Tab("6️⃣ Text Classification"):
classify_input = gr.Textbox(label="Enter Text", placeholder="Enter a sentence...")
classify_btn = gr.Button("Classify Sentiment")
classify_output = gr.Textbox(label="Classification Result")
classify_btn.click(classify_text, inputs=classify_input, outputs=classify_output)
# Launch App
demo.launch()