freddyaboulton's picture
add code
b72e7b5
raw
history blame
6.92 kB
from __future__ import annotations as _annotations
import json
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import AsyncGenerator
import asyncpg
import gradio as gr
import numpy as np
import pydantic_core
from gradio_webrtc import (
AdditionalOutputs,
ReplyOnPause,
WebRTC,
audio_to_bytes,
get_twilio_turn_credentials,
)
from groq import Groq
from openai import AsyncOpenAI
from pydantic import BaseModel
from pydantic_ai import RunContext
from pydantic_ai.agent import Agent
from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn
DOCS = json.load(open("gradio_docs.json"))
groq_client = Groq()
openai = AsyncOpenAI()
@dataclass
class Deps:
openai: AsyncOpenAI
pool: asyncpg.Pool
SYSTEM_PROMPT = (
"You are an assistant designed to help users answer questions about Gradio. "
"You have a retrival tool that can provide relevant documentation sections based on the user query. "
"Be curteous and helpful to the user but feel free to refuse answering questions that are not about Gradio. "
)
agent = Agent(
"openai:gpt-4o",
deps_type=Deps,
system_prompt=SYSTEM_PROMPT,
)
class RetrievalResult(BaseModel):
content: str
ids: list[int]
@asynccontextmanager
async def database_connect(
create_db: bool = False,
) -> AsyncGenerator[asyncpg.Pool, None]:
server_dsn, database = (
os.getenv("DATABASE_URL"),
"gradio_ai_rag",
)
if create_db:
conn = await asyncpg.connect(server_dsn)
try:
db_exists = await conn.fetchval(
"SELECT 1 FROM pg_database WHERE datname = $1", database
)
if not db_exists:
await conn.execute(f"CREATE DATABASE {database}")
finally:
await conn.close()
pool = await asyncpg.create_pool(f"{server_dsn}/{database}")
try:
yield pool
finally:
await pool.close()
@agent.tool
async def retrieve(context: RunContext[Deps], search_query: str) -> str:
"""Retrieve documentation sections based on a search query.
Args:
context: The call context.
search_query: The search query.
"""
print(f"create embedding for {search_query}")
embedding = await context.deps.openai.embeddings.create(
input=search_query,
model="text-embedding-3-small",
)
assert (
len(embedding.data) == 1
), f"Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}"
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
rows = await context.deps.pool.fetch(
"SELECT id, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8",
embedding_json,
)
content = "\n\n".join(f'# {row["title"]}\n{row["content"]}\n' for row in rows)
ids = [row["id"] for row in rows]
return RetrievalResult(content=content, ids=ids).model_dump_json()
async def stream_from_agent(
audio: tuple[int, np.ndarray], chatbot: list[dict], past_messages: list
):
question = groq_client.audio.transcriptions.create(
file=("audio-file.mp3", audio_to_bytes(audio)),
model="whisper-large-v3-turbo",
response_format="verbose_json",
).text
print("text", question)
chatbot.append({"role": "user", "content": question})
yield AdditionalOutputs(chatbot, gr.skip())
async with database_connect(False) as pool:
deps = Deps(openai=openai, pool=pool)
async with agent.run_stream(
question, deps=deps, message_history=past_messages
) as result:
for message in result.new_messages():
past_messages.append(message)
if isinstance(message, ModelStructuredResponse):
for call in message.calls:
gr_message = {
"role": "assistant",
"content": "",
"metadata": {
"title": "πŸ” Retrieving relevant docs",
"id": call.tool_id,
},
}
chatbot.append(gr_message)
if isinstance(message, ToolReturn):
for gr_message in chatbot:
if (
gr_message.get("metadata", {}).get("id", "")
== message.tool_id
):
paths = []
for d in DOCS:
tool_result = RetrievalResult.model_validate_json(
message.content
)
if d["id"] in tool_result.ids:
paths.append(d["path"])
gr_message["content"] = (
f"Relevant Context:\n {'\n'.join(list(set(paths)))}"
)
yield AdditionalOutputs(chatbot, gr.skip())
chatbot.append({"role": "assistant", "content": ""})
async for message in result.stream_text():
chatbot[-1]["content"] = message
yield AdditionalOutputs(chatbot, gr.skip())
data = await result.get_data()
past_messages.append(ModelTextResponse(content=data))
yield AdditionalOutputs(gr.skip(), past_messages)
with gr.Blocks() as demo:
placeholder = """
<div style="display: flex; justify-content: center; align-items: center; gap: 1rem; padding: 1rem; width: 100%">
<img src="/gradio_api/file=logo.svg" style="max-width: 200px; height: auto">
<div>
<h1 style="margin: 0 0 1rem 0">Chat with Gradio Docs πŸ—£οΈ</h1>
<h3 style="margin: 0 0 0.5rem 0">
Simple RAG agent over Gradio docs built with Pydantic AI.
</h3>
<h3 style="margin: 0">
Ask any question about Gradio with your natural voice and get an answer!
</h3>
</div>
</div>
"""
past_messages = gr.State([])
chatbot = gr.Chatbot(
label="Gradio Docs Bot",
type="messages",
placeholder=placeholder,
avatar_images=(None, "logo.svg"),
)
audio = WebRTC(
label="Talk with the Agent",
modality="audio",
rtc_configuration=get_twilio_turn_credentials(),
mode="send",
)
audio.stream(
ReplyOnPause(stream_from_agent),
inputs=[audio, chatbot, past_messages],
outputs=[audio],
)
audio.on_additional_outputs(
lambda c, s: (c, s),
outputs=[chatbot, past_messages],
queue=False,
show_progress="hidden",
)
if __name__ == "__main__":
demo.launch(allowed_paths=["logo.svg"])