Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
import spaces | |
import torch | |
zero = torch.Tensor([0]).cuda() | |
print(zero.device) | |
# Check if GPU is available for FP16 inference | |
device = 0 if torch.cuda.is_available() else -1 | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Load Pipelines with FP16 (if GPU available) | |
question_answering = pipeline("question-answering", model="deepset/roberta-base-squad2", device=device) | |
code_generation = pipeline("text-generation", model="Salesforce/codegen-350M-mono", device=device) | |
summarization = pipeline("summarization", model="facebook/bart-large-cnn", device=device) | |
#translation = pipeline("translation_en_to_fr", model="Helsinki-NLP/opus-mt-en-fr", device=device) | |
#translation = pipeline("translation", model="facebook/m2m100_418M", device=device) | |
text_generation = pipeline("text-generation", model="gpt2", device=device) | |
text_classification = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", device=device) | |
# Define Functions for Each Task | |
def answer_question(context, question): | |
result = question_answering(question=question, context=context) | |
return result["answer"] | |
def generate_code(prompt): | |
output = code_generation(prompt, max_length=50) | |
return output[0]['generated_text'] | |
def summarize_text(text): | |
output = summarization(text, max_length=100, min_length=30, do_sample=False) | |
return output[0]['summary_text'] | |
def generate_text(prompt): | |
output = text_generation(prompt, max_length=100) | |
return output[0]['generated_text'] | |
def classify_text(text): | |
output = text_classification(text) | |
return f"Label: {output[0]['label']} | Score: {output[0]['score']:.4f}" | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🤖 Transformers Pipeline with FP16 Inference") | |
with gr.Tab("1️⃣ Question Answering"): | |
with gr.Row(): | |
context = gr.Textbox(label="Context", lines=4, placeholder="Paste your paragraph here...") | |
question = gr.Textbox(label="Question", placeholder="Ask a question...") | |
answer_btn = gr.Button("Get Answer") | |
answer_output = gr.Textbox(label="Answer") | |
answer_btn.click(answer_question, inputs=[context, question], outputs=answer_output) | |
with gr.Tab("2️⃣ Code Generation"): | |
code_input = gr.Textbox(label="Code Prompt", placeholder="Write code snippet...") | |
code_btn = gr.Button("Generate Code") | |
code_output = gr.Textbox(label="Generated Code") | |
code_btn.click(generate_code, inputs=code_input, outputs=code_output) | |
with gr.Tab("3️⃣ Summarization"): | |
summary_input = gr.Textbox(label="Text to Summarize", lines=5, placeholder="Paste long text here...") | |
summary_btn = gr.Button("Summarize") | |
summary_output = gr.Textbox(label="Summary") | |
summary_btn.click(summarize_text, inputs=summary_input, outputs=summary_output) | |
with gr.Tab("5️⃣ Text Generation"): | |
text_input = gr.Textbox(label="Text Prompt", placeholder="Start your text...") | |
text_btn = gr.Button("Generate Text") | |
text_output = gr.Textbox(label="Generated Text") | |
text_btn.click(generate_text, inputs=text_input, outputs=text_output) | |
with gr.Tab("6️⃣ Text Classification"): | |
classify_input = gr.Textbox(label="Enter Text", placeholder="Enter a sentence...") | |
classify_btn = gr.Button("Classify Sentiment") | |
classify_output = gr.Textbox(label="Classification Result") | |
classify_btn.click(classify_text, inputs=classify_input, outputs=classify_output) | |
# Launch App | |
demo.launch() | |