nileshhanotia commited on
Commit
1fcfc73
·
verified ·
1 Parent(s): 6126ba8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -61
app.py CHANGED
@@ -1,66 +1,66 @@
1
- import logging
2
- import gradio as gr
3
- from sql_generator import SQLGenerator
4
- from intent_classifier import IntentClassifier
5
- from rag_system import RAGSystem
6
 
7
- # Initialize logging
8
- logging.basicConfig(
9
- level=logging.INFO, # Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
10
- format='%(asctime)s - %(levelname)s - %(message)s' # Format for the log messages
11
- )
12
-
13
- class UnifiedSystem:
14
- def __init__(self):
15
- self.sql_generator = SQLGenerator()
16
- self.intent_classifier = IntentClassifier()
17
- self.rag_system = RAGSystem()
18
 
19
- def process_query(self, query):
20
- logging.info(f"Processing query: {query}") # Log the incoming query
21
- intent, confidence = self.intent_classifier.classify(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- logging.info(f"Classified intent: {intent} with confidence: {confidence:.2f}")
24
-
25
- if intent == "database_query":
26
- sql_query = self.sql_generator.generate_query(query) # Assuming this method is correct
27
- products = self.sql_generator.fetch_shopify_data("products")
28
-
29
- if products and 'products' in products:
30
- results = "\n".join([
31
- f"Title: {p['title']}, Vendor: {p['vendor']}"
32
- for p in products['products']
33
- ])
34
- return f"Intent: Database Query (Confidence: {confidence:.2f})\n\n" \
35
- f"SQL Query: {sql_query}\n\nResults:\n{results}"
36
- else:
37
- logging.warning("No results found or error fetching data from Shopify.")
38
- return "No results found or error fetching data from Shopify."
39
-
40
- elif intent == "product_description":
41
- rag_response = self.rag_system.process_query(query)
42
- return f"Intent: Product Description (Confidence: {confidence:.2f})\n\n" \
43
- f"Response: {rag_response}"
44
 
45
- logging.error("Intent not recognized.")
46
- return "Intent not recognized."
47
-
48
- def create_interface():
49
- system = UnifiedSystem()
50
-
51
- iface = gr.Interface(
52
- fn=system.process_query,
53
- inputs=gr.Textbox(
54
- label="Enter your query",
55
- placeholder="e.g., 'Show me all T-shirts' or 'Describe the product features'"
56
- ),
57
- outputs=gr.Textbox(label="Response"),
58
- title="Unified Query Processing System",
59
- description="Enter a natural language query to search products or get descriptions."
60
- )
61
 
62
- return iface
63
-
64
- if __name__ == "__main__":
65
- iface = create_interface()
66
- iface.launch()
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import requests
4
+ from config import ACCESS_TOKEN, SHOP_NAME
 
5
 
6
+ class SQLGenerator:
7
+ def __init__(self):
8
+ self.model_name = "premai-io/prem-1B-SQL"
9
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
 
 
 
 
 
 
11
 
12
+ def generate_query(self, natural_language_query):
13
+ schema_info = """
14
+ CREATE TABLE products (
15
+ id DECIMAL(8,2) PRIMARY KEY,
16
+ title VARCHAR(255),
17
+ body_html VARCHAR(255),
18
+ vendor VARCHAR(255),
19
+ product_type VARCHAR(255),
20
+ created_at VARCHAR(255),
21
+ handle VARCHAR(255),
22
+ updated_at DATE,
23
+ published_at VARCHAR(255),
24
+ template_suffix VARCHAR(255),
25
+ published_scope VARCHAR(255),
26
+ tags VARCHAR(255),
27
+ status VARCHAR(255),
28
+ admin_graphql_api_id DECIMAL(8,2),
29
+ variants VARCHAR(255),
30
+ options VARCHAR(255),
31
+ images VARCHAR(255),
32
+ image VARCHAR(255)
33
+ );
34
+ """
35
 
36
+ prompt = f"""### Task: Generate a SQL query to answer the following question.
37
+ ### Database Schema:
38
+ {schema_info}
39
+ ### Question: {natural_language_query}
40
+ ### SQL Query:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
43
+ outputs = self.model.generate(
44
+ inputs["input_ids"],
45
+ max_length=256,
46
+ do_sample=False,
47
+ num_return_sequences=1,
48
+ eos_token_id=self.tokenizer.eos_token_id,
49
+ pad_token_id=self.tokenizer.pad_token_id
50
+ )
51
+
52
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
 
 
 
 
53
 
54
+ def fetch_shopify_data(self, endpoint):
55
+ headers = {
56
+ 'X-Shopify-Access-Token': ACCESS_TOKEN,
57
+ 'Content-Type': 'application/json'
58
+ }
59
+ url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
60
+ response = requests.get(url, headers=headers)
61
+
62
+ if response.status_code == 200:
63
+ return response.json()
64
+ else:
65
+ print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
66
+ return None