Prathamesh1420 commited on
Commit
c7109aa
·
verified ·
1 Parent(s): fe86392

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -44
app.py CHANGED
@@ -1,85 +1,75 @@
1
  import streamlit as st
2
-
3
  from streamlit_chat import message
4
- # from langchain.llms import OpenAI #This import has been replaced by the below import please
5
  from langchain_groq import ChatGroq
6
  from langchain.chains import ConversationChain
7
- from langchain.chains.conversation.memory import (ConversationBufferMemory,
8
- ConversationSummaryMemory,
9
- ConversationBufferWindowMemory
10
-
11
- )
12
 
 
13
  if 'conversation' not in st.session_state:
14
- st.session_state['conversation'] =None
15
  if 'messages' not in st.session_state:
16
- st.session_state['messages'] =[]
17
  if 'API_Key' not in st.session_state:
18
- st.session_state['API_Key'] =''
19
 
20
  # Setting page title and header
21
  st.set_page_config(page_title="Chat GPT Clone", page_icon=":robot_face:")
22
  st.markdown("<h1 style='text-align: center;'>How can I assist you? </h1>", unsafe_allow_html=True)
23
 
24
-
25
  st.sidebar.title("😎")
26
- groq_api_key=st.sidebar.text_input(label="Groq API Key",type="password")
27
  summarise_button = st.sidebar.button("Summarise the conversation", key="summarise")
28
  if summarise_button:
29
- summarise_placeholder = st.sidebar.write("Nice chatting with you my friend ❤️:\n\n"+st.session_state['conversation'].memory.buffer)
30
- #summarise_placeholder.write("Nice chatting with you my friend ❤️:\n\n"+st.session_state['conversation'].memory.buffer)
31
-
32
 
 
33
  def getresponse(userInput, api_key):
 
 
 
 
 
34
 
 
35
  if st.session_state['conversation'] is None:
36
-
37
- '''llm = OpenAI(
38
- temperature=0,
39
- openai_api_key=api_key,
40
- model_name='gpt-3.5-turbo-instruct' # 'text-davinci-003' model is depreciated now, so we are using the openai's recommended model
41
- )'''
42
-
43
- llm=ChatGroq(model="Gemma2-9b-It",groq_api_key=groq_api_key)
44
-
45
-
46
  st.session_state['conversation'] = ConversationChain(
47
  llm=llm,
48
  verbose=True,
49
- memory=ConversationSummaryMemory(llm=llm)
50
  )
51
 
52
- response=st.session_state['conversation'].predict(input=userInput)
53
- print(st.session_state['conversation'].memory.buffer)
54
-
55
-
56
  return response
57
 
58
-
59
-
60
  response_container = st.container()
61
- # Here we will have a container for user input text box
62
  container = st.container()
63
 
64
-
65
  with container:
66
  with st.form(key='my_form', clear_on_submit=True):
67
  user_input = st.text_area("Your question goes here:", key='input', height=100)
68
  submit_button = st.form_submit_button(label='Send')
69
 
70
  if submit_button:
 
71
  st.session_state['messages'].append(user_input)
72
- model_response=getresponse(user_input,st.session_state['API_Key'])
 
 
73
  st.session_state['messages'].append(model_response)
74
-
75
 
 
76
  with response_container:
77
  for i in range(len(st.session_state['messages'])):
78
- if (i % 2) == 0:
79
- message(st.session_state['messages'][i], is_user=True, key=str(i) + '_user')
80
- else:
81
- message(st.session_state['messages'][i], key=str(i) + '_AI')
82
-
83
-
84
-
85
-
 
1
  import streamlit as st
 
2
  from streamlit_chat import message
 
3
  from langchain_groq import ChatGroq
4
  from langchain.chains import ConversationChain
5
+ from langchain.chains.conversation.memory import ConversationSummaryMemory
6
+ from transformers import pipeline
7
+
8
+ # Initialize the text classifier for guardrails
9
+ classifier = pipeline("text-classification", model="meta-llama/Prompt-Guard-86M")
10
 
11
+ # Set session state variables
12
  if 'conversation' not in st.session_state:
13
+ st.session_state['conversation'] = None
14
  if 'messages' not in st.session_state:
15
+ st.session_state['messages'] = []
16
  if 'API_Key' not in st.session_state:
17
+ st.session_state['API_Key'] = ''
18
 
19
  # Setting page title and header
20
  st.set_page_config(page_title="Chat GPT Clone", page_icon=":robot_face:")
21
  st.markdown("<h1 style='text-align: center;'>How can I assist you? </h1>", unsafe_allow_html=True)
22
 
23
+ # Sidebar configuration
24
  st.sidebar.title("😎")
25
+ groq_api_key = st.sidebar.text_input(label="Groq API Key", type="password")
26
  summarise_button = st.sidebar.button("Summarise the conversation", key="summarise")
27
  if summarise_button:
28
+ st.sidebar.write("Nice chatting with you my friend ❤️:\n\n" + st.session_state['conversation'].memory.buffer)
 
 
29
 
30
+ # Function to get response from the chatbot
31
  def getresponse(userInput, api_key):
32
+ # Classify the input using guardrails
33
+ classification = classifier(userInput)[0] # Get the first result
34
+ if classification['label'] == "JAILBREAK":
35
+ # If classified as Jailbreak, return a predefined safe response
36
+ return "You are attempting jailbreak/prompt injection. I can't help you with that. Please ask another question."
37
 
38
+ # Initialize the conversation chain if not already initialized
39
  if st.session_state['conversation'] is None:
40
+ llm = ChatGroq(model="Gemma2-9b-It", groq_api_key=groq_api_key)
 
 
 
 
 
 
 
 
 
41
  st.session_state['conversation'] = ConversationChain(
42
  llm=llm,
43
  verbose=True,
44
+ memory=ConversationSummaryMemory(llm=llm),
45
  )
46
 
47
+ # Generate a response using the conversation chain
48
+ response = st.session_state['conversation'].predict(input=userInput)
 
 
49
  return response
50
 
51
+ # Response container
 
52
  response_container = st.container()
53
+ # User input container
54
  container = st.container()
55
 
 
56
  with container:
57
  with st.form(key='my_form', clear_on_submit=True):
58
  user_input = st.text_area("Your question goes here:", key='input', height=100)
59
  submit_button = st.form_submit_button(label='Send')
60
 
61
  if submit_button:
62
+ # Append user input to message history
63
  st.session_state['messages'].append(user_input)
64
+ # Get response from the chatbot or guardrails
65
+ model_response = getresponse(user_input, st.session_state['API_Key'])
66
+ # Append model response to message history
67
  st.session_state['messages'].append(model_response)
 
68
 
69
+ # Display the conversation
70
  with response_container:
71
  for i in range(len(st.session_state['messages'])):
72
+ if (i % 2) == 0:
73
+ message(st.session_state['messages'][i], is_user=True, key=str(i) + '_user')
74
+ else:
75
+ message(st.session_state['messages'][i], key=str(i) + '_AI')