Enron_case_RAG / app.py
marcolorenzi98's picture
Update app.py
9663c5d verified
raw
history blame contribute delete
No virus
3.21 kB
from torch import cuda, bfloat16
import torch
import transformers
from transformers import AutoTokenizer
from time import time
import chromadb
from chromadb.config import Settings
from langchain_community.llms import HuggingFacePipeline
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
import gradio as gr
#############################################################################
model_id = "marcolorenzi98/tinyllama-enron-v1"
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
# set quantization configuration to load large model with less GPU memory
# this requires the `bitsandbytes` library
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=bfloat16
)
##############################################################################
model_config = transformers.AutoConfig.from_pretrained(model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,trust_remote_code=True,
config=model_config,
#quantization_config=bnb_config,
device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)
##############################################################################
embedding = SpacyEmbeddings(model_name="en_core_web_sm")
# Embed and store the texts
# Supplying a persist_directory will store the embeddings on disk
persist_directory = 'Enron_case_RAG/Langchain_ChromaDB'
# load from disk
db3 = Chroma(persist_directory=persist_directory,
embedding_function=embedding,
collection_name="Enron_vectorstore"
)
##############################################################################
query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16,
device_map="auto")
llm = HuggingFacePipeline(pipeline=query_pipeline)
retriever = db3.as_retriever()
##############################################################################
def gradio_rag(query):
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
verbose=True)
print(f"Query: {query}\n")
time_1 = time()
result = qa.run(query)
time_2 = time()
print(f"Inference time: {round(time_2-time_1, 3)} sec.")
print("\nResult: ", result)
###############################################################################
demo = gr.Interface(
fn=gradio_rag,
inputs=gr.Textbox(label="Please, write your request here:", placeholder="example: who is Sheila Chang", lines=5),
outputs=gr.Textbox(label="Answer:"),
title='Tiny Llama RAG on Enron Scandal',
description="This is a RAG system based on the SLM Tiny Llama, fine tuned on the Enron Scandal Emails' dataset",
allow_flagging="never"
)
demo.launch(debug=False)