import streamlit as st import time import plotly.graph_objects as go import random # Set page layout to wide mode st.set_page_config(layout="wide") # Treatment outcomes and possibilities outcomes = [ ("Doctor prescribes Treatment 1 (Personalized Drug A)", "info"), ("Nurse monitors Patient - Slight improvement", "success"), ("Clinician evaluates lab results - Treatment seems effective", "info"), ("Doctor prescribes Treatment 2 (Personalized Drug B)", "warning"), ("Nurse monitors Patient - Side effects observed", "error"), ("Clinician suggests modifying dosage", "warning"), ("Doctor prescribes modified Treatment 2", "info"), ("Patient shows significant improvement", "success") ] # Different possible outcomes of each treatment possibilities = { 1: [ ("Possibility 1: Patient responds well to Treatment 1", "success"), ("Possibility 2: Slight improvement, but inconclusive results", "info"), ("Possibility 3: No response, reevaluation needed", "error") ], 2: [ ("Possibility 1: Significant improvement", "success"), ("Possibility 2: Mild side effects", "warning") ], 3: [ ("Possibility 1: Treatment is effective", "success"), ("Possibility 2: Inconclusive lab results", "info") ], 4: [ ("Possibility 1: Improvement with Drug B", "success"), ("Possibility 2: Significant side effects", "error") ], 5: [ ("Possibility 1: Side effects worsen, modify dosage", "warning"), ("Possibility 2: Manageable side effects", "info") ], 6: [ ("Possibility 1: Dosage adjustment successful", "success"), ("Possibility 2: Further modification needed", "warning") ], 7: [ ("Possibility 1: Patient responds well to modified treatment", "success"), ("Possibility 2: Limited response, consider alternatives", "warning") ], 8: [ ("Possibility 1: Complete recovery", "success"), ("Possibility 2: Partial improvement, continue monitoring", "info") ] } # Initialize a dictionary to keep track of patient counts for each possibility patient_counts = {i: [0] * len(possibilities[i]) for i in possibilities} # Create a tree diagram using Plotly def create_tree_diagram(stage, selected_box=None): labels = [f"Stage {stage}: {outcomes[stage - 1][0]}"] parents = [""] values = [1] # Root node value colors = ['lightgrey'] # Root node color stage_possibilities = possibilities.get(stage, [("No specific possibilities defined", "info")]) for idx, (possibility, outcome_type) in enumerate(stage_possibilities): labels.append(f"{possibility} - Patients: {patient_counts[stage][idx]}") parents.append(f"Stage {stage}: {outcomes[stage - 1][0]}") values.append(1) # Equal weight for all possibilities if outcome_type == "success": colors.append('#d4edda' if idx != selected_box else '#28a745') elif outcome_type == "info": colors.append('#cce5ff' if idx != selected_box else '#007bff') elif outcome_type == "warning": colors.append('#fff3cd' if idx != selected_box else '#ffc107') elif outcome_type == "error": colors.append('#f8d7da' if idx != selected_box else '#dc3545') else: colors.append('lightgrey') fig = go.Figure(go.Treemap( labels=labels, parents=parents, values=values, marker=dict(colors=colors), textinfo="label", textfont=dict(size=14), )) fig.update_layout( margin=dict(t=10, l=10, r=10, b=10), width=600, height=400, uniformtext=dict(minsize=12, mode='show') ) return fig # Function to increment patient counts in random boxes for each stage def run_simulation(): for stage in possibilities: selected_box = random.randint(0, len(possibilities[stage]) - 1) patient_counts[stage][selected_box] += 1 st.session_state[f'selected_box_{stage}'] = selected_box # Initialize the selected box in session state if it doesn't exist for stage in possibilities: if f'selected_box_{stage}' not in st.session_state: st.session_state[f'selected_box_{stage}'] = None # Streamlit app layout st.title("Precision Medicine AI Agents - Treatment Decision Tree") # Add a "Run Simulation" button run_button = st.button("Run Simulation") # If the button is pressed, run the simulation if run_button: run_simulation() # Render the animation with the updated patient counts for i, (outcome, outcome_type) in enumerate(outcomes, 1): with st.container(): col1, col2 = st.columns([1, 2]) with col1: if outcome_type == "success": st.markdown(f"
Stage {i}: {outcome}
", unsafe_allow_html=True) elif outcome_type == "info": st.markdown(f"
Stage {i}: {outcome}
", unsafe_allow_html=True) elif outcome_type == "warning": st.markdown(f"
Stage {i}: {outcome}
", unsafe_allow_html=True) elif outcome_type == "error": st.markdown(f"
Stage {i}: {outcome}
", unsafe_allow_html=True) st.markdown("### Possibilities:") stage_possibilities = possibilities.get(i, [("No specific possibilities defined", "info")]) for idx, (possibility, possibility_type) in enumerate(stage_possibilities): color = 'lightgrey' if possibility_type == "success": color = '#d4edda' elif possibility_type == "info": color = '#cce5ff' elif possibility_type == "warning": color = '#fff3cd' elif possibility_type == "error": color = '#f8d7da' # Display the possibility and its patient count st.markdown(f"
{possibility} - Patients: {patient_counts[i][idx]}
", unsafe_allow_html=True) with col2: selected_box = st.session_state[f'selected_box_{i}'] fig = create_tree_diagram(i, selected_box) st.plotly_chart(fig, use_container_width=True)