import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch import sqlparse # from modelscope import snapshot_download # 加载模型和分词器 model_name = "defog/llama-3-sqlcoder-8b" # 使用更新的模型以提高性能 # model_name = snapshot_download("stevie/llama-3-sqlcoder-8b") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto", use_cache=True, ) def generate_sql(user_question, instructions, create_table_statements): prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Generate a SQL query to answer this question: `{user_question}` {instructions} DDL statements: {create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|> The following SQL query best answers the question `{user_question}`: ```sql """ inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") generated_ids = model.generate( **inputs, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, max_new_tokens=400, do_sample=False, num_beams=1, ) outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) torch.cuda.empty_cache() torch.cuda.synchronize() return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True) question = f"What are our top 3 products by revenue in the New York region?" instructions = f"""- if the question cannot be answered given the database schema, return "I do not know" - recall that the current date in YYYY-MM-DD format is 2024-09-15 """ schema = f"""CREATE TABLE products ( product_id INTEGER PRIMARY KEY, -- Unique ID for each product name VARCHAR(50), -- Name of the product price DECIMAL(10,2), -- Price of each unit of the product quantity INTEGER -- Current quantity in stock ); CREATE TABLE customers ( customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer name VARCHAR(50), -- Name of the customer address VARCHAR(100) -- Mailing address of the customer ); CREATE TABLE salespeople ( salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson name VARCHAR(50), -- Name of the salesperson region VARCHAR(50) -- Geographic sales region ); CREATE TABLE sales ( sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale product_id INTEGER, -- ID of product sold customer_id INTEGER, -- ID of customer who made purchase salesperson_id INTEGER, -- ID of salesperson who made the sale sale_date DATE, -- Date the sale occurred quantity INTEGER -- Quantity of product sold ); CREATE TABLE product_suppliers ( supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier product_id INTEGER, -- Product ID supplied supply_price DECIMAL(10,2) -- Unit price charged by supplier ); -- sales.product_id can be joined with products.product_id -- sales.customer_id can be joined with customers.customer_id -- sales.salesperson_id can be joined with salespeople.salesperson_id -- product_suppliers.product_id can be joined with products.product_id """ demo = gr.Interface( fn=generate_sql, title="SQLCoder-8b", description="Defog's SQLCoder-8B is a state of the art-models for generating SQL queries from natural language. ", inputs=[ gr.Textbox(label="User Question", placeholder="Enter your question here...", value=question), gr.Textbox(label="Instructions (optional)", placeholder="Enter any additional instructions here...", value=instructions), gr.Textbox(label="Create Table Statements", placeholder="Enter DDL statements here...", value=schema), ], outputs="text", ) if __name__ == "__main__": demo.launch(share=True)