import asyncio from pathlib import Path import pandas as pd from typing import Tuple, Optional from graphrag.config import GraphRagConfig, load_config, resolve_paths from graphrag.index.create_pipeline_config import create_pipeline_config from graphrag.logging import PrintProgressReporter from graphrag.utils.storage import _create_storage, _load_table_from_storage import graphrag.api as api class StreamlitProgressReporter(PrintProgressReporter): def __init__(self, placeholder): super().__init__("") self.placeholder = placeholder def success(self, message: str): self.placeholder.success(message) def _resolve_parquet_files( root_dir: str, config: GraphRagConfig, parquet_list: list[str], optional_list: list[str], ) -> dict[str, pd.DataFrame]: """Read parquet files to a dataframe dict.""" dataframe_dict = {} pipeline_config = create_pipeline_config(config) storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage) for parquet_file in parquet_list: df_key = parquet_file.split(".")[0] df_value = asyncio.run( _load_table_from_storage(name=parquet_file, storage=storage_obj) ) dataframe_dict[df_key] = df_value for optional_file in optional_list: file_exists = asyncio.run(storage_obj.has(optional_file)) df_key = optional_file.split(".")[0] if file_exists: df_value = asyncio.run( _load_table_from_storage(name=optional_file, storage=storage_obj) ) dataframe_dict[df_key] = df_value else: dataframe_dict[df_key] = None return dataframe_dict def run_global_search( config_filepath: Optional[str], data_dir: Optional[str], root_dir: str, community_level: int, response_type: str, streaming: bool, query: str, progress_placeholder, ) -> Tuple[str, dict]: """Perform a global search with a given query.""" root = Path(root_dir).resolve() config = load_config(root, config_filepath) reporter = StreamlitProgressReporter(progress_placeholder) config.storage.base_dir = data_dir or config.storage.base_dir resolve_paths(config) dataframe_dict = _resolve_parquet_files( root_dir=root_dir, config=config, parquet_list=[ "create_final_nodes.parquet", "create_final_entities.parquet", "create_final_community_reports.parquet", ], optional_list=[], ) final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] final_entities: pd.DataFrame = dataframe_dict["create_final_entities"] final_community_reports: pd.DataFrame = dataframe_dict[ "create_final_community_reports" ] if streaming: async def run_streaming_search(): full_response = "" context_data = None get_context_data = True try: async for stream_chunk in api.global_search_streaming( config=config, nodes=final_nodes, entities=final_entities, community_reports=final_community_reports, community_level=community_level, response_type=response_type, query=query, ): if get_context_data: context_data = stream_chunk get_context_data = False else: full_response += stream_chunk progress_placeholder.markdown(full_response) except Exception as e: progress_placeholder.error(f"Error during streaming search: {e}") return None, None return full_response, context_data result = asyncio.run(run_streaming_search()) if result is None: return "", {} # Graceful fallback return result # Non-streaming logic try: response, context_data = asyncio.run( api.global_search( config=config, nodes=final_nodes, entities=final_entities, community_reports=final_community_reports, community_level=community_level, response_type=response_type, query=query, ) ) reporter.success(f"Global Search Response:\n{response}") return response, context_data except Exception as e: progress_placeholder.error(f"Error during global search: {e}") return "", {} # Graceful fallback def run_local_search( config_filepath: Optional[str], data_dir: Optional[str], root_dir: str, community_level: int, response_type: str, streaming: bool, query: str, progress_placeholder, ) -> Tuple[str, dict]: """Perform a local search with a given query.""" root = Path(root_dir).resolve() config = load_config(root, config_filepath) reporter = StreamlitProgressReporter(progress_placeholder) config.storage.base_dir = data_dir or config.storage.base_dir resolve_paths(config) dataframe_dict = _resolve_parquet_files( root_dir=root_dir, config=config, parquet_list=[ "create_final_nodes.parquet", "create_final_community_reports.parquet", "create_final_text_units.parquet", "create_final_relationships.parquet", "create_final_entities.parquet", ], optional_list=["create_final_covariates.parquet"], ) final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] final_community_reports: pd.DataFrame = dataframe_dict[ "create_final_community_reports" ] final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"] final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"] final_entities: pd.DataFrame = dataframe_dict["create_final_entities"] final_covariates: Optional[pd.DataFrame] = dataframe_dict["create_final_covariates"] if streaming: async def run_streaming_search(): full_response = "" context_data = None get_context_data = True async for stream_chunk in api.local_search_streaming( config=config, nodes=final_nodes, entities=final_entities, community_reports=final_community_reports, text_units=final_text_units, relationships=final_relationships, covariates=final_covariates, community_level=community_level, response_type=response_type, query=query, ): if get_context_data: context_data = stream_chunk get_context_data = False else: full_response += stream_chunk progress_placeholder.markdown(full_response) return full_response, context_data return asyncio.run(run_streaming_search()) response, context_data = asyncio.run( api.local_search( config=config, nodes=final_nodes, entities=final_entities, community_reports=final_community_reports, text_units=final_text_units, relationships=final_relationships, covariates=final_covariates, community_level=community_level, response_type=response_type, query=query, ) ) reporter.success(f"Local Search Response:\n{response}") return response, context_data def run_drift_search( config_filepath: Optional[str], data_dir: Optional[str], root_dir: str, community_level: int, response_type: str, streaming: bool, query: str, progress_placeholder, ) -> Tuple[str, dict]: """Perform a DRIFT search with a given query.""" root = Path(root_dir).resolve() config = load_config(root, config_filepath) reporter = StreamlitProgressReporter(progress_placeholder) config.storage.base_dir = data_dir or config.storage.base_dir resolve_paths(config) dataframe_dict = _resolve_parquet_files( root_dir=root_dir, config=config, parquet_list=[ "create_final_nodes.parquet", "create_final_entities.parquet", "create_final_community_reports.parquet", "create_final_text_units.parquet", "create_final_relationships.parquet", ], optional_list=[], # Remove covariates as it's not supported ) final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] final_entities: pd.DataFrame = dataframe_dict["create_final_entities"] final_community_reports: pd.DataFrame = dataframe_dict[ "create_final_community_reports" ] final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"] final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"] # Note: DRIFT search doesn't support streaming if streaming: progress_placeholder.warning( "Streaming is not supported for DRIFT search. Using standard search instead." ) response, context_data = asyncio.run( api.drift_search( config=config, nodes=final_nodes, entities=final_entities, community_reports=final_community_reports, text_units=final_text_units, relationships=final_relationships, community_level=community_level, query=query, ) ) reporter.success(f"DRIFT Search Response:\n{response}") return response, context_data