bgaspra commited on
Commit
0f91f48
1 Parent(s): d46f971

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -15
app.py CHANGED
@@ -43,6 +43,8 @@ class ModelRecommender(nn.Module):
43
  output = self.combined(combined)
44
  return output
45
 
 
 
46
  # Load model dan dataset info
47
  def load_model():
48
  # Load dataset info
@@ -61,13 +63,13 @@ def load_model():
61
 
62
  return model, model_names, device
63
 
64
- # Inference function
 
 
65
  def predict_image(image):
66
- # Load model if not loaded
67
  if not hasattr(predict_image, "model"):
68
  predict_image.model, predict_image.model_names, predict_image.device = load_model()
69
 
70
- # Preprocess image
71
  transform = transforms.Compose([
72
  transforms.Resize((224, 224)),
73
  transforms.ToTensor(),
@@ -78,29 +80,85 @@ def predict_image(image):
78
  image_tensor = transform(image).unsqueeze(0).to(predict_image.device)
79
  dummy_text_features = torch.zeros(1, 768).to(predict_image.device)
80
 
81
- # Get predictions
82
  with torch.no_grad():
 
83
  outputs = predict_image.model(image_tensor, dummy_text_features)
84
- probs = torch.nn.functional.softmax(outputs, dim=1)
85
- top5_prob, top5_indices = torch.topk(probs, 5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # Format results
88
- results = []
89
- for prob, idx in zip(top5_prob[0], top5_indices[0]):
90
- model_name = predict_image.model_names[idx.item()]
91
- confidence = f"{prob.item():.2%}"
92
- results.append(f"Model: {model_name}\nConfidence: {confidence}")
93
 
94
- return "\n\n".join(results)
95
 
96
  # Gradio Interface
97
  demo = gr.Interface(
98
  fn=predict_image,
99
  inputs=gr.Image(type="pil"),
100
- outputs=gr.Textbox(label="Model Recommendations"),
101
  title="Stable Diffusion Model Recommender",
102
  description="Upload an AI-generated image to get model recommendations",
103
- examples=[["example1.jpg"], ["example2.jpg"]] # Tambahkan contoh gambar jika ada
104
  )
105
 
106
  demo.launch()
 
43
  output = self.combined(combined)
44
  return output
45
 
46
+
47
+
48
  # Load model dan dataset info
49
  def load_model():
50
  # Load dataset info
 
63
 
64
  return model, model_names, device
65
 
66
+ def calculate_euclidean_distance(features1, features2):
67
+ return np.linalg.norm(features1 - features2)
68
+
69
  def predict_image(image):
 
70
  if not hasattr(predict_image, "model"):
71
  predict_image.model, predict_image.model_names, predict_image.device = load_model()
72
 
 
73
  transform = transforms.Compose([
74
  transforms.Resize((224, 224)),
75
  transforms.ToTensor(),
 
80
  image_tensor = transform(image).unsqueeze(0).to(predict_image.device)
81
  dummy_text_features = torch.zeros(1, 768).to(predict_image.device)
82
 
83
+ # Get image features
84
  with torch.no_grad():
85
+ img_features = predict_image.model.cnn(image_tensor).cpu().numpy()
86
  outputs = predict_image.model(image_tensor, dummy_text_features)
87
+ top5_prob, top5_indices = torch.topk(outputs, 5)
88
+
89
+ # Create HTML gallery
90
+ html_output = """
91
+ <style>
92
+ .model-gallery {
93
+ display: grid;
94
+ grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
95
+ gap: 20px;
96
+ padding: 20px;
97
+ }
98
+ .model-card {
99
+ border: 1px solid #ddd;
100
+ border-radius: 8px;
101
+ overflow: hidden;
102
+ background: white;
103
+ box-shadow: 0 2px 5px rgba(0,0,0,0.1);
104
+ }
105
+ .model-img {
106
+ width: 100%;
107
+ height: 200px;
108
+ object-fit: cover;
109
+ }
110
+ .model-info {
111
+ padding: 15px;
112
+ }
113
+ .model-name {
114
+ color: #2563eb;
115
+ text-decoration: none;
116
+ font-weight: bold;
117
+ font-size: 1.1em;
118
+ margin-bottom: 8px;
119
+ display: block;
120
+ }
121
+ .model-name:hover {
122
+ text-decoration: underline;
123
+ }
124
+ .distance {
125
+ color: #666;
126
+ font-size: 0.9em;
127
+ }
128
+ </style>
129
+ <div class="model-gallery">
130
+ """
131
+
132
+ # Generate cards for each model
133
+ for idx, (score, model_idx) in enumerate(zip(top5_prob[0], top5_indices[0])):
134
+ model_name = predict_image.model_names[model_idx.item()]
135
+ distance = calculate_euclidean_distance(img_features[0],
136
+ torch.randn(512).numpy()) # Placeholder for actual features
137
+
138
+ civitai_url = f"https://civitai.com/search/models?sortBy=models_v9&query={model_name}"
139
+
140
+ html_output += f"""
141
+ <div class="model-card">
142
+ <img class="model-img" src="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' width='250' height='200' viewBox='0 0 250 200'><rect width='100%' height='100%' fill='%23f0f0f0'/><text x='50%' y='50%' dominant-baseline='middle' text-anchor='middle' font-family='Arial' font-size='16' fill='%23666'>Model Preview</text></svg>" alt="{model_name}">
143
+ <div class="model-info">
144
+ <a href="{civitai_url}" target="_blank" class="model-name">{model_name}</a>
145
+ <div class="distance">Euclidean Distance: {distance:.4f}</div>
146
+ </div>
147
+ </div>
148
+ """
149
 
150
+ html_output += "</div>"
 
 
 
 
 
151
 
152
+ return html_output
153
 
154
  # Gradio Interface
155
  demo = gr.Interface(
156
  fn=predict_image,
157
  inputs=gr.Image(type="pil"),
158
+ outputs=gr.HTML(),
159
  title="Stable Diffusion Model Recommender",
160
  description="Upload an AI-generated image to get model recommendations",
161
+ examples=[["example1.jpg"], ["example2.jpg"]]
162
  )
163
 
164
  demo.launch()