bgaspra commited on
Commit
8fa2606
·
verified ·
1 Parent(s): ec29692

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -9
app.py CHANGED
@@ -14,7 +14,6 @@ model_name = "microsoft/Florence-2-base"
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
 
17
- # Modify model loading to disable flash attention
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_name,
20
  torch_dtype=torch_dtype,
@@ -22,21 +21,27 @@ model = AutoModelForCausalLM.from_pretrained(
22
  ).to(device)
23
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
24
 
25
- # Load CivitAI dataset (limited to 1000 samples)
26
  print("Loading dataset...")
27
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
28
  df = pd.DataFrame(dataset)
29
  print("Dataset loaded successfully!")
30
 
31
- # Create cache for embeddings to improve performance
32
  text_embedding_cache = {}
33
 
34
  def get_image_embedding(image):
35
  try:
36
- inputs = processor(images=image, return_tensors="pt").to(device, torch_dtype)
 
 
 
 
 
37
  with torch.no_grad():
38
- outputs = model.get_image_features(**inputs)
39
- return outputs.cpu().numpy()
 
 
40
  except Exception as e:
41
  print(f"Error in get_image_embedding: {str(e)}")
42
  return None
@@ -46,11 +51,17 @@ def get_text_embedding(text):
46
  if text in text_embedding_cache:
47
  return text_embedding_cache[text]
48
 
49
- inputs = processor(text=text, return_tensors="pt").to(device, torch_dtype)
 
 
 
 
 
50
  with torch.no_grad():
51
- outputs = model.get_text_features(**inputs)
 
52
 
53
- embedding = outputs.cpu().numpy()
54
  text_embedding_cache[text] = embedding
55
  return embedding
56
  except Exception as e:
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
 
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
  torch_dtype=torch_dtype,
 
21
  ).to(device)
22
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
23
 
24
+ # Load CivitAI dataset
25
  print("Loading dataset...")
26
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
27
  df = pd.DataFrame(dataset)
28
  print("Dataset loaded successfully!")
29
 
 
30
  text_embedding_cache = {}
31
 
32
  def get_image_embedding(image):
33
  try:
34
+ inputs = processor(
35
+ images=image,
36
+ text=[""], # Florence-2 requires both image and text inputs
37
+ return_tensors="pt"
38
+ ).to(device, torch_dtype)
39
+
40
  with torch.no_grad():
41
+ outputs = model(**inputs)
42
+ # Get the image embeddings from the last hidden states
43
+ image_embeddings = outputs.last_hidden_state[:, 0, :] # Take CLS token
44
+ return image_embeddings.cpu().numpy()
45
  except Exception as e:
46
  print(f"Error in get_image_embedding: {str(e)}")
47
  return None
 
51
  if text in text_embedding_cache:
52
  return text_embedding_cache[text]
53
 
54
+ inputs = processor(
55
+ text=text,
56
+ images=None,
57
+ return_tensors="pt"
58
+ ).to(device, torch_dtype)
59
+
60
  with torch.no_grad():
61
+ outputs = model(**inputs)
62
+ text_embeddings = outputs.last_hidden_state[:, 0, :] # Take CLS token
63
 
64
+ embedding = text_embeddings.cpu().numpy()
65
  text_embedding_cache[text] = embedding
66
  return embedding
67
  except Exception as e: