Kvikontent's picture
Create app.py
1cbcd7d verified
raw
history blame
No virus
3.62 kB
from huggingface_hub import cached_download, hf_hub_url
from PIL import Image
import os
import gradio as gr
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPModel
from transformers.pipelines import ImagePipeline
@spaces.GPU()
def train_image_generation_model(image_folder, text_folder, model_name="image_generation_model"):
"""Trains an image generation model on the provided dataset.
Args:
image_folder (str): Path to the folder containing training images.
text_folder (str): Path to the folder containing text prompts for each image.
model_name (str, optional): Name for the saved model file. Defaults to "image_generation_model".
Returns:
str: Path to the saved model file.
"""
class ImageTextDataset(Dataset):
def __init__(self, image_folder, text_folder, transform=None):
self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
self.text_paths = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.lower().endswith('.txt')]
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
if self.transform:
image = self.transform(image)
with open(self.text_paths[idx], 'r') as f:
text = f.read().strip()
return image, text
# Load CLIP model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
# Define image and text transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
])
# Create dataset and dataloader
dataset = ImageTextDataset(image_folder, text_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# Define optimizer and loss function
optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()
# Train the model
for epoch in range(10):
for i, (images, texts) in enumerate(dataloader):
optimizer.zero_grad()
image_features = clip_model.get_image_features(images)
text_features = clip_model.get_text_features(tokenizer(texts, return_tensors="pt")["input_ids"])
similarity = image_features @ text_features.T
loss = loss_fn(similarity, torch.arange(images.size(0), device=images.device))
loss.backward()
optimizer.step()
print(f"Epoch: {epoch} | Iteration: {i} | Loss: {loss.item()}")
# Save the trained model
model_path = os.path.join(os.getcwd(), model_name + ".pt")
torch.save(clip_model.state_dict(), model_path)
return model_path
# Define Gradio interface
iface = gr.Interface(
fn=train_image_generation_model,
inputs=[
gr.File(label="Image Folder"),
gr.File(label="Text Prompts Folder"),
],
outputs=gr.File(label="Model File"),
title="Image Generation Model Trainer",
description="Upload a folder of images and their corresponding text prompts to train a model.",
)
iface.launch(share=True)