Lawyer-ChatBot / app.py
Krishnachaitanya2004's picture
Update app.py
df59db4 verified
# !pip install accelerate
# !pip install chromadb
# !pip install "unstructured[all-docs]"
from langchain.vectorstores import Chroma
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import pipeline
import torch
from langchain.llms import HuggingFacePipeline
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.chains import RetrievalQA
import streamlit as st
embeddings = SentenceTransformerEmbeddings(model_name="multi-qa-mpnet-base-dot-v1")
persist_directory = "chroma"
# Persist the database to disk
db = Chroma(persist_directory,embeddings)
# To save and load the saved vector db (if needed in the future)
# Persist the database to disk
# db.persist()
# db = Chroma(persist_directory="db", embedding_function=embeddings)
checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
# Initialize the tokenizer and base model for text generation
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLM.from_pretrained(
checkpoint,
device_map="auto",
torch_dtype=torch.float32
)
pipe = pipeline(
'text2text-generation',
model = base_model,
tokenizer = tokenizer,
max_length = 512,
do_sample = True,
temperature = 0.3,
top_p= 0.95
)
# Initialize a local language model pipeline
local_llm = HuggingFacePipeline(pipeline=pipe)
# Create a RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
llm=local_llm,
chain_type='stuff',
retriever=db.as_retriever(search_type="similarity", search_kwargs={"k": 2}),
return_source_documents=True,
)
st.title("Lawyer Bot")
st.subheader("A chatbot to answer your legal questions trained on IPC")
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("What is up?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Get response from chatbot
with st.chat_message("assistant"):
response = qa_chain(prompt)
print(response['result'])
st.markdown(response["result"])
st.session_state.messages.append({"role": "assistant", "content": response})