File size: 7,157 Bytes
b72e7b5
 
 
 
 
 
 
 
 
 
36e2c69
b72e7b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36e2c69
 
 
 
 
b72e7b5
 
 
 
 
 
 
 
 
 
 
 
 
 
36e2c69
b72e7b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36e2c69
b72e7b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3b074e
b72e7b5
c3b074e
b72e7b5
 
 
 
 
 
 
 
 
 
 
 
 
 
36e2c69
b72e7b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36e2c69
b72e7b5
 
 
 
 
 
 
 
 
 
 
084f417
83f05d7
b72e7b5
 
 
 
 
 
 
 
 
 
31a3133
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from __future__ import annotations as _annotations

import json
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import AsyncGenerator

import asyncpg
import gradio as gr
from gradio.utils import get_space
import numpy as np
import pydantic_core
from gradio_webrtc import (
    AdditionalOutputs,
    ReplyOnPause,
    WebRTC,
    audio_to_bytes,
    get_twilio_turn_credentials,
)
from groq import Groq
from openai import AsyncOpenAI
from pydantic import BaseModel
from pydantic_ai import RunContext
from pydantic_ai.agent import Agent
from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn

if not get_space():
    from dotenv import load_dotenv

    load_dotenv()

DOCS = json.load(open("gradio_docs.json"))

groq_client = Groq()
openai = AsyncOpenAI()


@dataclass
class Deps:
    openai: AsyncOpenAI
    pool: asyncpg.Pool


SYSTEM_PROMPT = (
    "You are an assistant designed to help users answer questions about Gradio. "
    "You have a retrieve tool that can provide relevant documentation sections based on the user query. "
    "Be curteous and helpful to the user but feel free to refuse answering questions that are not about Gradio. "
)


agent = Agent(
    "openai:gpt-4o",
    deps_type=Deps,
    system_prompt=SYSTEM_PROMPT,
)


class RetrievalResult(BaseModel):
    content: str
    ids: list[int]


@asynccontextmanager
async def database_connect(
    create_db: bool = False,
) -> AsyncGenerator[asyncpg.Pool, None]:
    server_dsn, database = (
        os.getenv("DB_URL"),
        "gradio_ai_rag",
    )
    if create_db:
        conn = await asyncpg.connect(server_dsn)
        try:
            db_exists = await conn.fetchval(
                "SELECT 1 FROM pg_database WHERE datname = $1", database
            )
            if not db_exists:
                await conn.execute(f"CREATE DATABASE {database}")
        finally:
            await conn.close()

    pool = await asyncpg.create_pool(f"{server_dsn}/{database}")
    try:
        yield pool
    finally:
        await pool.close()


@agent.tool
async def retrieve(context: RunContext[Deps], search_query: str) -> str:
    """Retrieve documentation sections based on a search query.

    Args:
        context: The call context.
        search_query: The search query.
    """
    print(f"create embedding for {search_query}")
    embedding = await context.deps.openai.embeddings.create(
        input=search_query,
        model="text-embedding-3-small",
    )

    assert (
        len(embedding.data) == 1
    ), f"Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}"
    embedding = embedding.data[0].embedding
    embedding_json = pydantic_core.to_json(embedding).decode()
    rows = await context.deps.pool.fetch(
        "SELECT id, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8",
        embedding_json,
    )
    content = "\n\n".join(f'# {row["title"]}\n{row["content"]}\n' for row in rows)
    ids = [row["id"] for row in rows]
    return RetrievalResult(content=content, ids=ids).model_dump_json()


async def stream_from_agent(
    audio: tuple[int, np.ndarray], chatbot: list[dict], past_messages: list
):
    question = groq_client.audio.transcriptions.create(
        file=("audio-file.mp3", audio_to_bytes(audio)),
        model="whisper-large-v3-turbo",
        response_format="verbose_json",
    ).text

    print("text", question)

    chatbot.append({"role": "user", "content": question})
    yield AdditionalOutputs(chatbot, gr.skip())

    async with database_connect(False) as pool:
        deps = Deps(openai=openai, pool=pool)
        async with agent.run_stream(
            question, deps=deps, message_history=past_messages
        ) as result:
            for message in result.new_messages():
                past_messages.append(message)
                if isinstance(message, ModelStructuredResponse):
                    for call in message.calls:
                        gr_message = {
                            "role": "assistant",
                            "content": "",
                            "metadata": {
                                "title": "🔍 Retrieving relevant docs",
                                "id": call.tool_id,
                            },
                        }
                        chatbot.append(gr_message)
                if isinstance(message, ToolReturn):
                    for gr_message in chatbot:
                        if (
                            gr_message.get("metadata", {}).get("id", "")
                            == message.tool_id
                        ):
                            paths = []
                            for d in DOCS:
                                tool_result = RetrievalResult.model_validate_json(
                                    message.content
                                )
                                if d["id"] in tool_result.ids:
                                    paths.append(d["path"])
                            paths = '\n'.join(list(set(paths)))
                            gr_message["content"] = (
                                f"Relevant Context:\n {paths}"
                            )
                yield AdditionalOutputs(chatbot, gr.skip())
            chatbot.append({"role": "assistant", "content": ""})
            async for message in result.stream_text():
                chatbot[-1]["content"] = message
                yield AdditionalOutputs(chatbot, gr.skip())
            data = await result.get_data()
            past_messages.append(ModelTextResponse(content=data))
            yield AdditionalOutputs(gr.skip(), past_messages)


with gr.Blocks() as demo:
    placeholder = """
<div style="display: flex; justify-content: center; align-items: center; gap: 1rem; padding: 1rem; width: 100%">
    <img src="/gradio_api/file=gradio_logo.png" style="max-width: 200px; height: auto">
    <div>
        <h1 style="margin: 0 0 1rem 0">Chat with Gradio Docs 🗣️</h1>
        <h3 style="margin: 0 0 0.5rem 0">
            Simple RAG agent over Gradio docs built with Pydantic AI.
        </h3>
        <h3 style="margin: 0">
            Ask any question about Gradio with your natural voice and get an answer!
        </h3>
    </div>
</div>
"""
    past_messages = gr.State([])
    chatbot = gr.Chatbot(
        label="Gradio Docs Bot",
        type="messages",
        placeholder=placeholder,
        avatar_images=(None, "gradio_logo.png"),
    )
    audio = WebRTC(
        label="Talk with the Agent",
        modality="audio",
        rtc_configuration=get_twilio_turn_credentials(),
        mode="send",
    )
    audio.stream(
        ReplyOnPause(stream_from_agent),
        inputs=[audio, chatbot, past_messages],
        outputs=[audio],
        time_limit=90,
        concurrency_limit=5
    )
    audio.on_additional_outputs(
        lambda c, s: (c, s),
        outputs=[chatbot, past_messages],
        queue=False,
        show_progress="hidden",
    )


if __name__ == "__main__":
    demo.launch(allowed_paths=["gradio_logo.png"], ssr_mode=False)