Singularity666 commited on
Commit
369ad4d
1 Parent(s): 52b49e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -18
app.py CHANGED
@@ -1,51 +1,140 @@
1
- import streamlit as st
2
- import pickle
3
- import pandas as pd
4
  import torch
5
  from PIL import Image
 
6
  import numpy as np
7
- from main import predict_caption, CLIPModel , get_text_embeddings
 
 
 
 
 
 
 
8
 
 
9
 
10
  st.markdown(
11
  """
12
- <style>
13
  body {
14
  background-color: transparent;
15
  }
16
- </style>
17
- """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  unsafe_allow_html=True,
19
  )
20
 
21
 
 
22
  device = torch.device("cpu")
23
 
24
  testing_df = pd.read_csv("testing_df.csv")
25
- model = CLIPModel().to(device)
26
  model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu')))
 
27
  text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device)
28
 
 
 
 
 
29
 
30
- def show_predicted_caption(image):
31
  matches = predict_caption(
32
  image, model, text_embeddings, testing_df["caption"]
33
- )[0]
34
- return matches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- st.title("Medical Image Captioning")
37
- st.write("Upload an image to get a caption:")
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
40
  if uploaded_file is not None:
41
  image = Image.open(uploaded_file)
42
- st.image(image, caption="Uploaded Image", use_column_width=True)
43
- st.write("")
44
-
45
  if st.button("Generate Caption"):
46
  with st.spinner("Generating caption..."):
47
  image_np = np.array(image)
48
- caption = show_predicted_caption(image_np)
 
49
  st.success(f"Caption: {caption}")
50
 
51
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from PIL import Image
3
+ import streamlit as st
4
  import numpy as np
5
+ import pandas as pd
6
+ from main import predict_caption, CLIPModel, get_text_embeddings
7
+ import openai
8
+ import base64
9
+ from docx import Document
10
+ from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
11
+ from io import BytesIO
12
+ import re
13
 
14
+ openai.api_key = "sk-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"
15
 
16
  st.markdown(
17
  """
18
+ <style>
19
  body {
20
  background-color: transparent;
21
  }
22
+ .container {
23
+ display: flex;
24
+ justify-content: center;
25
+ align-items: center;
26
+ background-color: rgba(255, 255, 255, 0.7);
27
+ border-radius: 15px;
28
+ padding: 20px;
29
+ }
30
+ .stApp {
31
+ background-color: transparent;
32
+ }
33
+ .stText, .stMarkdown, .stTextInput>label, .stButton>button>span {
34
+ color: #1c1c1c !important; /* Set the dark text color for text elements */
35
+ }
36
+ .stButton>button>span {
37
+ color: initial !important; /* Reset the text color for the 'Generate Caption' button */
38
+ }
39
+ .stMarkdown h1, .stMarkdown h2 {
40
+ color: #ff6b81 !important; /* Set the text color of h1 and h2 elements to soft red-pink */
41
+ font-weight: bold; /* Set the font weight to bold */
42
+ border: 2px solid #ff6b81; /* Add a bold border around the headers */
43
+ padding: 10px; /* Add padding to the headers */
44
+ border-radius: 5px; /* Add border-radius to the headers */
45
+ }
46
+ </style>
47
+ """,
48
  unsafe_allow_html=True,
49
  )
50
 
51
 
52
+
53
  device = torch.device("cpu")
54
 
55
  testing_df = pd.read_csv("testing_df.csv")
56
+ model = CLIPModel() # Create an instance of CLIPModel
57
  model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu')))
58
+ # ...)
59
  text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device)
60
 
61
+ def download_link(content, filename, link_text):
62
+ b64 = base64.b64encode(content).decode()
63
+ href = f'<a href="data:application/octet-stream;base64,{b64}" download="{filename}">{link_text}</a>'
64
+ return href
65
 
