Erva Ulusoy commited on
Commit
8f4b741
·
1 Parent(s): dbed3d3

added kg visualization feature

Browse files
Files changed (4) hide show
  1. ProtHGT_app.py +288 -171
  2. requirements.txt +2 -1
  3. run_prothgt_app.py +3 -4
  4. visualize_kg.py +242 -0
ProtHGT_app.py CHANGED
@@ -25,8 +25,8 @@ import random
25
  # # ❌ Remove the info message after initialization is complete
26
  # loading_placeholder.empty()
27
 
28
-
29
  from run_prothgt_app import *
 
30
 
31
  def convert_df(df):
32
  return df.to_csv(index=False).encode('utf-8')
@@ -34,19 +34,31 @@ def convert_df(df):
34
  # Initialize session state variables
35
  if 'predictions_df' not in st.session_state:
36
  st.session_state.predictions_df = None
 
 
37
  if 'submitted' not in st.session_state:
38
  st.session_state.submitted = False
39
  if 'previous_inputs' not in st.session_state:
40
  st.session_state.previous_inputs = None
41
- # Initialize session state variables
42
  if 'generating_predictions' not in st.session_state:
43
  st.session_state.generating_predictions = False
 
 
 
44
 
45
  def reset_prediction_state():
46
  st.session_state.generating_predictions = False
47
  st.session_state.submitted = False
48
  st.session_state.predictions_df = None
49
  st.session_state.previous_inputs = None
 
 
 
 
 
 
 
 
50
 
51
  def set_generating_predictions():
52
  st.session_state.generating_predictions = True
@@ -130,7 +142,6 @@ with st.sidebar:
130
  )
131
 
132
  elif selection_method == "Search proteins":
133
-
134
  # User enters search term
135
  search_query = st.text_input(
136
  "1\\. Start typing a protein ID (at least 3 characters) and press Enter to see search results in the dropdown menu below (2)",
@@ -138,6 +149,10 @@ with st.sidebar:
138
  disabled=disabled
139
  )
140
 
 
 
 
 
141
  # Apply fuzzy search only if query length is >= 3
142
  filtered_proteins = []
143
  if len(search_query) >= 3:
@@ -150,14 +165,22 @@ with st.sidebar:
150
  filtered_proteins = [match[0] for match in matches] # Show top 50 matches
151
 
152
  with st.container():
 
 
 
153
  selected_proteins = st.multiselect(
154
  "2\\. Select proteins from search results",
155
- options=filtered_proteins,
 
156
  placeholder="Start typing a protein ID above (1) to see search results...",
157
  max_selections=100,
158
  disabled=disabled,
159
  key="protein_selector"
160
  )
 
 
 
 
161
  # Apply custom CSS to make container scrollable
162
  st.markdown("""
163
  <style>
@@ -167,7 +190,7 @@ with st.sidebar:
167
  }
168
  </style>
169
  """, unsafe_allow_html=True)
170
-
171
  else: # Upload file option
