Whatson / query.py
gerasdf
bugfix(on werid sequence of chat histories)
7a05c53
import gradio as gr
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_astradb import AstraDBChatMessageHistory, AstraDBStore, AstraDBVectorStore
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from elevenlabs import VoiceSettings
from elevenlabs.client import ElevenLabs
from openai import OpenAI
from json import loads as json_loads
import itertools
import time
import os
AI = True
if not hasattr(itertools, "batched"):
def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
it = iter(iterable)
while True:
batch = list(itertools.islice(it, n))
if not batch:
return
yield batch
itertools.batched = batched
def ai_setup():
global llm, prompt_chain, oai_client
if AI:
oai_client = OpenAI()
llm = ChatOpenAI(model = "gpt-4o", temperature=0.8)
embedding = OpenAIEmbeddings()
vstore = AstraDBVectorStore(
embedding=embedding,
collection_name=os.environ.get("ASTRA_DB_COLLECTION"),
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
)
retriever = vstore.as_retriever(search_kwargs={'k': 10})
prompt_template = os.environ.get("PROMPT_TEMPLATE")
prompt = ChatPromptTemplate.from_messages([('system', prompt_template)])
prompt_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| RunnableLambda(format_context)
| prompt
# | llm
# | StrOutputParser()
)
else:
retriever = RunnableLambda(just_read)
def group_and_sort(documents):
grouped = {}
for document in documents:
title = document.metadata["Title"]
docs = grouped.get(title, [])
grouped[title] = docs
docs.append((document.page_content, document.metadata["range"]))
for title, values in grouped.items():
values.sort(key=lambda doc:doc[1][0])
for title in grouped:
text = ''
prev_last = 0
for fragment, (start, last) in grouped[title]:
if start < prev_last:
text += fragment[prev_last-start:]
elif start == prev_last:
text += fragment
else:
text += ' [...] '
text += fragment
prev_last = last
grouped[title] = text
return grouped
def format_context(pipeline_state):
"""Print the state passed between Runnables in a langchain and pass it on"""
context = ''
documents = group_and_sort(pipeline_state["context"])
for title, text in documents.items():
context += f"\nTitle: {title}\n"
context += text
context += '\n\n---\n'
pipeline_state["context"] = context
return pipeline_state
def just_read(pipeline_state):
fname = "docs.pickle"
import pickle
return pickle.load(open(fname, "rb"))
def new_state():
return gr.State({
"user" : None,
"system" : None,
"history" : None,
})
def session_id(state: dict, request: gr.Request) -> str:
return f'{state["user"]}_{request.session_hash}'
class History:
store = None
def __init__(self, name:str, user:str, session_id:str, id:str = None):
self.session_id = session_id
self.name = name
self.user = user
self.astra_history = None
if id:
self.id = id
else:
self.id = f"{user}_{session_id}"
self.create()
@classmethod
def get_store(self):
if self.store is None:
self.store = AstraDBStore(
collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions',
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
)
return self.store
@classmethod
def from_dict(cls, id:str, data:dict):
name = f":{id}"
name = data.get("name", name)
answer = cls(name, user=data["user"], id = id, session_id=data["session"])
return answer
@classmethod
def get_histories(cls, user:str):
store = cls.get_store()
histories = []
keys = [k for k in store.yield_keys(prefix=f"{user}_")]
for id, history in zip(keys, store.mget(keys)):
history = cls.from_dict(id = id, data = history)
histories.append(history)
return histories
@classmethod
def load(cls, id:str):
data = cls.get_store().mget([id])
return cls.from_dict(id, data[0])
def __str__(self):
return f"{self.id}:{self.name}"
def create(self):
history = {
'session' : self.session_id,
'user' : self.user,
'timestamp' : time.asctime(time.gmtime()),
'name' : self.name
}
self.get_store().mset([(self.id, history)])
@staticmethod
def get_history_collection_name():
return f'{os.environ.get("ASTRA_DB_COLLECTION")}_chat_history'
def get_astra_history(self):
if self.astra_history is None:
self.astra_history = AstraDBChatMessageHistory(
session_id=self.id,
collection_name=self.get_history_collection_name(),
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
)
return self.astra_history
def add(self, type:str, message):
if type == "system":
self.get_astra_history().add_message(message)
elif type == "user":
self.get_astra_history().add_user_message(message)
elif type == "ai":
self.get_astra_history().add_ai_message(message)
def messages(self):
return self.get_astra_history().messages
def clear(self):
self.get_astra_history().clear()
def delete(self):
self.clear()
self.get_store().mdelete([self.id])
def auth(token, state, request: gr.Request):
tokens=os.environ.get("APP_TOKENS")
if not tokens:
state["user"] = "anonymous"
else:
tokens=json_loads(tokens)
state["user"] = tokens.get(token, None)
return "", state
AUTH_JS = """function auth_js(token, state) {
if (!!document.location.hash) {
token = document.location.hash
document.location.hash=""
}
return [token, state]
}
"""
def not_authenticated(state):
answer = (state is None) or (not state['user'])
if answer:
gr.Warning("You need to authenticate first")
return answer
def list_histories(state):
if not_authenticated(state):
return gr.update()
histories = History.get_histories(state["user"])
answer = [(h.name, h.id) for h in histories]
return gr.update(choices=answer, value=None)
def add_history(state, request, type, message, name:str = None):
if not state["history"]:
name = name or message[:60]
state["history"] = History(
name = name,
user = state["user"],
session_id = request.session_hash
)
state["history"].add(type, message)
def load_history(state, history_id):
state["history"] = History.load(history_id)
history = []
for msg in state["history"].messages():
if type(msg) is HumanMessage:
history.append([msg.content, ''])
elif type(msg) is AIMessage:
if not history:
history.append(['',''])
last = history[-1]
if last[1]:
history.append(['', msg.content])
else:
last[1] = msg.content
if history and len(history[-1]) == 1:
user_input = history[-1][0]
history = history[:-1]
else:
user_input = ''
if history:
state["system"] = get_system_prompt(history[0][0])
return state, history, history, user_input # state, Chatbot, ChatInterface.state, ChatInterface.textbox
def get_system_prompt(message):
system_prompt = prompt_chain.invoke(message)
return system_prompt.messages[0]
def chat(message, history, state, request:gr.Request):
if not_authenticated(state):
yield "You need to authenticate first"
else:
if AI:
if not history:
state["system"] = get_system_prompt(message)
system_prompt = state["system"]
add_history(state, request, "user", message)
messages = [system_prompt]
for human, ai in history:
messages.append(HumanMessage(human))
messages.append(AIMessage(ai))
messages.append(HumanMessage(message))
answer = ''
for response in llm.stream(messages):
answer += response.content
yield answer+'…'
else:
add_history(state, request, "user", message)
msg = f"{time.ctime()}: You said: {message}"
answer = ' '
for word in msg.split():
answer += f' {word}'
yield answer+'…'
time.sleep(0.05)
yield answer
add_history(state, request, "ai", answer)
def on_audio(path, state):
if not_authenticated(state):
return (gr.update(), None)
else:
if not path:
return [gr.update(), None]
if AI:
text = oai_client.audio.transcriptions.create(
model="whisper-1",
file=open(path, "rb"),
response_format="text"
)
else:
text = f"{time.ctime()}: You said something"
return (text, None)
def play_last(history, state):
if not_authenticated(state):
pass
else:
if len(history):
voice_id = "IINmogebEQykLiDoSkd0"
text = history[-1][1]
lab11 = ElevenLabs()
whatson=lab11.voices.get(voice_id)
response = lab11.generate(text=text, voice=whatson, stream=True)
yield from response
def chat_change(history):
if history:
if not history[-1][1]:
return gr.update(interactive=False)
elif history[-1][1][-1] != '…':
return gr.update(interactive=True)
return gr.update() # play_last_btn
TEXT_TALK = "🎤 Talk"
TEXT_STOP = "⏹ Stop"
def gr_setup():
theme = gr.Theme.from_hub("freddyaboulton/dracula_revamped@0.3.9")
theme.set(
color_accent_soft="#818eb6", # ChatBot.svelte / .user / .message-row.panel.user-row . neutral_500 -> neutral_200
background_fill_secondary="#6272a4", # ChatBot.svelte / .bot / .message-row.panel.bot-row . neutral_500 -> neutral_400
background_fill_primary="#818eb6", # DropdownOptions.svelte / item
button_primary_text_color="*button_secondary_text_color",
button_primary_background_fill="*button_secondary_background_fill")
with gr.Blocks(
title="Sherlock Holmes stories",
fill_height=True,
theme=theme,
css="footer {visibility: hidden}"
) as app:
state = new_state()
chatbot = gr.Chatbot(show_label=False, render=False, scale=1)
gr.HTML('<h1 style="text-align: center">Sherlock Holmes stories</h1>')
history_choice = gr.Dropdown(
choices=[("History", "History")],
value="History",
show_label=False,
container=False,
interactive=True,
filterable=True)
iface = gr.ChatInterface(
chat,
chatbot=chatbot,
title=None,
submit_btn=gr.Button(
"Send",
variant="primary",
scale=1,
min_width=150,
elem_id="submit_btn",
render=False
),
undo_btn=None,
clear_btn=None,
retry_btn=None,
# examples=[
# ["I arrived late last night and found a dead goose in my bed"],
# ["Help please sir. I'm about to get married, to the most lovely lady,"
# "and I just received a letter threatening me to make public some things"
# "of my past I'd rather keep quiet, unless I don't marry"],
# ],
additional_inputs=[state])
with gr.Row():
player = gr.Audio(
visible=False,
show_label=False,
show_download_button=False,
show_share_button=False,
autoplay=True,
streaming=True,
interactive=False)
mic = gr.Audio(
sources=["microphone"],
type="filepath",
show_label=False,
format="mp3",
elem_id="microphone",
visible=False,
waveform_options=gr.WaveformOptions(sample_rate=16000, show_recording_waveform=False))
start_stop_rec = gr.Button(TEXT_TALK, size = "lg")
play_last_btn = gr.Button("🔊 Play last", size = "lg", interactive=False)
play_last_btn.click(
play_last,
[chatbot, state], player)
chatbot.change(chat_change, inputs=chatbot, outputs=play_last_btn)
start_stop_rec.click(
lambda x:x,
inputs=start_stop_rec,
outputs=start_stop_rec,
js=f'''function (text) {{
if (text == "{TEXT_TALK}") {{
document.getElementById("microphone").querySelector(".record-button").click()
return ["{TEXT_STOP}"]
}} else {{
document.getElementById("microphone").querySelector(".stop-button").click()
return ["{TEXT_TALK}"]
}}
}}'''
)
mic.change(
on_audio, [mic, state], [iface.textbox, mic]
).then(
lambda x:None,
inputs=iface.textbox,
js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}'
)
history_choice.focus(
list_histories,
inputs=state,
outputs=history_choice
)
history_choice.input(
load_history,
inputs=[state, history_choice],
outputs=[state, chatbot, iface.chatbot_state, iface.textbox])
token = gr.Textbox(visible=False)
app.load(auth,
[token,state],
[token,state],
js=AUTH_JS)
app.queue(default_concurrency_limit=None, api_open=False)
return app
if __name__ == "__main__":
ai_setup()
app = gr_setup()
app.launch(show_api=False)