66
+ def show_predicted_caption(image, top_k=8):
67
  matches = predict_caption(
68
  image, model, text_embeddings, testing_df["caption"]
69
+ )[:top_k]
70
+ cleaned_matches = [re.sub(r'\s\(ROCO_\d+\)', '', match) for match in matches] # Add this line to clean the matches
71
+ return cleaned_matches # Return the cleaned_matches instead of matches
72
+
73
+ def generate_radiology_report(prompt):
74
+ response = openai.Completion.create(
75
+ engine="text-davinci-003",
76
+ prompt=prompt,
77
+ max_tokens=800,
78
+ n=1,
79
+ stop=None,
80
+ temperature=1,
81
+ )
82
+ report = response.choices[0].text.strip()
83
+ # Remove reference string from the report
84
+ report = re.sub(r'\(ROCO_\d+\)', '', report).strip()
85
+ return report
86
 
 
 
87
 
88
+ def save_as_docx(text, filename):
89
+ document = Document()
90
+ document.add_paragraph(text)
91
+ with BytesIO() as output:
92
+ document.save(output)
93
+ output.seek(0)
94
+ return output.getvalue()
95
+
96
+ st.title("RadiXGPT: An Evolution of machine doctors towards Radiology")
97
+
98
+
99
+ # Collect user's personal information
100
+ st.subheader("Personal Information")
101
+ first_name = st.text_input("First Name")
102
+ last_name = st.text_input("Last Name")
103
+ age = st.number_input("Age", min_value=0, max_value=120, value=25, step=1)
104
+ gender = st.selectbox("Gender", ["Male", "Female", "Other"])
105
+
106
+ st.write("Upload Scan to get Radiological Report:")
107
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
108
  if uploaded_file is not None:
109
  image = Image.open(uploaded_file)
 
 
 
110
  if st.button("Generate Caption"):
111
  with st.spinner("Generating caption..."):
112
  image_np = np.array(image)
113
+ caption = show_predicted_caption(image_np)[0]
114
+
115
  st.success(f"Caption: {caption}")
116
 
117
+ # Generate the radiology report
118
+ radiology_report = generate_radiology_report(f"Write Complete Radiology Report for this with clinical info, subjective, Assessment, Finding, Impressions, Conclusion and more in proper order : {caption}")
119
+
120
+ # Add personal information to the radiology report
121
+ radiology_report_with_personal_info = f"Patient Name: {first_name} {last_name}\nAge: {age}\nGender: {gender}\n\n{radiology_report}"
122
+
123
+ st.header("Radiology Report")
124
+ st.write(radiology_report_with_personal_info)
125
+ st.markdown(download_link(save_as_docx(radiology_report_with_personal_info, "radiology_report.docx"), "radiology_report.docx", "Download Report as DOCX"), unsafe_allow_html=True)
126
+
127
+ feedback_options = ["Satisfied", "Not Satisfied"]
128
+ selected_feedback = st.radio("Please provide feedback on the generated report:", feedback_options)
129
+
130
+ if selected_feedback == "Not Satisfied":
131
+ if st.button("Regenerate Report"):
132
+ with st.spinner("Regenerating report..."):
133
+ alternative_caption = get_alternative_caption(image_np, model, text_embeddings, testing_df["caption"])
134
+ regenerated_radiology_report = generate_radiology_report(f"Write Complete Radiology Report for this with clinical info, subjective, Assessment, Finding, Impressions, Conclusion and more in proper order : {alternative_caption}")
135
+
136
+ regenerated_radiology_report_with_personal_info = f"Patient Name: {first_name} {last_name}\nAge: {age}\nGender: {gender}\n\n{regenerated_radiology_report}"
137
+
138
+ st.header("Regenerated Radiology Report")
139
+ st.write(regenerated_radiology_report_with_personal_info)
140
+ st.markdown(download_link(save_as_docx(regenerated_radiology_report_with_personal_info, "regenerated_radiology_report.docx"), "regenerated_radiology_report.docx", "Download Regenerated Report as DOCX"), unsafe_allow_html=True)