csv2sql / llm.py
mobinln's picture
feat: version 1
8e0a273
raw
history blame
2.31 kB
import streamlit as st
from llama_cpp import Llama
from sql import get_table_schema
@st.cache_resource()
def load_llm(repo_id, filename):
llm = Llama.from_pretrained(
repo_id=repo_id,
filename=filename,
verbose=True,
use_mmap=True,
use_mlock=True,
n_threads=4,
n_threads_batch=4,
n_ctx=8000,
)
print(f"{repo_id} loaded successfully. ✅")
return llm
def generate_system_prompt(table_name, table_schema):
"""
Generates a prompt to provide context about a table's schema for LLM to convert natural language to SQL.
Args:
table_name (str): The name of the table.
table_schema (list): A list of tuples where each tuple contains information about the columns in the table.
Returns:
str: The generated prompt to be used by the LLM.
"""
prompt = f"""You are an expert in writing SQL queries for relational databases.
You will be provided with a database schema and a natural
language question, and your task is to generate an accurate SQL query.
The database has a table named '{table_name}' with the following schema:\n\n"""
prompt += "Columns:\n"
for col in table_schema:
column_name = col[1]
column_type = col[2]
prompt += f"- {column_name} ({column_type})\n"
prompt += "\nPlease generate a SQL query based on the following natural language question. ONLY return the SQL query."
return prompt
# Streamed response emulator
def response_generator(llm, messages, question, table_name, db_name):
table_schema = get_table_schema(db_name, table_name)
llm_prompt = generate_system_prompt(table_name, table_schema)
user_prompt = f"""Question: {question}"""
print(messages, llm_prompt, user_prompt)
history = [{"content": llm_prompt.format(table_name=table_name), "role": "system"}]
for val in messages:
history.append(val)
history.append({"role": "user", "content": user_prompt})
response = llm.create_chat_completion(
messages=history,
max_tokens=2048,
temperature=0.7,
top_p=0.95,
)
answer = response["choices"][0]["message"]["content"]
query = answer.replace("```sql", "").replace("```", "")
query = query.strip()
return query