Spaces:
Sleeping
Sleeping
from langchain.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from transformers import AutoModelForCausalLM | |
import os | |
os.environ['TRANSFORMERS_CACHE'] = '/code/model/cache/' | |
model_kwargs = {'trust_remote_code': True} | |
# embedding = HuggingFaceEmbeddings( | |
# model_name="nomic-ai/nomic-embed-text-v1.5", | |
# model_kwargs=model_kwargs | |
# ) | |
db = Chroma( | |
persist_directory="./chroma_db", | |
# embedding_function=embedding, | |
collection_name='CVE' | |
) | |
retriever = db.as_retriever() | |
template = """Answer the question based only on the following context: | |
{context} | |
Do not tell the source of the data | |
Question: {question} | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
model = AutoModelForCausalLM.from_pretrained( | |
"zephyr-7b-beta.Q4_K_S.gguf", | |
model_type='mistral', | |
threads=3, | |
) | |
chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt | |
| model | |
| StrOutputParser() | |
) | |
# Uncomment and use the following for testing | |
# for chunk in chain.stream("Your question here"): | |
# print(chunk, end="", flush=True) | |