csv2sql / llm.py
mobinln's picture
add csv ingestion and prompt templates
5875608
raw
history blame
2.95 kB
import streamlit as st
from llama_cpp import Llama
from sql import get_table_schema
@st.cache_resource()
def load_llm(repo_id, filename):
llm = Llama.from_pretrained(
repo_id=repo_id,
filename=filename,
verbose=True,
use_mmap=True,
use_mlock=True,
n_threads=4,
n_threads_batch=4,
n_ctx=8000,
)
print(f"{repo_id} loaded successfully. ✅")
return llm
def generate_llm_prompt(table_name, table_schema):
"""
Generates a prompt to provide context about a table's schema for LLM to convert natural language to SQL.
Args:
table_name (str): The name of the table.
table_schema (list): A list of tuples where each tuple contains information about the columns in the table.
Returns:
str: The generated prompt to be used by the LLM.
"""
prompt = f"""You are an expert in writing SQL queries for relational databases.
You will be provided with a database schema and a natural
language question, and your task is to generate an accurate SQL query.
The database has a table named '{table_name}' with the following schema:\n\n"""
prompt += "Columns:\n"
for col in table_schema:
column_name = col[1]
column_type = col[2]
prompt += f"- {column_name} ({column_type})\n"
prompt += "\nPlease generate a SQL query based on the following natural language question. ONLY return the SQL query."
return prompt
def generate_sql_query(question, table_name, db_name):
pass
# table_name = 'movies'
# db_name = 'movies_db.db'
# table_schema = get_table_schema(db_name, table_name)
# llm_prompt = generate_llm_prompt(table_name, table_schema)
# user_prompt = """Question: {question}"""
# response = completion(
# api_key=OPENAI_API_KEY,
# model="gpt-4o-mini",
# messages=[
# ,
# {"content": user_prompt.format(question=question),"role": "user"}],
# max_tokens=1000
# )
# answer = response.choices[0].message.content
# query = answer.replace("```sql", "").replace("```", "")
# query = query.strip()
# return query
# Streamed response emulator
def response_generator(llm, messages, question, table_name, db_name):
table_schema = get_table_schema(db_name, table_name)
llm_prompt = generate_llm_prompt(table_name, table_schema)
user_prompt = """Question: {question}"""
messages = [{"content": llm_prompt.format(table_name=table_name), "role": "system"}]
for val in st.session_state.messages:
messages.append(val)
messages.append({"role": "user", "content": user_prompt})
response = llm.create_chat_completion(
messages, max_tokens=2048, temperature=0.7, top_p=0.95
)
answer = response["choices"][0].message.content
query = answer.replace("```sql", "").replace("```", "")
query = query.strip()
return query