nileshhanotia commited on
Commit
56abc73
·
verified ·
1 Parent(s): 6a74563

Update sql_generator.py

Browse files
Files changed (1) hide show
  1. sql_generator.py +35 -62
sql_generator.py CHANGED
@@ -1,67 +1,40 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import requests
4
- from config import ACCESS_TOKEN, SHOP_NAME
5
 
6
  class SQLGenerator:
7
- def __init__(self):
8
- self.model_name = "premai-io/prem-1B-SQL"
9
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
- self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
11
 
12
- def generate_query(self, natural_language_query):
13
- schema_info = """
14
- CREATE TABLE products (
15
- id DECIMAL(8,2) PRIMARY KEY,
16
- title VARCHAR(255),
17
- body_html VARCHAR(255),
18
- vendor VARCHAR(255),
19
- product_type VARCHAR(255),
20
- created_at VARCHAR(255),
21
- handle VARCHAR(255),
22
- updated_at DATE,
23
- published_at VARCHAR(255),
24
- template_suffix VARCHAR(255),
25
- published_scope VARCHAR(255),
26
- tags VARCHAR(255),
27
- status VARCHAR(255),
28
- admin_graphql_api_id DECIMAL(8,2),
29
- variants VARCHAR(255),
30
- options VARCHAR(255),
31
- images VARCHAR(255),
32
- image VARCHAR(255)
33
- );
34
- """
35
-
36
- prompt = f"""### Task: Generate a SQL query to answer the following question.
37
- ### Database Schema:
38
- {schema_info}
39
- ### Question: {natural_language_query}
40
- ### SQL Query:"""
41
-
42
- inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
43
- outputs = self.model.generate(
44
- inputs["input_ids"],
45
- max_length=256,
46
- do_sample=False,
47
- num_return_sequences=1,
48
- eos_token_id=self.tokenizer.eos_token_id,
49
- pad_token_id=self.tokenizer.pad_token_id
50
- )
51
 
52
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
53
-
54
- def fetch_shopify_data(self, endpoint):
55
- headers = {
56
- 'X-Shopify-Access-Token': ACCESS_TOKEN,
57
- 'Content-Type': 'application/json'
58
- }
59
- url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
60
- response = requests.get(url, headers=headers)
61
-
62
- if response.status_code == 200:
63
- return response.json()
64
- else:
65
- print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
66
- return None
 
 
 
 
 
 
 
 
 
 
 
67
 
 
1
+ import sqlite3
2
+ from typing import List, Dict, Any
3
+ import logging
 
4
 
5
  class SQLGenerator:
6
+ def __init__(self, db_path: str = "shopify.db"):
7
+ self.db_path = db_path
8
+ self.setup_logging()
 
9
 
10
+ def setup_logging(self):
11
+ logging.basicConfig(level=logging.INFO)
12
+ self.logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def execute_query(self, query: str) -> List[Dict[str, Any]]:
15
+ """
16
+ Execute SQL query and return results as a list of dictionaries
17
+ """
18
+ try:
19
+ with sqlite3.connect(self.db_path) as conn:
20
+ conn.row_factory = sqlite3.Row
21
+ cursor = conn.cursor()
22
+ cursor.execute(query)
23
+ results = [dict(row) for row in cursor.fetchall()]
24
+ self.logger.info(f"Successfully executed query: {query[:100]}...")
25
+ return results
26
+ except sqlite3.Error as e:
27
+ self.logger.error(f"Database error: {e}")
28
+ raise
29
+ except Exception as e:
30
+ self.logger.error(f"Error executing query: {e}")
31
+ raise
32
+
33
+ def validate_query(self, query: str) -> bool:
34
+ """
35
+ Validate SQL query before execution
36
+ """
37
+ # Basic validation - you might want to add more sophisticated validation
38
+ dangerous_keywords = ["DROP", "DELETE", "TRUNCATE", "UPDATE", "INSERT"]
39
+ return not any(keyword in query.upper() for keyword in dangerous_keywords)
40