Spaces:
Running
Running
File size: 1,834 Bytes
e931b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import streamlit as st
from langchain.callbacks.streamlit.streamlit_callback_handler import (
StreamlitCallbackHandler,
)
from langchain.schema.output import LLMResult
from sql_formatter.core import format_sql
class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
self.status_bar = st.empty()
self.prog_value = 0
self.prog_interval = 0.2
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_llm_end(
self,
response: LLMResult,
*args,
**kwargs,
):
text = response.generations[0][0].text
if text.replace(" ", "").upper().startswith("SELECT"):
st.markdown("### Generated Vector Search SQL Statement \n"
"> This sql statement is generated by LLM \n\n")
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
self.prog_value += self.prog_interval
self.progress_bar.progress(
value=self.prog_value, text="Searching in DB...")
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
cid = ".".join(serialized["id"])
self.prog_value += self.prog_interval
self.progress_bar.progress(
value=self.prog_value, text=f"Running Chain `{cid}`..."
)
def on_chain_end(self, outputs, **kwargs) -> None:
pass
class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler):
def __init__(self, table: str) -> None:
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
self.status_bar = st.empty()
self.prog_value = 0
self.prog_interval = 0.1
self.table = table
|