File size: 1,951 Bytes
0365501
 
 
 
 
 
c157cd5
0365501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3193e49
 
0365501
 
 
 
3193e49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0365501
 
 
 
 
 
 
 
 
 
c6e70f1
3193e49
c6e70f1
0365501
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
from typing import Any, Optional, Tuple
from langchain.chains import ConversationChain
from langchain.llms import HuggingFaceHub
from langchain.llms import OpenAI
from threading import Lock

def load_chain_openai(api_key: str):
    os.environ["OPENAI_API_KEY"] = api_key
    llm = OpenAI(temperature=0)
    chain = ConversationChain(llm=llm)
    os.environ["OPENAI_API_KEY"] = ""
    return chain


def load_chain_falcon(api_key: str):
    os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
    llm = HuggingFaceHub(repo_id="tiiuae/falcon-7b-instruct", model_kwargs={"temperature": 0.9})
    chain = ConversationChain(llm=llm)
    os.environ["HUGGINGFACEHUB_API_TOKEN"] = ""
    return chain

class ChatWrapper:
    def __init__(self, chain_type: str, api_key: str = ''):
        self.api_key = api_key
        self.chain_type = chain_type
        self.history = []
        self.lock = Lock()

        if self.api_key:
            if chain_type == 'openai':
                self.chain = load_chain_openai(self.api_key)
            elif chain_type == 'falcon':
                self.chain = load_chain_falcon(self.api_key)
            else:
                raise ValueError(f'Invalid chain_type: {chain_type}')
        else:
            self.chain = None
    
    def clear_api_key(self):
        if hasattr(self, 'api_key'):
            del self.api_key

    def __call__(self, inp: str):
        self.lock.acquire()
        try:
            if self.chain is None:
                self.history.append((inp, "Please add your API key to proceed."))
                return self.history

            output = self.chain.run(input=inp)
            self.history.append((inp, output))
        except Exception as e:
            self.history.append((inp, f"An error occurred: {e}"))
        finally:
            self.clear_api_key() # API key is cleared after running each chain in the class
            self.lock.release()

        return self.history