|
import gradio as gr |
|
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda |
|
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage |
|
|
|
from langchain_astradb import AstraDBChatMessageHistory, AstraDBStore, AstraDBVectorStore |
|
|
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI |
|
|
|
from elevenlabs import VoiceSettings |
|
from elevenlabs.client import ElevenLabs |
|
from openai import OpenAI |
|
|
|
from json import loads as json_loads |
|
import itertools |
|
import time |
|
import os |
|
|
|
AI = True |
|
|
|
if not hasattr(itertools, "batched"): |
|
def batched(iterable, n): |
|
"Batch data into lists of length n. The last batch may be shorter." |
|
|
|
it = iter(iterable) |
|
while True: |
|
batch = list(itertools.islice(it, n)) |
|
if not batch: |
|
return |
|
yield batch |
|
itertools.batched = batched |
|
|
|
def ai_setup(): |
|
global llm, prompt_chain, oai_client |
|
|
|
if AI: |
|
oai_client = OpenAI() |
|
llm = ChatOpenAI(model = "gpt-4o", temperature=0.8) |
|
embedding = OpenAIEmbeddings() |
|
vstore = AstraDBVectorStore( |
|
embedding=embedding, |
|
collection_name=os.environ.get("ASTRA_DB_COLLECTION"), |
|
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), |
|
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), |
|
) |
|
|
|
retriever = vstore.as_retriever(search_kwargs={'k': 10}) |
|
|
|
prompt_template = os.environ.get("PROMPT_TEMPLATE") |
|
prompt = ChatPromptTemplate.from_messages([('system', prompt_template)]) |
|
prompt_chain = ( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
| RunnableLambda(format_context) |
|
| prompt |
|
|
|
|
|
) |
|
else: |
|
retriever = RunnableLambda(just_read) |
|
|
|
def group_and_sort(documents): |
|
grouped = {} |
|
for document in documents: |
|
title = document.metadata["Title"] |
|
docs = grouped.get(title, []) |
|
grouped[title] = docs |
|
|
|
docs.append((document.page_content, document.metadata["range"])) |
|
|
|
for title, values in grouped.items(): |
|
values.sort(key=lambda doc:doc[1][0]) |
|
|
|
for title in grouped: |
|
text = '' |
|
prev_last = 0 |
|
for fragment, (start, last) in grouped[title]: |
|
if start < prev_last: |
|
text += fragment[prev_last-start:] |
|
elif start == prev_last: |
|
text += fragment |
|
else: |
|
text += ' [...] ' |
|
text += fragment |
|
prev_last = last |
|
|
|
grouped[title] = text |
|
|
|
return grouped |
|
|
|
def format_context(pipeline_state): |
|
"""Print the state passed between Runnables in a langchain and pass it on""" |
|
|
|
context = '' |
|
documents = group_and_sort(pipeline_state["context"]) |
|
for title, text in documents.items(): |
|
context += f"\nTitle: {title}\n" |
|
context += text |
|
context += '\n\n---\n' |
|
|
|
pipeline_state["context"] = context |
|
return pipeline_state |
|
|
|
def just_read(pipeline_state): |
|
fname = "docs.pickle" |
|
import pickle |
|
|
|
return pickle.load(open(fname, "rb")) |
|
|
|
def new_state(): |
|
return gr.State({ |
|
"user" : None, |
|
"system" : None, |
|
"history" : None, |
|
}) |
|
|
|
def session_id(state: dict, request: gr.Request) -> str: |
|
return f'{state["user"]}_{request.session_hash}' |
|
|
|
class History: |
|
store = None |
|
def __init__(self, name:str, user:str, session_id:str, id:str = None): |
|
self.session_id = session_id |
|
self.name = name |
|
self.user = user |
|
self.astra_history = None |
|
|
|
if id: |
|
self.id = id |
|
else: |
|
self.id = f"{user}_{session_id}" |
|
self.create() |
|
|
|
@classmethod |
|
def get_store(self): |
|
if self.store is None: |
|
self.store = AstraDBStore( |
|
collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions', |
|
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), |
|
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), |
|
) |
|
return self.store |
|
|
|
@classmethod |
|
def from_dict(cls, id:str, data:dict): |
|
name = f":{id}" |
|
name = data.get("name", name) |
|
answer = cls(name, user=data["user"], id = id, session_id=data["session"]) |
|
|
|
return answer |
|
|
|
@classmethod |
|
def get_histories(cls, user:str): |
|
store = cls.get_store() |
|
histories = [] |
|
keys = [k for k in store.yield_keys(prefix=f"{user}_")] |
|
for id, history in zip(keys, store.mget(keys)): |
|
history = cls.from_dict(id = id, data = history) |
|
histories.append(history) |
|
return histories |
|
|
|
@classmethod |
|
def load(cls, id:str): |
|
data = cls.get_store().mget([id]) |
|
return cls.from_dict(id, data[0]) |
|
|
|
def __str__(self): |
|
return f"{self.id}:{self.name}" |
|
|
|
def create(self): |
|
history = { |
|
'session' : self.session_id, |
|
'user' : self.user, |
|
'timestamp' : time.asctime(time.gmtime()), |
|
'name' : self.name |
|
} |
|
self.get_store().mset([(self.id, history)]) |
|
|
|
@staticmethod |
|
def get_history_collection_name(): |
|
return f'{os.environ.get("ASTRA_DB_COLLECTION")}_chat_history' |
|
|
|
def get_astra_history(self): |
|
if self.astra_history is None: |
|
self.astra_history = AstraDBChatMessageHistory( |
|
session_id=self.id, |
|
collection_name=self.get_history_collection_name(), |
|
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), |
|
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), |
|
) |
|
return self.astra_history |
|
|
|
def add(self, type:str, message): |
|
if type == "system": |
|
self.get_astra_history().add_message(message) |
|
elif type == "user": |
|
self.get_astra_history().add_user_message(message) |
|
elif type == "ai": |
|
self.get_astra_history().add_ai_message(message) |
|
|
|
def messages(self): |
|
return self.get_astra_history().messages |
|
|
|
def clear(self): |
|
self.get_astra_history().clear() |
|
|
|
def delete(self): |
|
self.clear() |
|
self.get_store().mdelete([self.id]) |
|
|
|
def auth(token, state, request: gr.Request): |
|
tokens=os.environ.get("APP_TOKENS") |
|
if not tokens: |
|
state["user"] = "anonymous" |
|
else: |
|
tokens=json_loads(tokens) |
|
state["user"] = tokens.get(token, None) |
|
|
|
return "", state |
|
|
|
AUTH_JS = """function auth_js(token, state) { |
|
if (!!document.location.hash) { |
|
token = document.location.hash |
|
document.location.hash="" |
|
} |
|
return [token, state] |
|
} |
|
""" |
|
|
|
def not_authenticated(state): |
|
answer = (state is None) or (not state['user']) |
|
if answer: |
|
gr.Warning("You need to authenticate first") |
|
return answer |
|
|
|
def list_histories(state): |
|
if not_authenticated(state): |
|
return gr.update() |
|
|
|
histories = History.get_histories(state["user"]) |
|
answer = [(h.name, h.id) for h in histories] |
|
return gr.update(choices=answer, value=None) |
|
|
|
def add_history(state, request, type, message, name:str = None): |
|
if not state["history"]: |
|
name = name or message[:60] |
|
state["history"] = History( |
|
name = name, |
|
user = state["user"], |
|
session_id = request.session_hash |
|
) |
|
|
|
state["history"].add(type, message) |
|
|
|
def load_history(state, history_id): |
|
state["history"] = History.load(history_id) |
|
|
|
history = [] |
|
for msg in state["history"].messages(): |
|
if type(msg) is HumanMessage: |
|
history.append([msg.content, '']) |
|
elif type(msg) is AIMessage: |
|
if not history: |
|
history.append(['','']) |
|
last = history[-1] |
|
if last[1]: |
|
history.append(['', msg.content]) |
|
else: |
|
last[1] = msg.content |
|
|
|
if history and len(history[-1]) == 1: |
|
user_input = history[-1][0] |
|
history = history[:-1] |
|
else: |
|
user_input = '' |
|
|
|
if history: |
|
state["system"] = get_system_prompt(history[0][0]) |
|
|
|
return state, history, history, user_input |
|
|
|
def get_system_prompt(message): |
|
system_prompt = prompt_chain.invoke(message) |
|
return system_prompt.messages[0] |
|
|
|
def chat(message, history, state, request:gr.Request): |
|
if not_authenticated(state): |
|
yield "You need to authenticate first" |
|
else: |
|
if AI: |
|
if not history: |
|
state["system"] = get_system_prompt(message) |
|
system_prompt = state["system"] |
|
|
|
add_history(state, request, "user", message) |
|
|
|
messages = [system_prompt] |
|
for human, ai in history: |
|
messages.append(HumanMessage(human)) |
|
messages.append(AIMessage(ai)) |
|
messages.append(HumanMessage(message)) |
|
|
|
answer = '' |
|
for response in llm.stream(messages): |
|
answer += response.content |
|
yield answer+'…' |
|
else: |
|
add_history(state, request, "user", message) |
|
|
|
msg = f"{time.ctime()}: You said: {message}" |
|
answer = ' ' |
|
for word in msg.split(): |
|
answer += f' {word}' |
|
yield answer+'…' |
|
time.sleep(0.05) |
|
yield answer |
|
|
|
add_history(state, request, "ai", answer) |
|
|
|
def on_audio(path, state): |
|
if not_authenticated(state): |
|
return (gr.update(), None) |
|
else: |
|
if not path: |
|
return [gr.update(), None] |
|
if AI: |
|
text = oai_client.audio.transcriptions.create( |
|
model="whisper-1", |
|
file=open(path, "rb"), |
|
response_format="text" |
|
) |
|
else: |
|
text = f"{time.ctime()}: You said something" |
|
|
|
return (text, None) |
|
|
|
def play_last(history, state): |
|
if not_authenticated(state): |
|
pass |
|
else: |
|
if len(history): |
|
voice_id = "IINmogebEQykLiDoSkd0" |
|
text = history[-1][1] |
|
lab11 = ElevenLabs() |
|
whatson=lab11.voices.get(voice_id) |
|
response = lab11.generate(text=text, voice=whatson, stream=True) |
|
yield from response |
|
|
|
def chat_change(history): |
|
if history: |
|
if not history[-1][1]: |
|
return gr.update(interactive=False) |
|
elif history[-1][1][-1] != '…': |
|
return gr.update(interactive=True) |
|
return gr.update() |
|
|
|
TEXT_TALK = "🎤 Talk" |
|
TEXT_STOP = "⏹ Stop" |
|
|
|
def gr_setup(): |
|
theme = gr.Theme.from_hub("freddyaboulton/dracula_revamped@0.3.9") |
|
theme.set( |
|
color_accent_soft="#818eb6", |
|
background_fill_secondary="#6272a4", |
|
background_fill_primary="#818eb6", |
|
button_primary_text_color="*button_secondary_text_color", |
|
button_primary_background_fill="*button_secondary_background_fill") |
|
|
|
with gr.Blocks( |
|
title="Sherlock Holmes stories", |
|
fill_height=True, |
|
theme=theme, |
|
css="footer {visibility: hidden}" |
|
) as app: |
|
state = new_state() |
|
chatbot = gr.Chatbot(show_label=False, render=False, scale=1) |
|
gr.HTML('<h1 style="text-align: center">Sherlock Holmes stories</h1>') |
|
history_choice = gr.Dropdown( |
|
choices=[("History", "History")], |
|
value="History", |
|
show_label=False, |
|
container=False, |
|
interactive=True, |
|
filterable=True) |
|
|
|
iface = gr.ChatInterface( |
|
chat, |
|
chatbot=chatbot, |
|
title=None, |
|
submit_btn=gr.Button( |
|
"Send", |
|
variant="primary", |
|
scale=1, |
|
min_width=150, |
|
elem_id="submit_btn", |
|
render=False |
|
), |
|
undo_btn=None, |
|
clear_btn=None, |
|
retry_btn=None, |
|
|
|
|
|
|
|
|
|
|
|
|
|
additional_inputs=[state]) |
|
|
|
with gr.Row(): |
|
player = gr.Audio( |
|
visible=False, |
|
show_label=False, |
|
show_download_button=False, |
|
show_share_button=False, |
|
autoplay=True, |
|
streaming=True, |
|
interactive=False) |
|
|
|
mic = gr.Audio( |
|
sources=["microphone"], |
|
type="filepath", |
|
show_label=False, |
|
format="mp3", |
|
elem_id="microphone", |
|
visible=False, |
|
waveform_options=gr.WaveformOptions(sample_rate=16000, show_recording_waveform=False)) |
|
|
|
start_stop_rec = gr.Button(TEXT_TALK, size = "lg") |
|
play_last_btn = gr.Button("🔊 Play last", size = "lg", interactive=False) |
|
|
|
play_last_btn.click( |
|
play_last, |
|
[chatbot, state], player) |
|
|
|
chatbot.change(chat_change, inputs=chatbot, outputs=play_last_btn) |
|
start_stop_rec.click( |
|
lambda x:x, |
|
inputs=start_stop_rec, |
|
outputs=start_stop_rec, |
|
js=f'''function (text) {{ |
|
if (text == "{TEXT_TALK}") {{ |
|
document.getElementById("microphone").querySelector(".record-button").click() |
|
return ["{TEXT_STOP}"] |
|
}} else {{ |
|
document.getElementById("microphone").querySelector(".stop-button").click() |
|
return ["{TEXT_TALK}"] |
|
}} |
|
}}''' |
|
) |
|
mic.change( |
|
on_audio, [mic, state], [iface.textbox, mic] |
|
).then( |
|
lambda x:None, |
|
inputs=iface.textbox, |
|
js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}' |
|
) |
|
|
|
history_choice.focus( |
|
list_histories, |
|
inputs=state, |
|
outputs=history_choice |
|
) |
|
history_choice.input( |
|
load_history, |
|
inputs=[state, history_choice], |
|
outputs=[state, chatbot, iface.chatbot_state, iface.textbox]) |
|
|
|
token = gr.Textbox(visible=False) |
|
|
|
app.load(auth, |
|
[token,state], |
|
[token,state], |
|
js=AUTH_JS) |
|
|
|
app.queue(default_concurrency_limit=None, api_open=False) |
|
return app |
|
|
|
if __name__ == "__main__": |
|
ai_setup() |
|
app = gr_setup() |
|
app.launch(show_api=False) |
|
|