172
  uploaded_file = st.file_uploader(
173
  "Upload a text file with UniProt IDs (one per line, max 100)*",
@@ -328,193 +351,287 @@ if st.session_state.submitted:
328
  go_categories = ['GO_term_F', 'GO_term_P', 'GO_term_C']
329
 
330
  # Generate predictions
331
- predictions_df = generate_prediction_df(
332
  protein_ids=selected_proteins,
333
  model_paths=model_paths,
334
  model_config_paths=model_config_paths,
335
  go_category=go_categories
336
  )
337
 
 
338
  st.session_state.predictions_df = predictions_df
339
-
340
  # Reset only the generating_predictions flag to release the sidebar
341
  st.session_state.generating_predictions = False
342
  st.rerun()
343
 
344
  # Display and filter predictions
345
  st.success("Predictions generated successfully!")
346
- st.markdown("### Filter and View Predictions")
347
-
348
- # Create filters
349
- col1, col2, col3, col4 = st.columns(4)
350
-
351
- with col1:
352
- # Extract UniProt IDs from URLs for the selectbox
353
- uniprot_ids = st.session_state.predictions_df['UniProt_ID'].apply(
354
- lambda x: x.split('/')[-2] # Gets the ID part from the URL
355
- ).unique().tolist()
356
-
357
- # Protein filter
358
- selected_protein = st.selectbox(
359
- "Filter by Protein",
360
- options=['All'] + sorted(uniprot_ids)
361
- )
362
 
363
- with col2:
364
- # GO category filter
365
- selected_category = st.selectbox(
366
- "Filter by GO Category",
367
- options=['All'] + sorted(st.session_state.predictions_df['GO_category'].unique().tolist())
368
- )
369
 
370
- with col3:
371
- # GO term filter
372
- go_term_filter = st.text_input(
373
- "Filter by GO Term ID",
374
- placeholder="e.g., GO:0003674",
375
- help="Enter a GO term ID to filter results"
376
- ).strip()
377
 
378
- with col4:
379
- # Probability threshold
380
- min_probability_threshold = st.slider(
381
- "Minimum Probability",
382
- min_value=0.0,
383
- max_value=1.0,
384
- value=0.5,
385
- step=0.05
386
- )
387
 
388
- max_probability_threshold = st.slider(
389
- "Maximum Probability",
390
- min_value=0.0,
391
- max_value=1.0,
392
- value=1.0,
393
- step=0.05
394
- )
395
-
396
- # Filter the dataframe using session state data
397
- filtered_df = st.session_state.predictions_df.copy()
398
-
399
- if selected_protein != 'All':
400
- filtered_df = filtered_df[filtered_df['UniProt_ID'].str.contains(selected_protein)]
401
-
402
- if selected_category != 'All':
403
- filtered_df = filtered_df[filtered_df['GO_category'] == selected_category]
404
-
405
- if go_term_filter:
406
- filtered_df = filtered_df[filtered_df['GO_ID'].str.contains(go_term_filter, case=False, na=False)]
407
 
408
- filtered_df = filtered_df[(filtered_df['Probability'] >= min_probability_threshold) &
409
- (filtered_df['Probability'] <= max_probability_threshold)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
- # Custom CSS to increase table width and improve layout
412
- st.markdown("""
413
- <style>
414
- .stDataFrame {
415
- width: 100%;
416
- }
417
- .stDataFrame > div {
418
- width: 100%;
419
- }
420
- .stDataFrame [data-testid="stDataFrameResizable"] {
421
- width: 100%;
422
- min-width: 100%;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  }
424
- .pagination-info {
425
- font-size: 14px;
426
- color: #666;
427
- padding: 10px 0;
428
- }
429
- .page-controls {
430
- display: flex;
431
- align-items: center;
432
- justify-content: center;
433
- gap: 20px;
434
- padding: 10px 0;
435
- }
436
- </style>
437
- """, unsafe_allow_html=True)
 
438
 
439
- # Add pagination controls
440
- col1, col2, col3 = st.columns([2, 1, 2])
441
- with col2:
442
- rows_per_page = st.selectbox("Rows per page", [50, 100, 200, 500], index=1)
443
-
444
- total_rows = len(filtered_df)
445
- total_pages = (total_rows + rows_per_page - 1) // rows_per_page
446
-
447
- # Initialize page number in session state
448
- if "page_number" not in st.session_state:
449
- st.session_state.page_number = 0
450
-
451
- # Calculate start and end indices for current page
452
- start_idx = st.session_state.page_number * rows_per_page
453
- end_idx = min(start_idx + rows_per_page, total_rows)
454
-
455
- st.dataframe(
456
- filtered_df.iloc[start_idx:end_idx],
457
- hide_index=True,
458
- use_container_width=True,
459
- column_config={
460
- "UniProt_ID": st.column_config.LinkColumn(
461
- "UniProt ID",
462
- help="Click to view protein in UniProt",
463
- validate="^https://www\\.uniprot\\.org/uniprotkb/[A-Z0-9]+/entry$",
464
- display_text="^https://www\\.uniprot\\.org/uniprotkb/([A-Z0-9]+)/entry$"
465
- ),
466
- "GO_ID": st.column_config.LinkColumn(
467
- "GO ID",
468
- help="Click to view GO term in QuickGO",
469
- validate="^https://www\\.ebi\\.ac\\.uk/QuickGO/term/GO:[0-9]+$",
470
- display_text="^https://www\\.ebi\\.ac\\.uk/QuickGO/term/(GO:[0-9]+)$"
471
- ),
472
- "Probability": st.column_config.ProgressColumn(
473
- "Probability",
474
- format="%.2f",
475
- min_value=0,
476
- max_value=1,
477
- ),
478
- "Protein": st.column_config.TextColumn(
479
- "Protein",
480
- help="Protein Name",
481
- ),
482
- "GO_category": st.column_config.TextColumn(
483
- "GO Category",
484
- help="Gene Ontology Category",
485
- ),
486
- "GO_term": st.column_config.TextColumn(
487
- "GO Term",
488
- help="Gene Ontology Term Name",
489
- ),
490
- }
491
- )
492
- # Pagination controls with better layout
493
- col1, col2, col3 = st.columns([1, 3, 1])
494
- with col1:
495
- if st.button("Previous", disabled=st.session_state.page_number == 0):
496
- st.session_state.page_number -= 1
497
- st.rerun()
498
-
499
- with col2:
500
- st.markdown(f"""
501
- <div class="pagination-info" style="text-align: center">
502
- Page {st.session_state.page_number + 1} of {total_pages}<br>
503
- Showing rows {start_idx + 1} to {end_idx} of {total_rows}
504
- </div>
505
- """, unsafe_allow_html=True)
506
 
507
- with col3:
508
- if st.button("Next", disabled=st.session_state.page_number >= total_pages - 1):
509
- st.session_state.page_number += 1
510
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
 
513
- # Download filtered results
514
- st.download_button(
515
- label="Download Filtered Results",
516
- data=convert_df(filtered_df),
517
- file_name="filtered_predictions.csv",
518
- mime="text/csv",
519
- key="download_filtered_predictions"
520
- )
 
25
  # # ❌ Remove the info message after initialization is complete
26
  # loading_placeholder.empty()
27
 
 
28
  from run_prothgt_app import *
29
+ from visualize_kg import *
30
 
31
  def convert_df(df):
32
  return df.to_csv(index=False).encode('utf-8')
 
34
  # Initialize session state variables
35
  if 'predictions_df' not in st.session_state:
36
  st.session_state.predictions_df = None
37
+ if 'heterodata' not in st.session_state:
38
+ st.session_state.heterodata = None
39
  if 'submitted' not in st.session_state:
40
  st.session_state.submitted = False
41
  if 'previous_inputs' not in st.session_state:
42
  st.session_state.previous_inputs = None
 
43
  if 'generating_predictions' not in st.session_state:
44
  st.session_state.generating_predictions = False
45
+ if 'protein_visualizations' not in st.session_state:
46
+ st.session_state.protein_visualizations = {}
47
+
48
 
49
  def reset_prediction_state():
50
  st.session_state.generating_predictions = False
51
  st.session_state.submitted = False
52
  st.session_state.predictions_df = None
53
  st.session_state.previous_inputs = None
54
+ # Clean up visualization files
55
+ if 'protein_visualizations' in st.session_state:
56
+ for viz_info in st.session_state.protein_visualizations.values():
57
+ try:
58
+ os.unlink(viz_info['path'])
59
+ except:
60
+ pass
61
+ st.session_state.protein_visualizations = {}
62
 
63
  def set_generating_predictions():
64
  st.session_state.generating_predictions = True
 
142
  )
143
 
144
  elif selection_method == "Search proteins":
 
145
  # User enters search term
146
  search_query = st.text_input(
147
  "1\\. Start typing a protein ID (at least 3 characters) and press Enter to see search results in the dropdown menu below (2)",
 
149
  disabled=disabled
150
  )
151
 
152
+ # Initialize selected_proteins in session state if not exists
153
+ if 'selected_proteins_search' not in st.session_state:
154
+ st.session_state.selected_proteins_search = []
155
+
156
  # Apply fuzzy search only if query length is >= 3
157
  filtered_proteins = []
158
  if len(search_query) >= 3:
 
165
  filtered_proteins = [match[0] for match in matches] # Show top 50 matches
166
 
167
  with st.container():
168
+ # Include previously selected proteins in options
169
+ all_options = list(set(filtered_proteins + st.session_state.selected_proteins_search))
170
+
171
  selected_proteins = st.multiselect(
172
  "2\\. Select proteins from search results",
173
+ options=all_options,
174
+ default=st.session_state.selected_proteins_search,
175
  placeholder="Start typing a protein ID above (1) to see search results...",
176
  max_selections=100,
177
  disabled=disabled,
178
  key="protein_selector"
179
  )
180
+
181
+ # Update session state with current selection
182
+ st.session_state.selected_proteins_search = selected_proteins
183
+
184
  # Apply custom CSS to make container scrollable
185
  st.markdown("""
186
  <style>
 
190
  }
191
  </style>
192
  """, unsafe_allow_html=True)
193
+
194
  else: # Upload file option
195
  uploaded_file = st.file_uploader(
196
  "Upload a text file with UniProt IDs (one per line, max 100)*",
 
351
  go_categories = ['GO_term_F', 'GO_term_P', 'GO_term_C']
352
 
353
  # Generate predictions
354
+ heterodata, predictions_df = generate_prediction_df(
355
  protein_ids=selected_proteins,
356
  model_paths=model_paths,
357
  model_config_paths=model_config_paths,
358
  go_category=go_categories
359
  )
360
 
361
+ st.session_state.heterodata = heterodata
362
  st.session_state.predictions_df = predictions_df
363
+
364
  # Reset only the generating_predictions flag to release the sidebar
365
  st.session_state.generating_predictions = False
366
  st.rerun()
367
 
368
  # Display and filter predictions
369
  st.success("Predictions generated successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
+ # tabs for predictions and visualizations
372
+ predictions_tab, kg_viz_tab = st.tabs(["View Predictions", "View Knowledge Graphs"])
 
 
 
 
373
 
374
+ with predictions_tab:
375
+ st.markdown("### Filter and View Predictions")
 
 
 
 
 
376
 
 
 
 
 
 
 
 
 
 
377
 
378
+ # Create filters
379
+ col1, col2, col3, col4 = st.columns(4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
+ with col1:
382
+ # Extract UniProt IDs from URLs for the selectbox
383
+ uniprot_ids = st.session_state.predictions_df['UniProt_ID'].unique().tolist()
384
+
385
+ # Protein filter
386
+ selected_protein = st.selectbox(
387
+ "Filter by Protein",
388
+ options=['All'] + sorted(uniprot_ids)
389
+ )
390
+
391
+ with col2:
392
+ # GO category filter
393
+ selected_category = st.selectbox(
394
+ "Filter by GO Category",
395
+ options=['All'] + sorted(st.session_state.predictions_df['GO_category'].unique().tolist())
396
+ )
397
+
398
+ with col3:
399
+ # GO term filter
400
+ go_term_filter = st.text_input(
401
+ "Filter by GO Term ID",
402
+ placeholder="e.g., GO:0003674",
403
+ help="Enter a GO term ID to filter results"
404
+ ).strip()
405
+
406
+ with col4:
407
+ # Probability threshold range slider
408
+ probability_range = st.slider(
409
+ "Probability Range",
410
+ min_value=0.0,
411
+ max_value=1.0,
412
+ value=(0.5, 1.0), # (min, max) default values
413
+ step=0.05
414
+ )
415
+ min_probability_threshold, max_probability_threshold = probability_range
416
+
417
+ # Filter the dataframe using session state data
418
+ filtered_df = st.session_state.predictions_df.copy()
419
+
420
+ if selected_protein != 'All':
421
+ filtered_df = filtered_df[filtered_df['UniProt_ID'].str.contains(selected_protein)]
422
+
423
+ if selected_category != 'All':
424
+ filtered_df = filtered_df[filtered_df['GO_category'] == selected_category]
425
+
426
+ if go_term_filter:
427
+ filtered_df = filtered_df[filtered_df['GO_ID'] == go_term_filter]
428
+
429
+ filtered_df = filtered_df[(filtered_df['Probability'] >= min_probability_threshold) &
430
+ (filtered_df['Probability'] <= max_probability_threshold)]
431
+
432
+ filtered_df['UniProt_ID'] = [f"https://www.uniprot.org/uniprotkb/{pid}/entry" for pid in filtered_df['UniProt_ID']]
433
+ filtered_df['GO_ID'] = [f"https://www.ebi.ac.uk/QuickGO/term/{go_id}" for go_id in filtered_df['GO_ID']]
434
+
435
+ # Custom CSS to increase table width and improve layout
436
+ st.markdown("""
437
+ <style>
438
+ .stDataFrame {
439
+ width: 100%;
440
+ }
441
+ .stDataFrame > div {
442
+ width: 100%;
443
+ }
444
+ .stDataFrame [data-testid="stDataFrameResizable"] {
445
+ width: 100%;
446
+ min-width: 100%;
447
+ }
448
+ .pagination-info {
449
+ font-size: 14px;
450
+ color: #666;
451
+ padding: 10px 0;
452
+ }
453
+ .page-controls {
454
+ display: flex;
455
+ align-items: center;
456
+ justify-content: center;
457
+ gap: 20px;
458
+ padding: 10px 0;
459
+ }
460
+ </style>
461
+ """, unsafe_allow_html=True)
462
 
463
+ # Add pagination controls
464
+ col1, col2, col3 = st.columns([2, 1, 2])
465
+ with col2:
466
+ rows_per_page = st.selectbox("Rows per page", [50, 100, 200, 500], index=1)
467
+
468
+ total_rows = len(filtered_df)
469
+ total_pages = (total_rows + rows_per_page - 1) // rows_per_page
470
+
471
+ # Initialize page number in session state
472
+ if "page_number" not in st.session_state:
473
+ st.session_state.page_number = 0
474
+
475
+ # Calculate start and end indices for current page
476
+ start_idx = st.session_state.page_number * rows_per_page
477
+ end_idx = min(start_idx + rows_per_page, total_rows)
478
+
479
+ st.dataframe(
480
+ filtered_df.iloc[start_idx:end_idx],
481
+ hide_index=True,
482
+ use_container_width=True,
483
+ column_config={
484
+ "UniProt_ID": st.column_config.LinkColumn(
485
+ "UniProt ID",
486
+ help="Click to view protein in UniProt",
487
+ validate="^https://www\\.uniprot\\.org/uniprotkb/[A-Z0-9]+/entry$",
488
+ display_text="^https://www\\.uniprot\\.org/uniprotkb/([A-Z0-9]+)/entry$"
489
+ ),
490
+ "GO_ID": st.column_config.LinkColumn(
491
+ "GO ID",
492
+ help="Click to view GO term in QuickGO",
493
+ validate="^https://www\\.ebi\\.ac\\.uk/QuickGO/term/GO:[0-9]+$",
494
+ display_text="^https://www\\.ebi\\.ac\\.uk/QuickGO/term/(GO:[0-9]+)$"
495
+ ),
496
+ "Probability": st.column_config.ProgressColumn(
497
+ "Probability",
498
+ format="%.2f",
499
+ min_value=0,
500
+ max_value=1,
501
+ ),
502
+ "Protein": st.column_config.TextColumn(
503
+ "Protein",
504
+ help="Protein Name",
505
+ ),
506
+ "GO_category": st.column_config.TextColumn(
507
+ "GO Category",
508
+ help="Gene Ontology Category",
509
+ ),
510
+ "GO_term": st.column_config.TextColumn(
511
+ "GO Term",
512
+ help="Gene Ontology Term Name",
513
+ ),
514
  }
515
+ )
516
+ # Pagination controls with better layout
517
+ col1, col2, col3 = st.columns([1, 3, 1])
518
+ with col1:
519
+ if st.button("Previous", disabled=st.session_state.page_number == 0):
520
+ st.session_state.page_number -= 1
521
+ st.rerun()
522
+
523
+ with col2:
524
+ st.markdown(f"""
525
+ <div class="pagination-info" style="text-align: center">
526
+ Page {st.session_state.page_number + 1} of {total_pages}<br>
527
+ Showing rows {start_idx + 1} to {end_idx} of {total_rows}
528
+ </div>
529
+ """, unsafe_allow_html=True)
530
 
531
+ with col3:
532
+ if st.button("Next", disabled=st.session_state.page_number >= total_pages - 1):
533
+ st.session_state.page_number += 1
534
+ st.rerun()
535
+
536
+ downloadable_df = filtered_df.copy()
537
+ downloadable_df['UniProt_ID'] = downloadable_df['UniProt_ID'].apply(
538
+ lambda x: x.split('/')[-2] # Gets the ID part from the URL
539
+ )
540
+ downloadable_df['GO_ID'] = downloadable_df['GO_ID'].apply(
541
+ lambda x: x.split('/')[-1] # Gets the ID part from the URL
542
+ )
543
+ # Download filtered results
544
+ st.download_button(
545
+ label="Download Filtered Results",
546
+ data=convert_df(downloadable_df),
547
+ file_name="filtered_predictions.csv",
548
+ mime="text/csv",
549
+ key="download_filtered_predictions"
550
+ )
551
+
552
+ with kg_viz_tab:
553
+ st.markdown("### Knowledge Graph Visualization")
554
+
555
+ if not selected_proteins:
556
+ st.info("Please select proteins from the sidebar to visualize their knowledge graphs.")
557
+ elif len(selected_proteins) <= 10:
558
+ st.text("Visualize the knowledge graph for each protein to understand the biological relationships that contributed to the predictions.")
559
+
560
+ protein_tabs = st.tabs([f"{protein_id}" for protein_id in selected_proteins])
561
+
562
+ # Create visualizations in each tab
563
+ for idx, protein_id in enumerate(selected_proteins):
564
+ with protein_tabs[idx]:
565
+ max_node_count = st.slider(
566
+ "Maximum neighbors per edge type",
567
+ min_value=5,
568
+ max_value=50,
569
+ value=10,
570
+ step=5,
571
+ help="Control the maximum number of neighboring nodes shown for each relationship type",
572
+ key=f"slider_{protein_id}"
573
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
+ # Check if visualization exists for this protein
576
+ viz_exists = (protein_id in st.session_state.protein_visualizations and
577
+ os.path.exists(st.session_state.protein_visualizations[protein_id]['path']))
578
+
579
+ if not viz_exists:
580
+ if st.button(f"Generate Visualization", key=f"viz_{protein_id}"):
581
+ # Generate visualization with selected max_node_count
582
+ html_path, visualized_edges = visualize_protein_subgraph(
583
+ st.session_state.heterodata,
584
+ protein_id,
585
+ st.session_state.predictions_df,
586
+ limit=max_node_count
587
+ )
588
+
589
+ # Store visualization info in session state
590
+ st.session_state.protein_visualizations[protein_id] = {
591
+ 'path': html_path,
592
+ 'edges': visualized_edges
593
+ }
594
+ st.rerun()
595
+
596
+ # If visualization exists, display it
597
+ if viz_exists:
598
+ viz_info = st.session_state.protein_visualizations[protein_id]
599
+
600
+ # Add download button for edges
601
+ formatted_edges = {}
602
+ for edge_type, edges in viz_info['edges'].items():
603
+ edge_type_str = f"{edge_type[0]}_{edge_type[1]}_{edge_type[2]}"
604
+ formatted_edges[edge_type_str] = [
605
+ {"source": edge[0][0], "target": edge[0][1], "probability": edge[1]}
606
+ for edge in edges
607
+ ]
608
+
609
+ kg_viz_button_columns = st.columns([1, 1, 1])
610
+
611
+ with kg_viz_button_columns[0]:
612
+ st.download_button(
613
+ label='Download Visualized Edges',
614
+ data=json.dumps(formatted_edges, indent=2),
615
+ file_name=f'{protein_id}_visualized_edges.json',
616
+ mime='application/json'
617
+ )
618
+
619
+ with kg_viz_button_columns[1]:
620
+ if st.button("Regenerate Visualization", key=f"regenerate_{protein_id}"):
621
+ # Clean up old file
622
+ try:
623
+ os.unlink(viz_info['path'])
624
+ except FileNotFoundError:
625
+ pass
626
+ # Remove from session state
627
+ del st.session_state.protein_visualizations[protein_id]
628
+ st.rerun()
629
+
630
+ with open(viz_info['path'], 'r', encoding='utf-8') as f:
631
+ html_content = f.read()
632
+
633
+ st.components.v1.html(html_content, height=600)
634
 
635
 
636
+ else:
637
+ st.warning("Knowledge graph visualization is only available when 10 or fewer proteins are selected.")
 
 
 
 
 
 
requirements.txt CHANGED
@@ -7,4 +7,5 @@ torch_sparse==0.6.15
7
  torch_scatter==2.1.0
8
  torch_geometric==2.2.0
9
  gdown
10
- rapidfuzz
 
 
7
  torch_scatter==2.1.0
8
  torch_geometric==2.2.0
9
  gdown
10
+ rapidfuzz
11
+ pyvis
run_prothgt_app.py CHANGED
@@ -130,9 +130,9 @@ def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
130
 
131
  # Create DataFrame
132
  prediction_df = pd.DataFrame({
133
- 'UniProt_ID': [f"https://www.uniprot.org/uniprotkb/{pid}/entry" for pid in all_proteins],
134
  'Protein': all_protein_names,
135
- 'GO_ID': [f"https://www.ebi.ac.uk/QuickGO/term/{go_id}" for go_id in all_go_terms],
136
  'GO_term': all_go_term_names,
137
  'GO_category': all_categories,
138
  'Probability': all_probabilities
@@ -204,7 +204,6 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
204
  del predictions
205
  torch.cuda.empty_cache() # Clear CUDA cache if using GPU
206
 
207
- del heterodata
208
 
209
  # Combine all predictions
210
  final_df = pd.concat(all_predictions, ignore_index=True)
@@ -213,4 +212,4 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
213
  del all_predictions
214
  torch.cuda.empty_cache()
215
 
216
- return final_df
 
130
 
131
  # Create DataFrame
132
  prediction_df = pd.DataFrame({
133
+ 'UniProt_ID': all_proteins,
134
  'Protein': all_protein_names,
135
+ 'GO_ID': all_go_terms,
136
  'GO_term': all_go_term_names,
137
  'GO_category': all_categories,
138
  'Probability': all_probabilities
 
204
  del predictions
205
  torch.cuda.empty_cache() # Clear CUDA cache if using GPU
206
 
 
207
 
208
  # Combine all predictions
209
  final_df = pd.concat(all_predictions, ignore_index=True)
 
212
  del all_predictions
213
  torch.cuda.empty_cache()
214
 
215
+ return heterodata, final_df
visualize_kg.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyvis.network import Network
2
+ import os
3
+
4
+ NODE_TYPE_COLORS = {
5
+ 'Disease': '#079dbb',
6
+ 'HPO': '#58d0e8',
7
+ 'Drug': '#815ac0',
8
+ 'Compound': '#d2b7e5',
9
+ 'Domain': '#6bbf59',
10
+ 'GO_term_P': '#ff8800',
11
+ 'GO_term_F': '#ffaa00',
12
+ 'GO_term_C': '#ffc300',
13
+ 'Pathway': '#720026',
14
+ 'kegg_Pathway': '#720026',
15
+ 'EC_number': '#ce4257',
16
+ 'Protein': '#3aa6a4'
17
+ }
18
+
19
+ GO_CATEGORY_MAPPING = {
20
+ 'Biological Process': 'GO_term_P',
21
+ 'Molecular Function': 'GO_term_F',
22
+ 'Cellular Component': 'GO_term_C'
23
+ }
24
+
25
+ def _gather_protein_edges(data, protein_id):
26
+
27
+ protein_idx = data['Protein']['id_mapping'][protein_id]
28
+ reverse_id_mapping = {}
29
+ for node_type in data.node_types:
30
+ reverse_id_mapping[node_type] = {v:k for k, v in data[node_type]['id_mapping'].items()}
31
+
32
+ protein_edges = {}
33
+
34
+ print(f'Gathering edges for {protein_id}...')
35
+
36
+ for edge_type in data.edge_types:
37
+ if 'rev' not in edge_type[1]:
38
+ if edge_type not in protein_edges:
39
+ protein_edges[edge_type] = []
40
+ if edge_type[0] == 'Protein':
41
+ print(f'Gathering edges for {edge_type}...')
42
+ # append the edges with protein_idx as source node
43
+ edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == protein_idx]
44
+ protein_edges[edge_type].extend(edges.T.tolist())
45
+ elif edge_type[2] == 'Protein':
46
+ print(f'Gathering edges for {edge_type}...')
47
+ # append the edges with protein_idx as target node
48
+ edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == protein_idx]
49
+ protein_edges[edge_type].extend(edges.T.tolist())
50
+
51
+ for edge_type in protein_edges.keys():
52
+ if protein_edges[edge_type]:
53
+ mapped_edges = set()
54
+ for edge in protein_edges[edge_type]:
55
+ # Get source and target node types from edge_type
56
+ source_type, _, target_type = edge_type
57
+ # Map indices back to original IDs
58
+ source_id = reverse_id_mapping[source_type][edge[0]]
59
+ target_id = reverse_id_mapping[target_type][edge[1]]
60
+ mapped_edges.add((source_id, target_id))
61
+ protein_edges[edge_type] = mapped_edges
62
+
63
+ return protein_edges
64
+
65
+ def _filter_edges(protein_id, protein_edges, prediction_df, limit=10):
66
+
67
+ filtered_edges = {}
68
+
69
+ prediction_categories = prediction_df['GO_category'].unique()
70
+ prediction_categories = [GO_CATEGORY_MAPPING[category] for category in prediction_categories]
71
+ go_category_reverse_mapping = {v:k for k, v in GO_CATEGORY_MAPPING.items()}
72
+
73
+ for edge_type, edges in protein_edges.items():
74
+ # Skip if edges is empty
75
+ if edges is None or len(edges) == 0:
76
+ continue
77
+
78
+ if edge_type[2] in prediction_categories:
79
+ category_mask = (prediction_df['GO_category'] == go_category_reverse_mapping[edge_type[2]]) & (prediction_df['UniProt_ID'] == protein_id)
80
+ category_predictions = prediction_df[category_mask]
81
+
82
+ if len(category_predictions) > 0:
83
+ category_predictions = category_predictions.sort_values(by='Probability', ascending=False)
84
+
85
+ # Convert set to list for easier filtering
86
+ edges_list = list(edges)
87
+
88
+ # Filter valid edges and store with probabilities
89
+ valid_edges = []
90
+ for _, row in category_predictions.iterrows():
91
+ term = row['GO_ID']
92
+ prob = row['Probability']
93
+ matching_edges = [(edge, prob) for edge in edges_list if edge[1] == term]
94
+ valid_edges.extend(matching_edges)
95
+ if len(valid_edges) >= limit:
96
+ break
97
+ filtered_edges[edge_type] = valid_edges # Remove set conversion to preserve probabilities
98
+ else:
99
+ # If no predictions, include all edges up to limit without probabilities
100
+ filtered_edges[edge_type] = [(edge, None) for edge in list(edges)[:limit]]
101
+ else:
102
+ # For non-GO edges, include all edges up to limit without probabilities
103
+ filtered_edges[edge_type] = [(edge, None) for edge in list(edges)[:limit]]
104
+
105
+ return filtered_edges
106
+
107
+
108
+ def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
109
+ protein_edges = _gather_protein_edges(data, protein_id)
110
+ visualized_edges = _filter_edges(protein_id, protein_edges, prediction_df, limit)
111
+ print(f'Edges to be visualized: {visualized_edges}')
112
+
113
+ net = Network(height="600px", width="100%", directed=True, notebook=False)
114
+
115
+ # Create groups configuration from NODE_TYPE_COLORS
116
+ groups_config = {}
117
+ for node_type, color in NODE_TYPE_COLORS.items():
118
+ groups_config[node_type] = {
119
+ "color": {"background": color, "border": color}
120
+ }
121
+
122
+ # Convert groups_config to a JSON-compatible string
123
+ import json
124
+ groups_json = json.dumps(groups_config)
125
+
126
+ # Configure physics options with settings for better clustering
127
+ net.set_options("""{
128
+ "physics": {
129
+ "enabled": true,
130
+ "barnesHut": {
131
+ "gravitationalConstant": -1000,
132
+ "springLength": 250,
133
+ "springConstant": 0.001,
134
+ "damping": 0.09,
135
+ "avoidOverlap": 0
136
+ },
137
+ "forceAtlas2Based": {
138
+ "gravitationalConstant": -50,
139
+ "centralGravity": 0.01,
140
+ "springLength": 100,
141
+ "springConstant": 0.08,
142
+ "damping": 0.4,
143
+ "avoidOverlap": 0
144
+ },
145
+ "solver": "barnesHut",
146
+ "stabilization": {
147
+ "enabled": true,
148
+ "iterations": 1000,
149
+ "updateInterval": 25
150
+ }
151
+ },
152
+ "layout": {
153
+ "improvedLayout": true,
154
+ "hierarchical": {
155
+ "enabled": false
156
+ }
157
+ },
158
+ "interaction": {
159
+ "hover": true,
160
+ "navigationButtons": true,
161
+ "multiselect": true
162
+ },
163
+ "configure": {
164
+ "enabled": true,
165
+ "filter": ["physics", "layout", "manipulation"],
166
+ "showButton": true
167
+ },
168
+ "groups": """ + groups_json + "}")
169
+
170
+ # Add the main protein node
171
+ net.add_node(protein_id,
172
+ label=f"Protein: {protein_id}",
173
+ color={'background': 'white', 'border': '#c1121f'},
174
+ borderWidth=4,
175
+ shape="dot",
176
+ font={'color': '#000000', 'size': 15},
177
+ group='Protein',
178
+ size=30,
179
+ mass=2.5)
180
+
181
+ # Track added nodes to avoid duplication
182
+ added_nodes = {protein_id}
183
+
184
+ # Add edges and target nodes
185
+ for edge_type, edges in visualized_edges.items():
186
+ source_type, relation_type, target_type = edge_type
187
+
188
+ for edge_info in edges:
189
+ edge, probability = edge_info
190
+ source, target = edge[0], edge[1]
191
+ source_str = str(source)
192
+ target_str = str(target)
193
+
194
+ # Add source node if not present
195
+ if source_str not in added_nodes:
196
+ net.add_node(source_str,
197
+ label=f"{source_str}",
198
+ shape="dot",
199
+ font={'color': '#000000', 'size': 12},
200
+ title=f"{source_type}: {source_str}",
201
+ group=source_type,
202
+ size=15,
203
+ mass=1.5)
204
+ added_nodes.add(source_str)
205
+
206
+ # Add target node if not present
207
+ if target_str not in added_nodes:
208
+ net.add_node(target_str,
209
+ label=f"{target_str}",
210
+ shape="dot",
211
+ font={'color': '#000000', 'size': 12},
212
+ title=f"{target_type}: {target_str}",
213
+ group=target_type,
214
+ size=15,
215
+ mass=1.5)
216
+ added_nodes.add(target_str)
217
+
218
+ # Add edge with relationship type and probability as label
219
+ edge_label = f"{relation_type}"
220
+ if probability is not None:
221
+ edge_label += f"(P={probability:.2f})"
222
+ net.add_edge(source_str, target_str,
223
+ label=edge_label,
224
+ color='#666666',
225
+ title=edge_label,
226
+ length=200,
227
+ smooth={'type': 'curvedCW', 'roundness': 0.1})
228
+ else:
229
+ net.add_edge(source_str, target_str,
230
+ label=edge_label,
231
+ font={'size': 0},
232
+ color='#666666',
233
+ title=edge_label,
234
+ length=200,
235
+ smooth={'type': 'curvedCW', 'roundness': 0.1})
236
+
237
+ # Save graph to a protein-specific file in a temporary directory
238
+ os.makedirs('temp_viz', exist_ok=True)
239
+ file_path = os.path.join('temp_viz', f'{protein_id}_graph.html')
240
+ net.save_graph(file_path)
241
+
242
+ return file_path, visualized_edges