Spaces:
Sleeping
Sleeping
File size: 1,951 Bytes
0365501 c157cd5 0365501 3193e49 0365501 3193e49 0365501 c6e70f1 3193e49 c6e70f1 0365501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import os
from typing import Any, Optional, Tuple
from langchain.chains import ConversationChain
from langchain.llms import HuggingFaceHub
from langchain.llms import OpenAI
from threading import Lock
def load_chain_openai(api_key: str):
os.environ["OPENAI_API_KEY"] = api_key
llm = OpenAI(temperature=0)
chain = ConversationChain(llm=llm)
os.environ["OPENAI_API_KEY"] = ""
return chain
def load_chain_falcon(api_key: str):
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
llm = HuggingFaceHub(repo_id="tiiuae/falcon-7b-instruct", model_kwargs={"temperature": 0.9})
chain = ConversationChain(llm=llm)
os.environ["HUGGINGFACEHUB_API_TOKEN"] = ""
return chain
class ChatWrapper:
def __init__(self, chain_type: str, api_key: str = ''):
self.api_key = api_key
self.chain_type = chain_type
self.history = []
self.lock = Lock()
if self.api_key:
if chain_type == 'openai':
self.chain = load_chain_openai(self.api_key)
elif chain_type == 'falcon':
self.chain = load_chain_falcon(self.api_key)
else:
raise ValueError(f'Invalid chain_type: {chain_type}')
else:
self.chain = None
def clear_api_key(self):
if hasattr(self, 'api_key'):
del self.api_key
def __call__(self, inp: str):
self.lock.acquire()
try:
if self.chain is None:
self.history.append((inp, "Please add your API key to proceed."))
return self.history
output = self.chain.run(input=inp)
self.history.append((inp, output))
except Exception as e:
self.history.append((inp, f"An error occurred: {e}"))
finally:
self.clear_api_key() # API key is cleared after running each chain in the class
self.lock.release()
return self.history |