import os import uuid from typing import List, Dict, Optional import pandas as pd from autorag.deploy import GradioRunner from autorag.deploy.api import RetrievedPassage from autorag.nodes.generator.base import BaseGenerator from autorag.utils import fetch_contents empty_retrieved_passage = RetrievedPassage( content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None ) class GradioStreamRunner(GradioRunner): def __init__(self, config: Dict, project_dir: Optional[str] = None): super().__init__(config, project_dir) data_dir = os.path.join(project_dir, "data") self.corpus_df = pd.read_parquet( os.path.join(data_dir, "corpus.parquet"), engine="pyarrow" ) def stream_run(self, query: str): previous_result = pd.DataFrame( { "qid": str(uuid.uuid4()), "query": [query], "retrieval_gt": [[]], "generation_gt": [""], } ) # pseudo qa data for execution for module_instance, module_param in zip( self.module_instances, self.module_params ): if not isinstance(module_instance, BaseGenerator): new_result = module_instance.pure( previous_result=previous_result, **module_param ) duplicated_columns = previous_result.columns.intersection( new_result.columns ) drop_previous_result = previous_result.drop( columns=duplicated_columns ) previous_result = pd.concat( [drop_previous_result, new_result], axis=1 ) else: # retrieved_passages = self.extract_retrieve_passage( # previous_result # ) # yield "", retrieved_passages # Start streaming of the result assert len(previous_result) == 1 prompt: str = previous_result["prompts"].tolist()[0] for delta in module_instance.stream(prompt=prompt, **module_param): yield delta, [empty_retrieved_passage] def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]: retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0] contents = fetch_contents(self.corpus_df, [retrieved_ids])[0] if "path" in self.corpus_df.columns: paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[ 0 ] else: paths = [None] * len(retrieved_ids) metadatas = fetch_contents( self.corpus_df, [retrieved_ids], column_name="metadata" )[0] if "start_end_idx" in self.corpus_df.columns: start_end_indices = fetch_contents( self.corpus_df, [retrieved_ids], column_name="start_end_idx" )[0] else: start_end_indices = [None] * len(retrieved_ids) return list( map( lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage( content=content, doc_id=doc_id, filepath=path, file_page=metadata.get("page", None), start_idx=start_end_idx[0] if start_end_idx else None, end_idx=start_end_idx[1] if start_end_idx else None, ), contents, retrieved_ids, paths, metadatas, start_end_indices, ) )