from langchain import PromptTemplate, LLMChain from langchain.llms import CTransformers, HuggingFacePipeline, GooglePalm import os from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.chains import RetrievalQA from langchain.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings from io import BytesIO from langchain.document_loaders import PyPDFLoader import gradio as gr import chromadb from dotenv import load_dotenv from constants import CHROMA_SETTINGS from io import BytesIO import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM, AutoModel import gc from langchain.schema.runnable import RunnableLambda, RunnablePassthrough from langchain.chat_models import ChatGooglePalm import google.generativeai as genai from langchain_google_genai import GoogleGenerativeAIEmbeddings gc.collect() torch.cuda.empty_cache() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #model= AutoModelForCausalLM.from_pretrained(local_llm, device_map= device) llm= ChatGooglePalm() #llm= HuggingFacePipeline.from_model_id(model_id=local_llm, task='text-generation', device=0, pipeline_kwargs={"max_new_tokens": 1000}) embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME') persist_directory = os.environ.get('PERSIST_DIRECTORY') target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) google_api_key= os.environ.get('GOOGLE_API_KEY') if not load_dotenv(): print("Could not load .env file or it is empty. Please check if it exists and is readable.") exit(1) print("Loading embeddings model...") embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") #embeddings= pipeline("feature-extraction", model="WhereIsAI/UAE-Large-V1") # Chroma client chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS , path=persist_directory) db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS, client=chroma_client) prompt_template = """Use the following pieces of information to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer. Context: {context} Question: {question} Only return the helpful answer below and nothing else. Helpful answer: """ prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question']) retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) # activate/deactivate the streaming StdOut callback for LLMs chain_type_kwargs = {"prompt": prompt} input_gradio= gr.Text( label="Prompt", show_label=False, max_lines=2, placeholder="Enter your question here", container=False, ) def get_response(input_gradio ): query=input_gradio qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= False, chain_type_kwargs=chain_type_kwargs, verbose=True) response= qa(query) return response['result'] iface= gr.Interface( fn=get_response, inputs=input_gradio, outputs="text", title="Tsetlin Machine Chatbot", description="A chatbot that uses the LLM to answer anything regarding TM", allow_flagging='never' ) # Interactive questions and answers iface.launch()