test commited on
Commit
a5ccdb2
1 Parent(s): 793072f

add chatgpt class

Browse files
Files changed (2) hide show
  1. chat.py +65 -0
  2. requirements.txt +2 -1
chat.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ import openai
5
+
6
+ USER_MSG = str
7
+ BOT_MSG = str
8
+
9
+
10
+ class ChatGpt:
11
+ def __init__(self, api_key, max_tokens=4096):
12
+ self.api_key = api_key
13
+ self.max_tokens = max_tokens
14
+ self.message_history = []
15
+ self.total_tokens = 0
16
+
17
+ # Set up the OpenAI API client
18
+ openai.api_key = self.api_key
19
+
20
+ def add_message(self, role, content):
21
+ self.message_history.append({"role": role, "content": content})
22
+ self._truncate_history()
23
+
24
+ def add_system_message(self, content):
25
+ self.add_message("system", content)
26
+
27
+ def generate_response(self, user_input) -> Tuple[USER_MSG, BOT_MSG]:
28
+ self.add_message("user", user_input)
29
+ response = self._call_openai_api(self.message_history)
30
+ self.add_message("assistant", response)
31
+
32
+ return user_input, response
33
+
34
+ def _truncate_history(self):
35
+ while self.total_tokens > self.max_tokens:
36
+ if self.message_history[0]["role"] != "system":
37
+ self.message_history.pop(0)
38
+ else:
39
+ break
40
+
41
+ def _call_openai_api(self, messages) -> str:
42
+ response = openai.ChatCompletion.create(
43
+ model="gpt-3.5-turbo", messages=messages
44
+ )
45
+ self.total_tokens += response["usage"]["total_tokens"]
46
+ return response["choices"][0]["message"]["content"].strip()
47
+
48
+
49
+ if __name__ == "__main__":
50
+ chat = ChatGpt(os.getenv("OPENAI_API_KEY"))
51
+
52
+ chat.add_system_message("The assistant can answer questions and tell jokes.")
53
+ user_input = "Tell me a joke."
54
+ user_msg, bot_response = chat.generate_response(user_input)
55
+ assert user_msg == user_input
56
+ print("User:", user_msg)
57
+ print("Assistant:", bot_response)
58
+ print("Total Tokens:", chat.total_tokens)
59
+
60
+ user_input = "another one"
61
+ user_msg, bot_response = chat.generate_response(user_input)
62
+ assert user_msg == user_input
63
+ print("User:", user_msg)
64
+ print("Assistant:", bot_response)
65
+ print("Total Tokens:", chat.total_tokens)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio==3.26.0
2
  requests==2.28.2
3
- boto3==1.26.113
 
 
1
  gradio==3.26.0
2
  requests==2.28.2
3
+ boto3==1.26.113
4
+ openai==0.27.4, <1.0