Rohan Kataria
new files
0365501
raw
history blame
1.84 kB
import os
from typing import Any, Optional, Tuple
from langchain.chains import ConversationChain
from langchain.llms import HuggingFaceHub
from langchain.llms import OpenAI
from threading import Lock
def load_chain_openai(api_key: str):
os.environ["OPENAI_API_KEY"] = api_key
llm = OpenAI(temperature=0)
chain = ConversationChain(llm=llm)
os.environ["OPENAI_API_KEY"] = ""
return chain
def load_chain_falcon(api_key: str):
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
llm = HuggingFaceHub(repo_id="tiiuae/falcon-7b-instruct", model_kwargs={"temperature": 0.9})
chain = ConversationChain(llm=llm)
os.environ["HUGGINGFACEHUB_API_TOKEN"] = ""
return chain
class ChatWrapper:
def __init__(self, chain_type: str, api_key: str = ''):
self.api_key = api_key
self.chain_type = chain_type
self.history = []
self.lock = Lock()
if self.api_key:
if chain_type == 'openai':
self.chain = load_chain_openai(self.api_key)
elif chain_type == 'falcon':
self.chain = load_chain_falcon(self.api_key)
else:
raise ValueError(f'Invalid chain_type: {chain_type}')
else:
self.chain = None
def __call__(self, inp: str):
self.lock.acquire()
try:
if self.chain is None:
self.history.append((inp, "Please add your API key to proceed."))
return self.history
output = self.chain.run(input=inp)
self.history.append((inp, output))
except Exception as e:
self.history.append((inp, f"An error occurred: {e}"))
finally:
self.api_key = '' # API key is cleared after running each chain
self.lock.release()
return self.history