File size: 598 Bytes
ee3a625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import os
from typing import List, Optional

from langchain.chains.base import Chain
from langchain.chains.summarize import load_summarize_chain

from app_modules.llm_inference import LLMInference


class SummarizeChain(LLMInference):
    def __init__(self, llm_loader):
        super().__init__(llm_loader)

    def create_chain(self) -> Chain:
        chain = load_summarize_chain(self.llm_loader.llm, chain_type="refine")
        return chain

    def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
        result = chain(inputs, return_only_outputs=True)
        return result