momegas commited on
Commit
b1f1dd7
Β·
1 Parent(s): bfff554

🀝 Add custom prompt template

Browse files
Files changed (3) hide show
  1. example.ipynb +66 -11
  2. megabots/__init__.py +39 -6
  3. tests/test_bots.py +1 -1
example.ipynb CHANGED
@@ -2,7 +2,30 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "metadata": {},
7
  "outputs": [
8
  {
@@ -12,35 +35,67 @@
12
  "Using model: gpt-3.5-turbo\n",
13
  "Loading path from disk...\n"
14
  ]
 
 
 
 
 
 
 
 
 
 
15
  }
16
  ],
17
  "source": [
18
- "from megabots import bot\n",
19
- "from dotenv import load_dotenv\n",
20
- "\n",
21
- "load_dotenv()\n",
22
- "\n",
23
- "qnabot = bot(\"qna-over-docs\", index=\"./index.pkl\")"
24
  ]
25
  },
26
  {
27
  "cell_type": "code",
28
- "execution_count": 2,
29
  "metadata": {},
30
  "outputs": [
 
 
 
 
 
 
 
 
31
  {
32
  "data": {
33
  "text/plain": [
34
- "'The document does not provide an answer to the question \"What is the meaning of life?\".\\nSOURCES:'"
35
  ]
36
  },
37
- "execution_count": 2,
38
  "metadata": {},
39
  "output_type": "execute_result"
40
  }
41
  ],
42
  "source": [
43
- "qnabot.ask(\"What is the meaning of life?\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ]
45
  }
46
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 8,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/plain": [
11
+ "True"
12
+ ]
13
+ },
14
+ "execution_count": 8,
15
+ "metadata": {},
16
+ "output_type": "execute_result"
17
+ }
18
+ ],
19
+ "source": [
20
+ "from megabots import bot\n",
21
+ "from dotenv import load_dotenv\n",
22
+ "\n",
23
+ "load_dotenv()"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 9,
29
  "metadata": {},
30
  "outputs": [
31
  {
 
35
  "Using model: gpt-3.5-turbo\n",
36
  "Loading path from disk...\n"
37
  ]
38
+ },
39
+ {
40
+ "data": {
41
+ "text/plain": [
42
+ "'The first roster of the Avengers included Iron Man, Thor, Hulk, Ant-Man, and the Wasp.'"
43
+ ]
44
+ },
45
+ "execution_count": 9,
46
+ "metadata": {},
47
+ "output_type": "execute_result"
48
  }
49
  ],
50
  "source": [
51
+ "qnabot = bot(\"qna-over-docs\", index=\"./index.pkl\")\n",
52
+ "qnabot.ask(\"what was the first roster of the avengers?\")"
 
 
 
 
53
  ]
54
  },
55
  {
56
  "cell_type": "code",
57
+ "execution_count": 10,
58
  "metadata": {},
59
  "outputs": [
60
+ {
61
+ "name": "stdout",
62
+ "output_type": "stream",
63
+ "text": [
64
+ "Using model: gpt-3.5-turbo\n",
65
+ "Loading path from disk...\n"
66
+ ]
67
+ },
68
  {
69
  "data": {
70
  "text/plain": [
71
+ "\"Hmmm! Let me think about that... Ah yes, the original Avengers lineup included Iron Man, Thor, Hulk, Ant-Man, and the Wasp. They were like the ultimate superhero squad, except for maybe the Teenage Mutant Ninja Turtles. But let's be real, they were just a bunch of turtles who liked pizza.\""
72
  ]
73
  },
74
+ "execution_count": 10,
75
  "metadata": {},
76
  "output_type": "execute_result"
77
  }
78
  ],
79
  "source": [
80
+ "prompt_template = \"\"\"\n",
81
+ "Use the following pieces of context to answer the question at the end. \n",
82
+ "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
83
+ "Be very playfull and humourous in your responses. always try to make the user laugh.\n",
84
+ "Always start your answers with 'Hmmm! Let me think about that...'\n",
85
+ "{context}\n",
86
+ "\n",
87
+ "Question: {question}\n",
88
+ "Helpful humorous answer:\"\"\"\n",
89
+ "\n",
90
+ "load_dotenv()\n",
91
+ "\n",
92
+ "qnabot = bot(\n",
93
+ " \"qna-over-docs\",\n",
94
+ " index=\"./index.pkl\",\n",
95
+ " prompt_template=prompt_template,\n",
96
+ " prompt_variables=[\"context\", \"question\"],\n",
97
+ ")\n",
98
+ "qnabot.ask(\"what was the first roster of the avengers?\")\n"
99
  ]
100
  }
101
  ],
megabots/__init__.py CHANGED
@@ -9,6 +9,12 @@ from fastapi import FastAPI
9
  import pickle
10
  import os
11
  from dotenv import load_dotenv
 
 
 
 
 
 
12
 
13
  load_dotenv()
