Kvikontent commited on
Commit
1cbcd7d
1 Parent(s): 3c59ec7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import cached_download, hf_hub_url
2
+ from PIL import Image
3
+ import os
4
+ import gradio as gr
5
+ import torch
6
+ from torch import nn
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torchvision import transforms
9
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPModel
10
+ from transformers.pipelines import ImagePipeline
11
+
12
+ @spaces.GPU()
13
+ def train_image_generation_model(image_folder, text_folder, model_name="image_generation_model"):
14
+ """Trains an image generation model on the provided dataset.
15
+
16
+ Args:
17
+ image_folder (str): Path to the folder containing training images.
18
+ text_folder (str): Path to the folder containing text prompts for each image.
19
+ model_name (str, optional): Name for the saved model file. Defaults to "image_generation_model".
20
+
21
+ Returns:
22
+ str: Path to the saved model file.
23
+ """
24
+
25
+ class ImageTextDataset(Dataset):
26
+ def __init__(self, image_folder, text_folder, transform=None):
27
+ self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
28
+ self.text_paths = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.lower().endswith('.txt')]
29
+ self.transform = transform
30
+
31
+ def __len__(self):
32
+ return len(self.image_paths)
33
+
34
+ def __getitem__(self, idx):
35
+ image = Image.open(self.image_paths[idx]).convert("RGB")
36
+ if self.transform:
37
+ image = self.transform(image)
38
+ with open(self.text_paths[idx], 'r') as f:
39
+ text = f.read().strip()
40
+ return image, text
41
+
42
+ # Load CLIP model
43
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
44
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
45
+
46
+ # Define image and text transformations
47
+ transform = transforms.Compose([
48
+ transforms.Resize((224, 224)),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
51
+ ])
52
+
53
+ # Create dataset and dataloader
54
+ dataset = ImageTextDataset(image_folder, text_folder, transform=transform)
55
+ dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
56
+
57
+ # Define optimizer and loss function
58
+ optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5)
59
+ loss_fn = nn.CrossEntropyLoss()
60
+
61
+ # Train the model
62
+ for epoch in range(10):
63
+ for i, (images, texts) in enumerate(dataloader):
64
+ optimizer.zero_grad()
65
+ image_features = clip_model.get_image_features(images)
66
+ text_features = clip_model.get_text_features(tokenizer(texts, return_tensors="pt")["input_ids"])
67
+ similarity = image_features @ text_features.T
68
+ loss = loss_fn(similarity, torch.arange(images.size(0), device=images.device))
69
+ loss.backward()
70
+ optimizer.step()
71
+ print(f"Epoch: {epoch} | Iteration: {i} | Loss: {loss.item()}")
72
+
73
+ # Save the trained model
74
+ model_path = os.path.join(os.getcwd(), model_name + ".pt")
75
+ torch.save(clip_model.state_dict(), model_path)
76
+
77
+ return model_path
78
+
79
+ # Define Gradio interface
80
+ iface = gr.Interface(
81
+ fn=train_image_generation_model,
82
+ inputs=[
83
+ gr.File(label="Image Folder"),
84
+ gr.File(label="Text Prompts Folder"),
85
+ ],
86
+ outputs=gr.File(label="Model File"),
87
+ title="Image Generation Model Trainer",
88
+ description="Upload a folder of images and their corresponding text prompts to train a model.",
89
+ )
90
+
91
+ iface.launch(share=True)