Spaces:
Sleeping
Sleeping
File size: 1,998 Bytes
c7ce343 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
"""
This module provides an interface for summarizing medical text using FALCONS.AI's medical_summarization model.
The interface allows users to enter text from a medical document.
The user will receive a generated summary of the medical text.
"""
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizer with fast processing enabled
tokenizer = AutoTokenizer.from_pretrained(
"Falconsai/medical_summarization",
use_fast=True
)
# Load model with bf16 for optimized memory usage
model = AutoModelForSeq2SeqLM.from_pretrained(
"Falconsai/medical_summarization",
torch_dtype=torch.bfloat16
)
# Move model to device
model.to(DEVICE)
@spaces.GPU
def summarize(text):
"""
Generate a summary of the text from a medical document.
Args:
text (str): The text of a medical document.
Returns:
str: The generated summary of the medical text.
"""
# Tokenize the text for summarization
tokenized_input = tokenizer(
text,
return_tensors="pt"
).input_ids.to(DEVICE)
# Generate a summary prediction using the model
summary_ids = model.generate(
input_ids=tokenized_input,
max_new_tokens=500
)
# Decode the generated summary
summary = tokenizer.batch_decode(
summary_ids,
skip_special_tokens=True
)
return summary[0]
TITLE = "Medical Text Summarizer"
DESCRIPTION = """
Summarize medical text using FALCONS.AI's medical_summarization model.
"""
# Gradio components
input_text = gr.Textbox(
label="Medical document",
placeholder="Enter text here"
)
output_text = gr.Textbox(label="Summary")
# Define the Gradio interface
demo = gr.Interface(
fn=summarize,
inputs=[input_text],
outputs=[output_text],
title=TITLE,
description=DESCRIPTION
)
# Launch the Gradio interface
demo.launch()
|