bwook's picture
Upload 23 files
b8a3ef1 verified
import os
import uuid
from typing import List, Dict, Optional
import pandas as pd
from autorag.deploy import GradioRunner
from autorag.deploy.api import RetrievedPassage
from autorag.nodes.generator.base import BaseGenerator
from autorag.utils import fetch_contents
empty_retrieved_passage = RetrievedPassage(
content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None
)
class GradioStreamRunner(GradioRunner):
def __init__(self, config: Dict, project_dir: Optional[str] = None):
super().__init__(config, project_dir)
data_dir = os.path.join(project_dir, "data")
self.corpus_df = pd.read_parquet(
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
)
def stream_run(self, query: str):
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
if not isinstance(module_instance, BaseGenerator):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(
columns=duplicated_columns
)
previous_result = pd.concat(
[drop_previous_result, new_result], axis=1
)
else:
# retrieved_passages = self.extract_retrieve_passage(
# previous_result
# )
# yield "", retrieved_passages
# Start streaming of the result
assert len(previous_result) == 1
prompt: str = previous_result["prompts"].tolist()[0]
for delta in module_instance.stream(prompt=prompt,
**module_param):
yield delta, [empty_retrieved_passage]
def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
if "path" in self.corpus_df.columns:
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
0
]
else:
paths = [None] * len(retrieved_ids)
metadatas = fetch_contents(
self.corpus_df, [retrieved_ids], column_name="metadata"
)[0]
if "start_end_idx" in self.corpus_df.columns:
start_end_indices = fetch_contents(
self.corpus_df, [retrieved_ids], column_name="start_end_idx"
)[0]
else:
start_end_indices = [None] * len(retrieved_ids)
return list(
map(
lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage(
content=content,
doc_id=doc_id,
filepath=path,
file_page=metadata.get("page", None),
start_idx=start_end_idx[0] if start_end_idx else None,
end_idx=start_end_idx[1] if start_end_idx else None,
),
contents,
retrieved_ids,
paths,
metadatas,
start_end_indices,
)
)