Spaces:
Sleeping
Sleeping
""" | |
This module provides an interface for summarizing medical text using FALCONS.AI's medical_summarization model. | |
The interface allows users to enter text from a medical document. | |
The user will receive a generated summary of the medical text. | |
""" | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# Set device | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load tokenizer with fast processing enabled | |
tokenizer = AutoTokenizer.from_pretrained( | |
"Falconsai/medical_summarization", | |
use_fast=True | |
) | |
# Load model with bf16 for optimized memory usage | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"Falconsai/medical_summarization", | |
torch_dtype=torch.bfloat16 | |
) | |
# Move model to device | |
model.to(DEVICE) | |
def summarize(text): | |
""" | |
Generate a summary of the text from a medical document. | |
Args: | |
text (str): The text of a medical document. | |
Returns: | |
str: The generated summary of the medical text. | |
""" | |
# Tokenize the text for summarization | |
tokenized_input = tokenizer( | |
text, | |
return_tensors="pt" | |
).input_ids.to(DEVICE) | |
# Generate a summary prediction using the model | |
summary_ids = model.generate( | |
input_ids=tokenized_input, | |
max_new_tokens=500 | |
) | |
# Decode the generated summary | |
summary = tokenizer.batch_decode( | |
summary_ids, | |
skip_special_tokens=True | |
) | |
return summary[0] | |
TITLE = "Medical Text Summarizer" | |
DESCRIPTION = """ | |
Summarize medical text using FALCONS.AI's medical_summarization model. | |
""" | |
# Gradio components | |
input_text = gr.Textbox( | |
label="Medical document", | |
placeholder="Enter text here" | |
) | |
output_text = gr.Textbox(label="Summary") | |
# Define the Gradio interface | |
demo = gr.Interface( | |
fn=summarize, | |
inputs=[input_text], | |
outputs=[output_text], | |
title=TITLE, | |
description=DESCRIPTION | |
) | |
# Launch the Gradio interface | |
demo.launch() | |