|
from __future__ import annotations as _annotations |
|
|
|
import json |
|
import os |
|
from contextlib import asynccontextmanager |
|
from dataclasses import dataclass |
|
from typing import AsyncGenerator |
|
|
|
import asyncpg |
|
import gradio as gr |
|
import numpy as np |
|
import pydantic_core |
|
from gradio_webrtc import ( |
|
AdditionalOutputs, |
|
ReplyOnPause, |
|
WebRTC, |
|
audio_to_bytes, |
|
get_twilio_turn_credentials, |
|
) |
|
from groq import Groq |
|
from openai import AsyncOpenAI |
|
from pydantic import BaseModel |
|
from pydantic_ai import RunContext |
|
from pydantic_ai.agent import Agent |
|
from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn |
|
|
|
DOCS = json.load(open("gradio_docs.json")) |
|
|
|
groq_client = Groq() |
|
openai = AsyncOpenAI() |
|
|
|
|
|
@dataclass |
|
class Deps: |
|
openai: AsyncOpenAI |
|
pool: asyncpg.Pool |
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
"You are an assistant designed to help users answer questions about Gradio. " |
|
"You have a retrival tool that can provide relevant documentation sections based on the user query. " |
|
"Be curteous and helpful to the user but feel free to refuse answering questions that are not about Gradio. " |
|
) |
|
|
|
|
|
agent = Agent( |
|
"openai:gpt-4o", |
|
deps_type=Deps, |
|
system_prompt=SYSTEM_PROMPT, |
|
) |
|
|
|
|
|
class RetrievalResult(BaseModel): |
|
content: str |
|
ids: list[int] |
|
|
|
|
|
@asynccontextmanager |
|
async def database_connect( |
|
create_db: bool = False, |
|
) -> AsyncGenerator[asyncpg.Pool, None]: |
|
server_dsn, database = ( |
|
os.getenv("DATABASE_URL"), |
|
"gradio_ai_rag", |
|
) |
|
if create_db: |
|
conn = await asyncpg.connect(server_dsn) |
|
try: |
|
db_exists = await conn.fetchval( |
|
"SELECT 1 FROM pg_database WHERE datname = $1", database |
|
) |
|
if not db_exists: |
|
await conn.execute(f"CREATE DATABASE {database}") |
|
finally: |
|
await conn.close() |
|
|
|
pool = await asyncpg.create_pool(f"{server_dsn}/{database}") |
|
try: |
|
yield pool |
|
finally: |
|
await pool.close() |
|
|
|
|
|
@agent.tool |
|
async def retrieve(context: RunContext[Deps], search_query: str) -> str: |
|
"""Retrieve documentation sections based on a search query. |
|
|
|
Args: |
|
context: The call context. |
|
search_query: The search query. |
|
""" |
|
print(f"create embedding for {search_query}") |
|
embedding = await context.deps.openai.embeddings.create( |
|
input=search_query, |
|
model="text-embedding-3-small", |
|
) |
|
|
|
assert ( |
|
len(embedding.data) == 1 |
|
), f"Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}" |
|
embedding = embedding.data[0].embedding |
|
embedding_json = pydantic_core.to_json(embedding).decode() |
|
rows = await context.deps.pool.fetch( |
|
"SELECT id, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8", |
|
embedding_json, |
|
) |
|
content = "\n\n".join(f'# {row["title"]}\n{row["content"]}\n' for row in rows) |
|
ids = [row["id"] for row in rows] |
|
return RetrievalResult(content=content, ids=ids).model_dump_json() |
|
|
|
|
|
async def stream_from_agent( |
|
audio: tuple[int, np.ndarray], chatbot: list[dict], past_messages: list |
|
): |
|
question = groq_client.audio.transcriptions.create( |
|
file=("audio-file.mp3", audio_to_bytes(audio)), |
|
model="whisper-large-v3-turbo", |
|
response_format="verbose_json", |
|
).text |
|
|
|
print("text", question) |
|
|
|
chatbot.append({"role": "user", "content": question}) |
|
yield AdditionalOutputs(chatbot, gr.skip()) |
|
|
|
async with database_connect(False) as pool: |
|
deps = Deps(openai=openai, pool=pool) |
|
async with agent.run_stream( |
|
question, deps=deps, message_history=past_messages |
|
) as result: |
|
for message in result.new_messages(): |
|
past_messages.append(message) |
|
if isinstance(message, ModelStructuredResponse): |
|
for call in message.calls: |
|
gr_message = { |
|
"role": "assistant", |
|
"content": "", |
|
"metadata": { |
|
"title": "π Retrieving relevant docs", |
|
"id": call.tool_id, |
|
}, |
|
} |
|
chatbot.append(gr_message) |
|
if isinstance(message, ToolReturn): |
|
for gr_message in chatbot: |
|
if ( |
|
gr_message.get("metadata", {}).get("id", "") |
|
== message.tool_id |
|
): |
|
paths = [] |
|
for d in DOCS: |
|
tool_result = RetrievalResult.model_validate_json( |
|
message.content |
|
) |
|
if d["id"] in tool_result.ids: |
|
paths.append(d["path"]) |
|
gr_message["content"] = ( |
|
f"Relevant Context:\n {'\n'.join(list(set(paths)))}" |
|
) |
|
yield AdditionalOutputs(chatbot, gr.skip()) |
|
chatbot.append({"role": "assistant", "content": ""}) |
|
async for message in result.stream_text(): |
|
chatbot[-1]["content"] = message |
|
yield AdditionalOutputs(chatbot, gr.skip()) |
|
data = await result.get_data() |
|
past_messages.append(ModelTextResponse(content=data)) |
|
yield AdditionalOutputs(gr.skip(), past_messages) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
placeholder = """ |
|
<div style="display: flex; justify-content: center; align-items: center; gap: 1rem; padding: 1rem; width: 100%"> |
|
<img src="/gradio_api/file=logo.svg" style="max-width: 200px; height: auto"> |
|
<div> |
|
<h1 style="margin: 0 0 1rem 0">Chat with Gradio Docs π£οΈ</h1> |
|
<h3 style="margin: 0 0 0.5rem 0"> |
|
Simple RAG agent over Gradio docs built with Pydantic AI. |
|
</h3> |
|
<h3 style="margin: 0"> |
|
Ask any question about Gradio with your natural voice and get an answer! |
|
</h3> |
|
</div> |
|
</div> |
|
""" |
|
past_messages = gr.State([]) |
|
chatbot = gr.Chatbot( |
|
label="Gradio Docs Bot", |
|
type="messages", |
|
placeholder=placeholder, |
|
avatar_images=(None, "logo.svg"), |
|
) |
|
audio = WebRTC( |
|
label="Talk with the Agent", |
|
modality="audio", |
|
rtc_configuration=get_twilio_turn_credentials(), |
|
mode="send", |
|
) |
|
audio.stream( |
|
ReplyOnPause(stream_from_agent), |
|
inputs=[audio, chatbot, past_messages], |
|
outputs=[audio], |
|
) |
|
audio.on_additional_outputs( |
|
lambda c, s: (c, s), |
|
outputs=[chatbot, past_messages], |
|
queue=False, |
|
show_progress="hidden", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(allowed_paths=["logo.svg"]) |
|
|