File size: 1,998 Bytes
c7ce343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
This module provides an interface for summarizing medical text using FALCONS.AI's medical_summarization model.
The interface allows users to enter text from a medical document.
The user will receive a generated summary of the medical text.
"""

import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer with fast processing enabled
tokenizer = AutoTokenizer.from_pretrained(
    "Falconsai/medical_summarization",
    use_fast=True
)

# Load model with bf16 for optimized memory usage
model = AutoModelForSeq2SeqLM.from_pretrained(
    "Falconsai/medical_summarization",
    torch_dtype=torch.bfloat16
)

# Move model to device
model.to(DEVICE)

@spaces.GPU
def summarize(text):
    """
    Generate a summary of the text from a medical document.

    Args:
        text (str): The text of a medical document.

    Returns:
        str: The generated summary of the medical text.
    """
    # Tokenize the text for summarization
    tokenized_input = tokenizer(
        text,
        return_tensors="pt"
    ).input_ids.to(DEVICE)

    # Generate a summary prediction using the model
    summary_ids = model.generate(
        input_ids=tokenized_input,
        max_new_tokens=500
    )

    # Decode the generated summary
    summary = tokenizer.batch_decode(
        summary_ids,
        skip_special_tokens=True
    )

    return summary[0]

TITLE = "Medical Text Summarizer"
DESCRIPTION = """
Summarize medical text using FALCONS.AI's medical_summarization model.
"""

# Gradio components
input_text = gr.Textbox(
    label="Medical document",
    placeholder="Enter text here"
)

output_text = gr.Textbox(label="Summary")

# Define the Gradio interface
demo = gr.Interface(
    fn=summarize,
    inputs=[input_text],
    outputs=[output_text],
    title=TITLE,
    description=DESCRIPTION
)

# Launch the Gradio interface
demo.launch()