RadiXGPT_ / app.py
Singularity666's picture
Update app.py
1c4591c
import torch
from PIL import Image
import streamlit as st
import numpy as np
import pandas as pd
from main import predict_caption, CLIPModel, get_text_embeddings
import openai
import base64
from docx import Document
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
from io import BytesIO
import re
openai.api_key = "sk-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"
st.markdown(
"""
<style>
body {
background-color: transparent;
}
.container {
display: flex;
justify-content: center;
align-items: center;
background-color: rgba(255, 255, 255, 0.7);
border-radius: 15px;
padding: 20px;
}
.stApp {
background-color: transparent;
}
.stText, .stMarkdown, .stTextInput>label, .stButton>button>span {
color: #1c1c1c !important; /* Set the dark text color for text elements */
}
.stButton>button>span {
color: initial !important; /* Reset the text color for the 'Generate Caption' button */
}
.stMarkdown h1, .stMarkdown h2 {
color: #ff6b81 !important; /* Set the text color of h1 and h2 elements to soft red-pink */
font-weight: bold; /* Set the font weight to bold */
border: 2px solid #ff6b81; /* Add a bold border around the headers */
padding: 10px; /* Add padding to the headers */
border-radius: 5px; /* Add border-radius to the headers */
}
</style>
""",
unsafe_allow_html=True,
)
device = torch.device("cpu")
testing_df = pd.read_csv("testing_df.csv")
model = CLIPModel() # Create an instance of CLIPModel
# Load the model
state_dict = torch.load("weights.pt", map_location=torch.device('cpu'))
print("Loaded State Dict Keys:", state_dict.keys())
# Create an instance of CLIPModel
model = CLIPModel().to(device)
print("Model Keys:", model.state_dict().keys())
# Load the state_dict into the model
model.load_state_dict(state_dict, strict=False) # Set strict=False to ignore unexpected keys
text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device)
def download_link(content, filename, link_text):
b64 = base64.b64encode(content).decode()
href = f'<a href="data:application/octet-stream;base64,{b64}" download="{filename}">{link_text}</a>'
return href
def show_predicted_caption(image, top_k=8):
matches = predict_caption(
image, model, text_embeddings, testing_df["caption"]
)[:top_k]
cleaned_matches = [re.sub(r'\s\(ROCO_\d+\)', '', match) for match in matches] # Add this line to clean the matches
return cleaned_matches # Return the cleaned_matches instead of matches
def generate_radiology_report(prompt):
response = openai.Completion.create(
engine="text-davinci-003",
prompt=prompt,
max_tokens=800,
n=1,
stop=None,
temperature=1,
)
report = response.choices[0].text.strip()
# Remove reference string from the report
report = re.sub(r'\(ROCO_\d+\)', '', report).strip()
return report
def save_as_docx(text, filename):
document = Document()
document.add_paragraph(text)
with BytesIO() as output:
document.save(output)
output.seek(0)
return output.getvalue()
st.title("RadiXGPT: An Evolution of machine doctors towards Radiology")
# Collect user's personal information
st.subheader("Personal Information")
first_name = st.text_input("First Name")
last_name = st.text_input("Last Name")
age = st.number_input("Age", min_value=0, max_value=120, value=25, step=1)
gender = st.selectbox("Gender", ["Male", "Female", "Other"])
st.write("Upload Scan to get Radiological Report:")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
if st.button("Generate Caption"):
with st.spinner("Generating caption..."):
image_np = np.array(image)
caption = show_predicted_caption(image_np)[0]
st.success(f"Caption: {caption}")
# Generate the radiology report
radiology_report = generate_radiology_report(f"Write Complete Radiology Report for this with clinical info, subjective, Assessment, Finding, Impressions, Conclusion and more in proper order : {caption}")
# Add personal information to the radiology report
radiology_report_with_personal_info = f"Patient Name: {first_name} {last_name}\nAge: {age}\nGender: {gender}\n\n{radiology_report}"
st.header("Radiology Report")
st.write(radiology_report_with_personal_info)
st.markdown(download_link(save_as_docx(radiology_report_with_personal_info, "radiology_report.docx"), "radiology_report.docx", "Download Report as DOCX"), unsafe_allow_html=True)
feedback_options = ["Satisfied", "Not Satisfied"]
selected_feedback = st.radio("Please provide feedback on the generated report:", feedback_options)
if selected_feedback == "Not Satisfied":
if st.button("Regenerate Report"):
with st.spinner("Regenerating report..."):
alternative_caption = get_alternative_caption(image_np, model, text_embeddings, testing_df["caption"])
regenerated_radiology_report = generate_radiology_report(f"Write Complete Radiology Report for this with clinical info, subjective, Assessment, Finding, Impressions, Conclusion and more in proper order : {alternative_caption}")
regenerated_radiology_report_with_personal_info = f"Patient Name: {first_name} {last_name}\nAge: {age}\nGender: {gender}\n\n{regenerated_radiology_report}"
st.header("Regenerated Radiology Report")
st.write(regenerated_radiology_report_with_personal_info)
st.markdown(download_link(save_as_docx(regenerated_radiology_report_with_personal_info, "regenerated_radiology_report.docx"), "regenerated_radiology_report.docx", "Download Regenerated Report as DOCX"), unsafe_allow_html=True)