bgaspra commited on
Commit
c17d729
1 Parent(s): a651c65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -24
app.py CHANGED
@@ -11,8 +11,12 @@ from sklearn.preprocessing import LabelEncoder
11
  # Load dataset
12
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
13
 
14
- # Text preprocessing function
15
  def preprocess_text(text, max_length=100):
 
 
 
 
16
  # Convert text to lowercase and split into words
17
  words = text.lower().split()
18
  # Truncate or pad to max_length
@@ -29,32 +33,55 @@ class CustomDataset(Dataset):
29
  transforms.Resize((224, 224)),
30
  transforms.ToTensor(),
31
  ])
 
 
 
 
 
32
  self.label_encoder = LabelEncoder()
33
- self.labels = self.label_encoder.fit_transform(dataset['Model'])
34
 
35
  # Create vocabulary from all prompts
36
  self.vocab = set()
37
- for item in dataset['prompt']:
38
- self.vocab.update(preprocess_text(item))
 
 
 
 
 
 
 
39
  self.vocab = list(self.vocab)
40
  self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
41
 
42
  def __len__(self):
43
- return len(self.dataset)
44
 
45
  def text_to_vector(self, text):
46
- words = preprocess_text(text)
47
- vector = torch.zeros(len(self.vocab))
48
- for word in words:
49
- if word in self.word_to_idx:
50
- vector[self.word_to_idx[word]] += 1
51
- return vector
 
 
 
 
52
 
53
  def __getitem__(self, idx):
54
- image = self.transform(self.dataset[idx]['image'])
55
- text_vector = self.text_to_vector(self.dataset[idx]['prompt'])
56
- label = self.labels[idx]
57
- return image, text_vector, label
 
 
 
 
 
 
 
58
 
59
  # Define CNN for image processing
60
  class ImageModel(nn.Module):
@@ -85,11 +112,11 @@ class TextMLP(nn.Module):
85
 
86
  # Combined model
87
  class CombinedModel(nn.Module):
88
- def __init__(self, vocab_size):
89
  super(CombinedModel, self).__init__()
90
  self.image_model = ImageModel()
91
  self.text_model = TextMLP(vocab_size)
92
- self.fc = nn.Linear(1024, len(dataset['Model'].unique()))
93
 
94
  def forward(self, image, text):
95
  image_features = self.image_model(image)
@@ -97,9 +124,15 @@ class CombinedModel(nn.Module):
97
  combined = torch.cat((image_features, text_features), dim=1)
98
  return self.fc(combined)
99
 
100
- # Create dataset instance and model
 
101
  custom_dataset = CustomDataset(dataset)
102
- model = CombinedModel(len(custom_dataset.vocab))
 
 
 
 
 
103
 
104
  def get_recommendations(image):
105
  model.eval()
@@ -111,7 +144,7 @@ def get_recommendations(image):
111
  ])
112
  image_tensor = transform(image).unsqueeze(0)
113
 
114
- # Create dummy text vector (since we're only doing image-based recommendations)
115
  dummy_text = torch.zeros((1, len(custom_dataset.vocab)))
116
 
117
  # Get model output
@@ -121,9 +154,13 @@ def get_recommendations(image):
121
  # Get recommended images and their information
122
  recommendations = []
123
  for idx in indices[0]:
124
- recommended_image = dataset[idx.item()]['image']
125
- model_name = dataset[idx.item()]['Model']
126
- recommendations.append((recommended_image, f"{model_name}"))
 
 
 
 
127
 
128
  return recommendations
129
 
@@ -137,4 +174,5 @@ interface = gr.Interface(
137
  )
138
 
139
  # Launch the app
140
- interface.launch()
 
 
11
  # Load dataset
12
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
13
 
14
+ # Text preprocessing function with None handling
15
  def preprocess_text(text, max_length=100):
16
+ # Handle None or empty text
17
+ if text is None or not isinstance(text, str):
18
+ text = ""
19
+
20
  # Convert text to lowercase and split into words
21
  words = text.lower().split()
22
  # Truncate or pad to max_length
 
