nileshhanotia commited on
Commit
7b9f70d
·
verified ·
1 Parent(s): 81884e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -60
app.py CHANGED
@@ -1,66 +1,54 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import requests
4
- from config import ACCESS_TOKEN, SHOP_NAME
5
 
6
- class SQLGenerator:
7
  def __init__(self):
8
- self.model_name = "premai-io/prem-1B-SQL"
9
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
- self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
11
 
12
- def generate_query(self, natural_language_query):
13
- schema_info = """
14
- CREATE TABLE products (
15
- id DECIMAL(8,2) PRIMARY KEY,
16
- title VARCHAR(255),
17
- body_html VARCHAR(255),
18
- vendor VARCHAR(255),
19
- product_type VARCHAR(255),
20
- created_at VARCHAR(255),
21
- handle VARCHAR(255),
22
- updated_at DATE,
23
- published_at VARCHAR(255),
24
- template_suffix VARCHAR(255),
25
- published_scope VARCHAR(255),
26
- tags VARCHAR(255),
27
- status VARCHAR(255),
28
- admin_graphql_api_id DECIMAL(8,2),
29
- variants VARCHAR(255),
30
- options VARCHAR(255),
31
- images VARCHAR(255),
32
- image VARCHAR(255)
33
- );
34
- """
35
 
36
- prompt = f"""### Task: Generate a SQL query to answer the following question.
37
- ### Database Schema:
38
- {schema_info}
39
- ### Question: {natural_language_query}
40
- ### SQL Query:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
43
- outputs = self.model.generate(
44
- inputs["input_ids"],
45
- max_length=256,
46
- do_sample=False,
47
- num_return_sequences=1,
48
- eos_token_id=self.tokenizer.eos_token_id,
49
- pad_token_id=self.tokenizer.pad_token_id
50
- )
51
-
52
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
53
 
54
- def fetch_shopify_data(self, endpoint):
55
- headers = {
56
- 'X-Shopify-Access-Token': ACCESS_TOKEN,
57
- 'Content-Type': 'application/json'
58
- }
59
- url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
60
- response = requests.get(url, headers=headers)
61
-
62
- if response.status_code == 200:
63
- return response.json()
64
- else:
65
- print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
66
- return None
 
 
 
 
1
+ import gradio as gr
2
+ from sql_generator import SQLGenerator
3
+ from intent_classifier import IntentClassifier
4
+ from rag_system import RAGSystem
5
 
6
+ class UnifiedSystem:
7
  def __init__(self):
8
+ self.sql_generator = SQLGenerator()
9
+ self.intent_classifier = IntentClassifier()
10
+ self.rag_system = RAGSystem()
11
 
12
+ def process_query(self, query):
13
+ intent, confidence = self.intent_classifier.classify(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ if intent == "database_query":
16
+ sql_query = self.sql_generator.generate_query(query)
17
+ products = self.sql_generator.fetch_shopify_data("products")
18
+
19
+ if products and 'products' in products:
20
+ results = "\n".join([
21
+ f"Title: {p['title']}, Vendor: {p['vendor']}"
22
+ for p in products['products']
23
+ ])
24
+ return f"Intent: Database Query (Confidence: {confidence:.2f})\n\n" \
25
+ f"SQL Query: {sql_query}\n\nResults:\n{results}"
26
+ else:
27
+ return "No results found or error fetching data from Shopify."
28
+
29
+ elif intent == "product_description":
30
+ rag_response = self.rag_system.process_query(query)
31
+ return f"Intent: Product Description (Confidence: {confidence:.2f})\n\n" \
32
+ f"Response: {rag_response}"
33
 
34
+ return "Intent not recognized."
35
+
36
+ def create_interface():
37
+ system = UnifiedSystem()
 
 
 
 
 
 
 
38
 
39
+ iface = gr.Interface(
40
+ fn=system.process_query,
41
+ inputs=gr.Textbox(
42
+ label="Enter your query",
43
+ placeholder="e.g., 'Show me all T-shirts' or 'Describe the product features'"
44
+ ),
45
+ outputs=gr.Textbox(label="Response"),
46
+ title="Unified Query Processing System",
47
+ description="Enter a natural language query to search products or get descriptions."
48
+ )
49
+
50
+ return iface
51
+
52
+ if __name__ == "__main__":
53
+ iface = create_interface()
54
+ iface.launch()