DINGOLANI commited on
Commit
57b32e0
·
verified ·
1 Parent(s): 736b778

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -30
app.py CHANGED
@@ -1,47 +1,45 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import CLIPProcessor, CLIPModel
4
  import re
 
5
 
6
- # Load FashionCLIP model
7
- model_name = "patrickjohncyh/fashion-clip"
8
- model = CLIPModel.from_pretrained(model_name)
9
- processor = CLIPProcessor.from_pretrained(model_name)
10
 
11
- # Regex for price extraction
12
  price_pattern = re.compile(r'(\bunder\b|\babove\b|\bbelow\b|\bbetween\b)?\s?(\d{1,5})\s?(AED|USD|EUR)?', re.IGNORECASE)
13
 
14
- def get_text_embedding(text_list):
15
- """
16
- Converts a list of input texts into embeddings using FashionCLIP.
17
- """
18
- inputs = processor(text=text_list, return_tensors="pt", padding=True) # Corrected input format
19
- with torch.no_grad():
20
- text_embedding = model.get_text_features(**inputs)
21
- return text_embedding
22
 
23
  def extract_attributes(query):
24
  """
25
- Extract structured fashion attributes dynamically using FashionCLIP.
26
  """
27
  structured_output = {"Brand": "Unknown", "Category": "Unknown", "Gender": "Unknown", "Price": "Unknown"}
28
 
29
- # Get embedding for the query
30
- query_embedding = get_text_embedding([query])
31
 
32
- # Reference labels for classification
33
- reference_labels = ["Brand", "Category", "Gender", "Price"]
34
- reference_embeddings = get_text_embedding(reference_labels)
35
 
36
- # Compute cosine similarity
37
- similarities = torch.nn.functional.cosine_similarity(query_embedding, reference_embeddings)
38
- best_match_index = similarities.argmax().item()
 
 
 
39
 
40
- # Assign attribute dynamically
41
- attribute_type = reference_labels[best_match_index]
42
- structured_output[attribute_type] = query # Assigns the query text to the detected attribute
 
 
43
 
44
- # Extract price dynamically
45
  price_match = price_pattern.search(query)
46
  if price_match:
47
  condition, amount, currency = price_match.groups()
@@ -52,13 +50,14 @@ def extract_attributes(query):
52
  # Define Gradio UI
53
  def parse_query(user_query):
54
  """
55
- Takes user query and returns structured attributes dynamically.
56
  """
57
  parsed_output = extract_attributes(user_query)
58
- return parsed_output # Returns structured JSON
59
 
 
60
  with gr.Blocks() as demo:
61
- gr.Markdown("# 🛍️ Fashion Query Parser using FashionCLIP")
62
 
63
  query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., Gucci men’s perfume under 200AED")
64
  output_box = gr.JSON(label="Parsed Output")
 
1
  import gradio as gr
2
  import torch
 
3
  import re
4
+ from transformers import pipeline
5
 
6
+ # Load fine-tuned NER model from Hugging Face Hub
7
+ model_name = "luxury-fashion-ner"
8
+ ner_pipeline = pipeline("ner", model=model_name, tokenizer=model_name)
 
9
 
10
+ # Regex for extracting price
11
  price_pattern = re.compile(r'(\bunder\b|\babove\b|\bbelow\b|\bbetween\b)?\s?(\d{1,5})\s?(AED|USD|EUR)?', re.IGNORECASE)
12
 
13
+ # Keywords for gender extraction
14
+ gender_keywords = ["men", "male", "women", "female", "unisex"]
 
 
 
 
 
 
15
 
16
  def extract_attributes(query):
17
  """
18
+ Extract structured fashion attributes dynamically using the fine-tuned NER model.
19
  """
20
  structured_output = {"Brand": "Unknown", "Category": "Unknown", "Gender": "Unknown", "Price": "Unknown"}
21
 
22
+ # Run NER model on query
23
+ entities = ner_pipeline(query)
24
 
25
+ for entity in entities:
26
+ entity_text = entity["word"].replace("##", "") # Fix tokenization artifacts
27
+ entity_label = entity["entity"]
28
 
29
+ if "ORG" in entity_label: # Organization = Brand
30
+ structured_output["Brand"] = entity_text
31
+ elif "MISC" in entity_label: # Miscellaneous = Category
32
+ structured_output["Category"] = entity_text
33
+ elif "LOC" in entity_label: # Locations (sometimes used for brands)
34
+ structured_output["Brand"] = entity_text
35
 
36
+ # Extract gender
37
+ for gender in gender_keywords:
38
+ if gender in query.lower():
39
+ structured_output["Gender"] = gender.capitalize()
40
+ break
41
 
42
+ # Extract price
43
  price_match = price_pattern.search(query)
44
  if price_match:
45
  condition, amount, currency = price_match.groups()
 
50
  # Define Gradio UI
51
  def parse_query(user_query):
52
  """
53
+ Parses fashion-related queries into structured attributes.
54
  """
55
  parsed_output = extract_attributes(user_query)
56
+ return parsed_output # JSON output
57
 
58
+ # Create Gradio Interface
59
  with gr.Blocks() as demo:
60
+ gr.Markdown("# 🛍️ Luxury Fashion Query Parser using Fine-Tuned NER Model")
61
 
62
  query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., Gucci men’s perfume under 200AED")
63
  output_box = gr.JSON(label="Parsed Output")