bgaspra commited on
Commit
d46f971
1 Parent(s): 7c20f88

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms, models
4
+ import gradio as gr
5
+ from PIL import Image
6
+
7
+ # Model Architecture (sama seperti sebelumnya)
8
+ class ModelRecommender(nn.Module):
9
+ def __init__(self, num_models, text_embedding_dim=768):
10
+ super(ModelRecommender, self).__init__()
11
+
12
+ # CNN for image processing
13
+ self.cnn = models.resnet18(pretrained=True)
14
+ self.cnn.fc = nn.Linear(512, 256)
15
+
16
+ # MLP for text processing
17
+ self.text_mlp = nn.Sequential(
18
+ nn.Linear(text_embedding_dim, 512),
19
+ nn.ReLU(),
20
+ nn.Linear(512, 256),
21
+ nn.ReLU()
22
+ )
23
+
24
+ # Combined layers
25
+ self.combined = nn.Sequential(
26
+ nn.Linear(512, 256),
27
+ nn.ReLU(),
28
+ nn.Dropout(0.5),
29
+ nn.Linear(256, num_models)
30
+ )
31
+
32
+ def forward(self, image, text_features):
33
+ # Process image
34
+ img_features = self.cnn(image)
35
+
36
+ # Process text
37
+ text_features = self.text_mlp(text_features)
38
+
39
+ # Combine features
40
+ combined = torch.cat((img_features, text_features), dim=1)
41
+
42
+ # Final prediction
43
+ output = self.combined(combined)
44
+ return output
45
+
46
+ # Load model dan dataset info
47
+ def load_model():
48
+ # Load dataset info
49
+ dataset_info = torch.load('dataset_info.pth')
50
+ model_names = dataset_info['model_names']
51
+
52
+ # Initialize model
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ model = ModelRecommender(len(model_names))
55
+
56
+ # Load model weights
57
+ checkpoint = torch.load('sd_recommender_model.pth', map_location=device)
58
+ model.load_state_dict(checkpoint['model_state_dict'])
59
+ model.to(device)
60
+ model.eval()
61
+
62
+ return model, model_names, device
63
+
64
+ # Inference function
65
+ def predict_image(image):
66
+ # Load model if not loaded
67
+ if not hasattr(predict_image, "model"):
68
+ predict_image.model, predict_image.model_names, predict_image.device = load_model()
69
+
70
+ # Preprocess image
71
+ transform = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ image_tensor = transform(image).unsqueeze(0).to(predict_image.device)
79
+ dummy_text_features = torch.zeros(1, 768).to(predict_image.device)
80
+
81
+ # Get predictions
82
+ with torch.no_grad():
83
+ outputs = predict_image.model(image_tensor, dummy_text_features)
84
+ probs = torch.nn.functional.softmax(outputs, dim=1)
85
+ top5_prob, top5_indices = torch.topk(probs, 5)
86
+
87
+ # Format results
88
+ results = []
89
+ for prob, idx in zip(top5_prob[0], top5_indices[0]):
90
+ model_name = predict_image.model_names[idx.item()]
91
+ confidence = f"{prob.item():.2%}"
92
+ results.append(f"Model: {model_name}\nConfidence: {confidence}")
93
+
94
+ return "\n\n".join(results)
95
+
96
+ # Gradio Interface
97
+ demo = gr.Interface(
98
+ fn=predict_image,
99
+ inputs=gr.Image(type="pil"),
100
+ outputs=gr.Textbox(label="Model Recommendations"),
101
+ title="Stable Diffusion Model Recommender",
102
+ description="Upload an AI-generated image to get model recommendations",
103
+ examples=[["example1.jpg"], ["example2.jpg"]] # Tambahkan contoh gambar jika ada
104
+ )
105
+
106
+ demo.launch()