Daryl Lim
Add application file
c7ce343
raw
history blame
2 kB
"""
This module provides an interface for summarizing medical text using FALCONS.AI's medical_summarization model.
The interface allows users to enter text from a medical document.
The user will receive a generated summary of the medical text.
"""
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizer with fast processing enabled
tokenizer = AutoTokenizer.from_pretrained(
"Falconsai/medical_summarization",
use_fast=True
)
# Load model with bf16 for optimized memory usage
model = AutoModelForSeq2SeqLM.from_pretrained(
"Falconsai/medical_summarization",
torch_dtype=torch.bfloat16
)
# Move model to device
model.to(DEVICE)
@spaces.GPU
def summarize(text):
"""
Generate a summary of the text from a medical document.
Args:
text (str): The text of a medical document.
Returns:
str: The generated summary of the medical text.
"""
# Tokenize the text for summarization
tokenized_input = tokenizer(
text,
return_tensors="pt"
).input_ids.to(DEVICE)
# Generate a summary prediction using the model
summary_ids = model.generate(
input_ids=tokenized_input,
max_new_tokens=500
)
# Decode the generated summary
summary = tokenizer.batch_decode(
summary_ids,
skip_special_tokens=True
)
return summary[0]
TITLE = "Medical Text Summarizer"
DESCRIPTION = """
Summarize medical text using FALCONS.AI's medical_summarization model.
"""
# Gradio components
input_text = gr.Textbox(
label="Medical document",
placeholder="Enter text here"
)
output_text = gr.Textbox(label="Summary")
# Define the Gradio interface
demo = gr.Interface(
fn=summarize,
inputs=[input_text],
outputs=[output_text],
title=TITLE,
description=DESCRIPTION
)
# Launch the Gradio interface
demo.launch()