|
import streamlit as st |
|
from llama_cpp import Llama |
|
from sql import get_table_schema |
|
|
|
|
|
@st.cache_resource() |
|
def load_llm(repo_id, filename): |
|
llm = Llama.from_pretrained( |
|
repo_id=repo_id, |
|
filename=filename, |
|
verbose=True, |
|
use_mmap=True, |
|
use_mlock=True, |
|
n_threads=4, |
|
n_threads_batch=4, |
|
n_ctx=8000, |
|
) |
|
print(f"{repo_id} loaded successfully. ✅") |
|
return llm |
|
|
|
|
|
def generate_system_prompt(table_name, table_schema): |
|
""" |
|
Generates a prompt to provide context about a table's schema for LLM to convert natural language to SQL. |
|
|
|
Args: |
|
table_name (str): The name of the table. |
|
table_schema (list): A list of tuples where each tuple contains information about the columns in the table. |
|
|
|
Returns: |
|
str: The generated prompt to be used by the LLM. |
|
""" |
|
prompt = f"""You are an expert in writing SQL queries for relational databases. |
|
You will be provided with a database schema and a natural |
|
language question, and your task is to generate an accurate SQL query. |
|
|
|
The database has a table named '{table_name}' with the following schema:\n\n""" |
|
|
|
prompt += "Columns:\n" |
|
|
|
for col in table_schema: |
|
column_name = col[1] |
|
column_type = col[2] |
|
prompt += f"- {column_name} ({column_type})\n" |
|
|
|
prompt += "\nPlease generate a SQL query based on the following natural language question. ONLY return the SQL query." |
|
|
|
return prompt |
|
|
|
|
|
|
|
def response_generator(llm, messages, question, table_name, db_name): |
|
table_schema = get_table_schema(db_name, table_name) |
|
llm_prompt = generate_system_prompt(table_name, table_schema) |
|
user_prompt = f"""Question: {question}""" |
|
|
|
print(messages, llm_prompt, user_prompt) |
|
history = [{"content": llm_prompt.format(table_name=table_name), "role": "system"}] |
|
|
|
for val in messages: |
|
history.append(val) |
|
|
|
history.append({"role": "user", "content": user_prompt}) |
|
|
|
response = llm.create_chat_completion( |
|
messages=history, |
|
max_tokens=2048, |
|
temperature=0.7, |
|
top_p=0.95, |
|
) |
|
answer = response["choices"][0]["message"]["content"] |
|
|
|
query = answer.replace("```sql", "").replace("```", "") |
|
query = query.strip() |
|
return query |
|
|