33
  transforms.Resize((224, 224)),
34
  transforms.ToTensor(),
35
  ])
36
+
37
+ # Filter out None values from Model column
38
+ valid_indices = [i for i, model in enumerate(dataset['Model']) if model is not None]
39
+ self.valid_dataset = dataset.select(valid_indices)
40
+
41
  self.label_encoder = LabelEncoder()
42
+ self.labels = self.label_encoder.fit_transform(self.valid_dataset['Model'])
43
 
44
  # Create vocabulary from all prompts
45
  self.vocab = set()
46
+ for item in self.valid_dataset['prompt']:
47
+ try:
48
+ self.vocab.update(preprocess_text(item))
49
+ except Exception as e:
50
+ print(f"Error processing prompt: {e}")
51
+ continue
52
+
53
+ # Remove empty string from vocabulary if present
54
+ self.vocab.discard('')
55
  self.vocab = list(self.vocab)
56
  self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
57
 
58
  def __len__(self):
59
+ return len(self.valid_dataset)
60
 
61
  def text_to_vector(self, text):
62
+ try:
63
+ words = preprocess_text(text)
64
+ vector = torch.zeros(len(self.vocab))
65
+ for word in words:
66
+ if word in self.word_to_idx:
67
+ vector[self.word_to_idx[word]] += 1
68
+ return vector
69
+ except Exception as e:
70
+ print(f"Error converting text to vector: {e}")
71
+ return torch.zeros(len(self.vocab))
72
 
73
  def __getitem__(self, idx):
74
+ try:
75
+ image = self.transform(self.valid_dataset[idx]['image'])
76
+ text_vector = self.text_to_vector(self.valid_dataset[idx]['prompt'])
77
+ label = self.labels[idx]
78
+ return image, text_vector, label
79
+ except Exception as e:
80
+ print(f"Error getting item at index {idx}: {e}")
81
+ # Return zero tensors as fallback
82
+ return (torch.zeros((3, 224, 224)),
83
+ torch.zeros(len(self.vocab)),
84
+ 0)
85
 
86
  # Define CNN for image processing
87
  class ImageModel(nn.Module):
 
112
 
113
  # Combined model
114
  class CombinedModel(nn.Module):
115
+ def __init__(self, vocab_size, num_classes):
116
  super(CombinedModel, self).__init__()
117
  self.image_model = ImageModel()
118
  self.text_model = TextMLP(vocab_size)
119
+ self.fc = nn.Linear(1024, num_classes)
120
 
121
  def forward(self, image, text):
122
  image_features = self.image_model(image)
 
124
  combined = torch.cat((image_features, text_features), dim=1)
125
  return self.fc(combined)
126
 
127
+ # Create dataset instance
128
+ print("Creating dataset...")
129
  custom_dataset = CustomDataset(dataset)
130
+ print(f"Vocabulary size: {len(custom_dataset.vocab)}")
131
+ print(f"Number of valid samples: {len(custom_dataset)}")
132
+
133
+ # Create model
134
+ num_classes = len(custom_dataset.label_encoder.classes_)
135
+ model = CombinedModel(len(custom_dataset.vocab), num_classes)
136
 
137
  def get_recommendations(image):
138
  model.eval()
 
144
  ])
145
  image_tensor = transform(image).unsqueeze(0)
146
 
147
+ # Create dummy text vector
148
  dummy_text = torch.zeros((1, len(custom_dataset.vocab)))
149
 
150
  # Get model output
 
154
  # Get recommended images and their information
155
  recommendations = []
156
  for idx in indices[0]:
157
+ try:
158
+ recommended_image = custom_dataset.valid_dataset[idx.item()]['image']
159
+ model_name = custom_dataset.valid_dataset[idx.item()]['Model']
160
+ recommendations.append((recommended_image, f"{model_name}"))
161
+ except Exception as e:
162
+ print(f"Error getting recommendation for index {idx}: {e}")
163
+ continue
164
 
165
  return recommendations
166
 
 
174
  )
175
 
176
  # Launch the app
177
+ if __name__ == "__main__":
178
+ interface.launch()