File size: 2,696 Bytes
c877b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#OpenAI(api_key='sk-proj-PxCAkOqCTsVhVWTJSMKJT3BlbkFJz0J48QsSGmrt9Qjud2Sl') 
#embedding_function = OpenAIEmbeddingFunction(api_key='sk-proj-PxCAkOqCTsVhVWTJSMKJT3BlbkFJz0J48QsSGmrt9Qjud2Sl', model_name='text-embedding-3-small')
from dotenv import load_dotenv
import os
from openai import OpenAI

client = OpenAI(api_key='XX')
import pprint
from halo import Halo
import chromadb
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction


# Load environment variables
load_dotenv()
pp = pprint.PrettyPrinter(indent=4)

def generate_response(messages):
    spinner = Halo(text='Loading...', spinner='dots')
    spinner.start()
    model_name = 'gpt-3.5-turbo-0301'
    response = client.chat.completions.create(model=model_name,
    messages=messages,
    temperature=0.5,
    max_tokens=250)

    spinner.stop()
    print("Request:")
    pp.pprint(messages)

    print(f"Completion tokens: {response.usage.completion_tokens}, Prompt tokens: {response.usage.prompt_tokens}, Total tokens: {response.usage.total_tokens}")
    return response.choices[0].message

def main():
    chroma_client = chromadb.Client()
    embedding_function = OpenAIEmbeddingFunction(api_key='XX', model_name='text-embedding-3-small')
    collection = chroma_client.create_collection(name="conversations", embedding_function=embedding_function)
    current_id = 0
    while True:
        chat_history = []
        chat_metadata = []
        history_ids = []

        messages=[
            {"role": "system", "content": "You are a kind and wise wizard"}
            ]
        input_text = input("You: ")
        if input_text.lower() == "quit":
            break

        results = collection.query(
            query_texts=[input_text],
            where={"role": "assistant"},
            n_results=2
        )

        # append the query result into the messages
        for res in results['documents'][0]:
            messages.append({"role": "user", "content": f"previous chat: {res}"})

        # append user input at the end of conversation chain
        messages.append({"role": "user", "content": input_text})
        response = generate_response(messages)

        chat_metadata.append({"role":"user"})
        chat_history.append(input_text)
        chat_metadata.append({"role":"assistant"})
        chat_history.append(response.content)
        current_id += 1
        history_ids.append(f"id_{current_id}")
        current_id += 1
        history_ids.append(f"id_{current_id}")
        collection.add(
            documents=chat_history,
            metadatas=chat_metadata,
            ids=history_ids
        )
        print(f"Wizard: {response.content}")

if __name__ == "__main__":
    main()