{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyMp8bhKotk3mdZcc3U4qqKP", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jLMHfRq9kAP9" }, "outputs": [], "source": [ "!pip install -Uq langchain-community\n", "!pip install -Uq langchain\n", "!pip install -Uq langchainhub\n", "!pip install -Uq langgraph\n", "!pip install -Uq wikipedia\n", "!pip install -Uq scikit-learn\n", "!pip install -Uq chromadb\n", "!pip install -Uq sentence-transformers\n", "!pip install -Uq gpt4all\n", "!pip install -qU google-search-results" ] }, { "cell_type": "code", "source": [ "import os\n", "from google.colab import userdata\n", "os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = userdata.get('HUGGINGFACEHUB_API_TOKEN')\n", "os.environ[\"GOOGLE_CSE_ID\"] = userdata.get('GOOGLE_CSE_ID')\n", "os.environ[\"GOOGLE_API_KEY\"] = userdata.get('GOOGLE_API_KEY')" ], "metadata": { "id": "kPF-3dzGuAfT" }, "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "source": [ "### LLMs" ], "metadata": { "id": "XTtbWrue9l3E" } }, { "cell_type": "code", "source": [ "# HF libraries\n", "from langchain_community.llms import HuggingFaceEndpoint\n", "\n", "# Load the model from the Hugging Face Hub\n", "llm_mid = HuggingFaceEndpoint(repo_id=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n", " temperature=0.1,\n", " max_new_tokens=1024,\n", " repetition_penalty=1.2,\n", " return_full_text=False\n", " )\n", "\n", "llm_small = HuggingFaceEndpoint(repo_id=\"mistralai/Mistral-7B-Instruct-v0.2\",\n", " temperature=0.1,\n", " max_new_tokens=1024,\n", " repetition_penalty=1.2,\n", " return_full_text=False\n", " )" ], "metadata": { "id": "EDZyRq-wuIuy" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Chroma DB" ], "metadata": { "id": "mdMx_T8V9npk" } }, { "cell_type": "code", "source": [ "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from langchain_community.document_loaders import WebBaseLoader\n", "from langchain_community.vectorstores import Chroma\n", "from langchain_community.embeddings import GPT4AllEmbeddings\n", "from langchain.embeddings import HuggingFaceEmbeddings\n", "\n", "# Load\n", "url = \"https://lilianweng.github.io/posts/2023-06-23-agent/\"\n", "loader = WebBaseLoader(url)\n", "docs = loader.load()\n", "\n", "# Split\n", "text_splitter = RecursiveCharacterTextSplitter(\n", " chunk_size=500, chunk_overlap=100\n", ")\n", "all_splits = text_splitter.split_documents(docs)\n", "\n", "# Embed and index\n", "#embedding = GPT4AllEmbeddings()\n", "embedding = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-mpnet-base-v2\")\n", "\n", "\n", "# Index\n", "vectorstore = Chroma.from_documents(\n", " documents=all_splits,\n", " collection_name=\"rag-chroma\",\n", " embedding=embedding,\n", ")\n", "retriever = vectorstore.as_retriever()" ], "metadata": { "id": "LkX9ehoeupSz" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "###State" ], "metadata": { "id": "0A-7_d3G9b8h" } }, { "cell_type": "code", "source": [ "from typing import Annotated, Dict, TypedDict\n", "from langchain_core.messages import BaseMessage\n", "\n", "class GraphState(TypedDict):\n", " \"\"\"\n", " Represents the state of our graph.\n", "\n", " Attributes:\n", " key: A dictionary where each key is a string.\n", " \"\"\"\n", "\n", " keys: Dict[str, any]" ], "metadata": { "id": "fRzYhmOs7_GJ" }, "execution_count": 9, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Nodes and Edges" ], "metadata": { "id": "bPhIdcVD9pgV" } }, { "cell_type": "code", "source": [ "import json\n", "import operator\n", "from typing import Annotated, Sequence, TypedDict\n", "\n", "from langchain import hub\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain.prompts import PromptTemplate\n", "from langchain.schema import Document\n", "from langchain.tools import Tool\n", "from langchain_community.utilities import GoogleSearchAPIWrapper\n", "from langchain_community.vectorstores import Chroma\n", "from langchain_core.output_parsers import StrOutputParser\n", "from langchain_core.runnables import RunnablePassthrough\n", "\n", "### Nodes ###\n", "\n", "def retrieve(state):\n", " \"\"\"\n", " Retrieve documents\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, documents, that contains retrieved documents\n", " \"\"\"\n", " print(\"---RETRIEVE---\")\n", " state_dict = state[\"keys\"]\n", " question = state_dict[\"question\"]\n", " local = state_dict[\"local\"]\n", " documents = retriever.get_relevant_documents(question)\n", "\n", " return {\"keys\": {\"documents\": documents, \"local\": local, \"question\": question}}\n", "\n", "def generate(state):\n", " \"\"\"\n", " Generate answer\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, generation, that contains generation\n", " \"\"\"\n", " print(\"---GENERATE---\")\n", " state_dict = state[\"keys\"]\n", " question = state_dict[\"question\"]\n", " documents = state_dict[\"documents\"]\n", " local = state_dict[\"local\"]\n", "\n", " # Prompt\n", " prompt = hub.pull(\"rlm/rag-prompt\")\n", "\n", " # LLM\n", " llm = llm_mid\n", "\n", " # Post-processing\n", " def format_docs(docs):\n", " return \"\\n\\n\".join(doc.page_content for doc in docs)\n", "\n", " # Chain\n", " rag_chain = prompt | llm | StrOutputParser()\n", "\n", "\n", " # Run\n", " generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n", "\n", " return {\n", " \"keys\": {\"documents\": documents, \"question\": question, \"generation\": generation}\n", " }\n", "\n", "def grade_documents(state):\n", " \"\"\"\n", " Determines whether the retrieved documents are relevant to the question.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): Update documents key with relevant documents\n", " \"\"\"\n", "\n", " print(\"---CHECK RELEVANCE---\")\n", " state_dict = state[\"keys\"]\n", " question = state_dict[\"question\"]\n", " documents = state_dict[\"documents\"]\n", " local = state_dict[\"local\"]\n", "\n", " # LLM\n", " llm = llm_mid\n", "\n", " prompt = PromptTemplate(\n", " template=\"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n\n", " Here is the retrieved document: \\n\\n {context} \\n\\n\n", " Here is the user question: {question} \\n\n", " If the document contains keywords related to the user question, grade it as relevant. \\n\n", " It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n", " Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \\n\n", " Provide the binary score as a JSON with a single key 'score' and no premable or explaination.\n", " \"\"\",\n", " input_variables=[\"question\",\"context\"],\n", " )\n", "\n", " chain = prompt | llm | JsonOutputParser()\n", "\n", " # Score\n", " filtered_docs = []\n", " search = \"No\" #Default to do not opt for web search to supplement retrieval\n", " for d in documents:\n", " score = chain.invoke(\n", " {\n", " \"question\": question,\n", " \"context\": d.page_content,\n", " }\n", " )\n", " grade = score[\"score\"]\n", " if grade == \"yes\":\n", " print(\"---GRADE: DOCUMENT RELEVANT---\")\n", " filtered_docs.append(d)\n", " else:\n", " print(\"---GRADE: DOCUMENT IRRELEVANT---\")\n", " search = \"Yes\" #Perform web search\n", " continue\n", "\n", " return {\n", " \"keys\": {\n", " \"documents\": filtered_docs,\n", " \"question\": question,\n", " \"local\": local,\n", " \"run_web_search\": search,\n", " }\n", " }\n", "\n", "def transform_query(state):\n", " \"\"\"\n", " Transform the query to produce a better question.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): Updates question key with a re-phrased question\n", " \"\"\"\n", " print(\"---TRANSFORM QUERY---\")\n", " state_dict = state[\"keys\"]\n", " question = state_dict[\"question\"]\n", " documents = state_dict[\"documents\"]\n", " local = state_dict[\"local\"]\n", "\n", " # Create a prompt template with format instructions and the query\n", " prompt = PromptTemplate(\n", " template=\"\"\"You are generating questions that are well optimized for retrieval. \\n\n", " Look at the input and try to reasin about the underlying sematic intent / meaning . \\n\n", " Here is the initial question:\n", " \\n -------- \\n\n", " {question}\n", " \\n -------- \\n\n", " Provide an improved question without any premable, only respond with the updated question: \"\"\",\n", " input_variables=[\"question\"],\n", " )\n", "\n", " # Grader\n", " # LLM\n", " llm = llm_mid\n", "\n", " # Prompt\n", " chain = prompt | llm | StrOutputParser()\n", " better_question = chain.invoke({\"question\": question})\n", "\n", " return {\n", " \"keys\": {\"documents\": documents, \"question\": better_question, \"local\": local}\n", " }\n", "\n", "\n", "def web_search(state):\n", " \"\"\"\n", " Web search based on the re-phrased question using google\n", "\n", " Args:\n", " state (dict): The current graph state\n", " Returns:\n", " state (dict): Web results appended to documents.\n", " \"\"\"\n", "\n", " print(\"---WEB SEARCH---\")\n", " state_dict = state[\"keys\"]\n", " question = state_dict[\"question\"]\n", " documents = state_dict[\"documents\"]\n", " local = state_dict[\"local\"]\n", "\n", " websearch = GoogleSearchAPIWrapper(k=3)\n", " google_search = Tool(\n", " name=\"google_search\",\n", " description=\"Search Google for recent results.\",\n", " func=websearch.run,\n", " )\n", " web_search = google_search.run(question)\n", " #filtered_contents = [d[\"page_content\"] for d in web_search if d[\"page_content\"] is not None]\n", " #web_results = \"\\n\".join(filtered_contents)\n", " web_results = Document(page_content=web_search)\n", " documents.append(web_results)\n", "\n", " return {\"keys\": {\"documents\": documents, \"local\": local, \"question\": question}}" ], "metadata": { "id": "1Sn5NCyl9pRE" }, "execution_count": 88, "outputs": [] }, { "cell_type": "code", "source": [ "### Edges ###\n", "\n", "def decide_to_generate(state):\n", " \"\"\"\n", " Determines whether to generate an answer or re-generate a question for web search.\n", "\n", " Args:\n", " state (dict): The current state of the agent, including all keys.\n", "\n", " Returns:\n", " str: Next node to call\n", " \"\"\"\n", "\n", " print(\"---DECIDE TO GENERATE---\")\n", " state_dict = state[\"keys\"]\n", " question = state_dict[\"question\"]\n", " filtered_documents = state_dict[\"documents\"]\n", " search = state_dict[\"run_web_search\"]\n", "\n", " if search == \"Yes\":\n", " # All documents have been filtered check_relevance\n", " # We will re-generate a new query\n", " print(\"---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---\")\n", " return \"transform_query\"\n", " else:\n", " # We have relevant documents, so generate answer\n", " print(\"---DECISION: GENERATE---\")\n", " return \"generate\"" ], "metadata": { "id": "l9djuUIx-_ZK" }, "execution_count": 89, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Graph" ], "metadata": { "id": "Z6g94SltdUEc" } }, { "cell_type": "code", "source": [ "import pprint\n", "from langgraph.graph import END, StateGraph\n", "\n", "workflow = StateGraph(GraphState)\n", "\n", "# Define the nodes\n", "workflow.add_node(\"retrieve\", retrieve) #retrieve\n", "workflow.add_node(\"grade_documents\", grade_documents) # grade documents\n", "workflow.add_node(\"generate\", generate)\n", "workflow.add_node(\"transform_query\", transform_query)\n", "workflow.add_node(\"web_search\", web_search)\n", "\n", "# Build graph\n", "workflow.set_entry_point(\"retrieve\")\n", "workflow.add_edge(\"retrieve\", \"grade_documents\")\n", "workflow.add_conditional_edges(\n", " \"grade_documents\",\n", " decide_to_generate,\n", " {\n", " \"transform_query\": \"transform_query\",\n", " \"generate\": \"generate\",\n", " },\n", ")\n", "workflow.add_edge(\"transform_query\", \"web_search\")\n", "workflow.add_edge(\"web_search\", \"generate\")\n", "workflow.add_edge(\"generate\", END)\n", "\n", "# Compile\n", "app = workflow.compile()" ], "metadata": { "id": "5pyAWscidTUt" }, "execution_count": 90, "outputs": [] }, { "cell_type": "markdown", "source": [ "### RUN" ], "metadata": { "id": "Yb4oGR4Dfoud" } }, { "cell_type": "code", "source": [ "# Run\n", "\n", "inputs = {\n", " \"keys\": {\n", " \"question\": \"Explain how the different types of agent memory work?\",\n", " \"local\": \"No\",\n", " }\n", "}\n", "\n", "for output in app.stream(inputs):\n", " for key, value in output.items():\n", " # Node\n", " pprint.pprint(f\"Node '{key}':\")\n", " # Optional: print full state at each node\n", " # pprint.pprint(value[\"keys\"], ident=2, width=80, depth=None)\n", " pprint.pprint(\"\\n---\\n\")\n", "\n", "# Final generation\n", "pprint.pprint(value['keys']['generation'])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bJH68dQffp_e", "outputId": "4318d425-7284-4275-83b1-f1fcd85c9b38" }, "execution_count": 92, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "---RETRIEVE---\n", "\"Node 'retrieve':\"\n", "'\\n---\\n'\n", "---CHECK RELEVANCE---\n", "---GRADE: DOCUMENT IRRELEVANT---\n", "---GRADE: DOCUMENT RELEVANT---\n", "---GRADE: DOCUMENT RELEVANT---\n", "---GRADE: DOCUMENT IRRELEVANT---\n", "\"Node 'grade_documents':\"\n", "'\\n---\\n'\n", "---DECIDE TO GENERATE---\n", "---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---\n", "---TRANSFORM QUERY---\n", "\"Node 'transform_query':\"\n", "'\\n---\\n'\n", "---WEB SEARCH---\n", "\"Node 'web_search':\"\n", "'\\n---\\n'\n", "---GENERATE---\n", "\"Node 'generate':\"\n", "'\\n---\\n'\n", "\"Node '__end__':\"\n", "'\\n---\\n'\n", "(' \\n'\n", " '\\n'\n", " 'The functionalities of agent memory include recency, importance, relevance, '\n", " 'reflection mechanism, sensory memory, short-term memory, and long-term '\n", " 'memory. Recency gives higher scores to recent events, while Importance '\n", " 'distinguishes mundane from core memories. Relevance depends on how related '\n", " 'the memory is to the current situation or query. Reflection mechanism '\n", " 'synthesizes memories into higher-level inferences over time. Sensory memory '\n", " 'learns embedding representations for raw inputs, Short-term memory handles '\n", " 'in-context learning, and Long-term memory serves as an external vector store '\n", " 'attended to at query time.')\n" ] } ] } ] }