bobsbimal58's picture
Rename app_palm.py to app.py
db5b339
raw
history blame
3.38 kB
from langchain import PromptTemplate, LLMChain
from langchain.llms import CTransformers, HuggingFacePipeline, GooglePalm
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings
from io import BytesIO
from langchain.document_loaders import PyPDFLoader
import gradio as gr
import chromadb
from dotenv import load_dotenv
from constants import CHROMA_SETTINGS
from io import BytesIO
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM, AutoModel
import gc
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain.chat_models import ChatGooglePalm
import google.generativeai as genai
from langchain_google_genai import GoogleGenerativeAIEmbeddings
gc.collect()
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#model= AutoModelForCausalLM.from_pretrained(local_llm, device_map= device)
llm= ChatGooglePalm()
#llm= HuggingFacePipeline.from_model_id(model_id=local_llm, task='text-generation', device=0, pipeline_kwargs={"max_new_tokens": 1000})
embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME')
persist_directory = os.environ.get('PERSIST_DIRECTORY')
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
google_api_key= os.environ.get('GOOGLE_API_KEY')
if not load_dotenv():
print("Could not load .env file or it is empty. Please check if it exists and is readable.")
exit(1)
print("Loading embeddings model...")
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
#embeddings= pipeline("feature-extraction", model="WhereIsAI/UAE-Large-V1")
# Chroma client
chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS , path=persist_directory)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS, client=chroma_client)
prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question'])
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# activate/deactivate the streaming StdOut callback for LLMs
chain_type_kwargs = {"prompt": prompt}
input_gradio= gr.Text(
label="Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your question here",
container=False,
)
def get_response(input_gradio ):
query=input_gradio
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= False, chain_type_kwargs=chain_type_kwargs, verbose=True)
response= qa(query)
return response['result']
iface= gr.Interface(
fn=get_response,
inputs=input_gradio,
outputs="text",
title="Tsetlin Machine Chatbot",
description="A chatbot that uses the LLM to answer anything regarding TM",
allow_flagging='never'
)
# Interactive questions and answers
iface.launch()