14
 
@@ -17,10 +23,11 @@ class Bot:
17
  def __init__(
18
  self,
19
  model: str | None = None,
20
- prompt: str | None = None,
 
21
  memory: str | None = None,
22
  index: str | None = None,
23
- source: str | None = None,
24
  verbose: bool = False,
25
  temperature: int = 0,
26
  ):
@@ -29,7 +36,24 @@ class Bot:
29
  self.load_or_create_index(index)
30
 
31
  # Load the question-answering chain for the selected model
32
- self.chain = load_qa_with_sources_chain(self.llm, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def select_model(self, model: str | None, temperature: float):
35
  # Select and set the appropriate model based on the provided input
@@ -43,6 +67,13 @@ class Bot:
43
 
44
  def create_loader(self, index: str | None):
45
  # Create a loader based on the provided directory (either local or S3)
 
 
 
 
 
 
 
46
  self.loader = DirectoryLoader(index, recursive=True)
47
 
48
  def load_or_create_index(self, index_path: str):
@@ -100,10 +131,9 @@ SUPPORTED_MODELS = {}
100
  def bot(
101
  task: str | None = None,
102
  model: str | None = None,
103
- prompt: str | None = None,
104
- memory: str | None = None,
105
  index: str | None = None,
106
- source: str | None = None,
107
  verbose: bool = False,
108
  temperature: int = 0,
109
  **kwargs,
@@ -135,6 +165,9 @@ def bot(
135
  return SUPPORTED_TASKS[task]["impl"](
136
  model=model or task_defaults["model"],
137
  index=index or task_defaults["index"],
 
 
 
138
  verbose=verbose,
139
  **kwargs,
140
  )
 
9
  import pickle
10
  import os
11
  from dotenv import load_dotenv
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.chains.question_answering import load_qa_chain
14
+ from langchain.chains.conversational_retrieval.prompts import (
15
+ CONDENSE_QUESTION_PROMPT,
16
+ QA_PROMPT,
17
+ )
18
 
19
  load_dotenv()
20
 
 
23
  def __init__(
24
  self,
25
  model: str | None = None,
26
+ prompt_template: str | None = None,
27
+ prompt_variables: list[str] | None = None,
28
  memory: str | None = None,
29
  index: str | None = None,
30
+ sources: bool | None = False,
31
  verbose: bool = False,
32
  temperature: int = 0,
33
  ):
 
36
  self.load_or_create_index(index)
37
 
38
  # Load the question-answering chain for the selected model
39
+ self.chain = self.create_chain(
40
+ prompt_template, prompt_variables, verbose=verbose
41
+ )
42
+
43
+ def create_chain(
44
+ self,
45
+ prompt_template: str | None = None,
46
+ prompt_variables: list[str] | None = None,
47
+ verbose: bool = False,
48
+ ):
49
+ prompt = (
50
+ PromptTemplate(template=prompt_template, input_variables=prompt_variables)
51
+ if prompt_template is not None and prompt_variables is not None
52
+ else QA_PROMPT
53
+ )
54
+ return load_qa_chain(
55
+ self.llm, chain_type="stuff", verbose=verbose, prompt=prompt
56
+ )
57
 
58
  def select_model(self, model: str | None, temperature: float):
59
  # Select and set the appropriate model based on the provided input
 
67
 
68
  def create_loader(self, index: str | None):
69
  # Create a loader based on the provided directory (either local or S3)
70
+ if index is None:
71
+ raise RuntimeError(
72
+ """
73
+ Impossible to find a valid index.
74
+ Either provide a valid path to a pickle file or a directory.
75
+ """
76
+ )
77
  self.loader = DirectoryLoader(index, recursive=True)
78
 
79
  def load_or_create_index(self, index_path: str):
 
131
  def bot(
132
  task: str | None = None,
133
  model: str | None = None,
134
+ prompt_template: str | None = None,
135
+ prompt_variables: list[str] | None = None,
136
  index: str | None = None,
 
137
  verbose: bool = False,
138
  temperature: int = 0,
139
  **kwargs,
 
165
  return SUPPORTED_TASKS[task]["impl"](
166
  model=model or task_defaults["model"],
167
  index=index or task_defaults["index"],
168
+ prompt_template=prompt_template,
169
+ prompt_variables=prompt_variables,
170
+ temperature=temperature,
171
  verbose=verbose,
172
  **kwargs,
173
  )
tests/test_bots.py CHANGED
@@ -19,7 +19,7 @@ def test_ask():
19
  # Assert that the answer contains the correct answer
20
  assert correct_answer in answer
21
  # Assert that the answer contains the sources
22
- assert sources in answer
23
 
24
 
25
  def test_save_load_index():
 
19
  # Assert that the answer contains the correct answer
20
  assert correct_answer in answer
21
  # Assert that the answer contains the sources
22
+ assert sources not in answer
23
 
24
 
25
  def test_save_load_index():