import os | |
from typing import List, Optional | |
from langchain.chains.base import Chain | |
from langchain.chains.summarize import load_summarize_chain | |
from app_modules.llm_inference import LLMInference | |
class SummarizeChain(LLMInference): | |
def __init__(self, llm_loader): | |
super().__init__(llm_loader) | |
def create_chain(self) -> Chain: | |
chain = load_summarize_chain(self.llm_loader.llm, chain_type="refine") | |
return chain | |
def run_chain(self, chain, inputs, callbacks: Optional[List] = []): | |
result = chain(inputs, return_only_outputs=True) | |
return result | |