anky-degen-pixels / README.md
jpfraneto's picture
Update README.md
a524e5f verified
|
raw
history blame
10.9 kB
metadata
license: mit

This model was trained using the 8888 images of the Anky Genesis NFT Collection, and its mission is to transform an image into pixel art, like so:

Anky Degen Pixel Example

The code used for training it is the following:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np

# Custom dataset for loading the images
class PixelArtDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_files = [f"{i}.png" for i in range(1, 8889)]
        
        # Debug: Check if images are correctly listed
        print(f"Total images found: {len(self.image_files)}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, image

# Define the neural network
class PixelArtGenerator(nn.Module):
    def __init__(self):
        super(PixelArtGenerator, self).__init__()
        print("Initializing PixelArtGenerator Model...")
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

def train(model, dataloader, criterion, optimizer, device, epochs=50):
    print("Starting training...")
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        print(f"Epoch [{epoch+1}/{epochs}] starting...")
        for batch_idx, (input_images, target_images) in enumerate(dataloader):
            input_images, target_images = input_images.to(device), target_images.to(device)
            optimizer.zero_grad()
            outputs = model(input_images)
            loss = criterion(outputs, target_images)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            # Debug: Print progress for every batch
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

        print(f"Epoch [{epoch+1}/{epochs}] completed with Loss: {running_loss/len(dataloader):.4f}")

def create_pixel_art(model, input_image_path, output_image_path, device):
    print("Creating pixel art...")
    model.eval()
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    image = Image.open(input_image_path).convert("RGB")
    input_image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output_image = model(input_image).squeeze(0).cpu().numpy()
        output_image = np.transpose(output_image, (1, 2, 0))
        output_image = (output_image * 0.5 + 0.5) * 255.0
        output_image = np.clip(output_image, 0, 255).astype(np.uint8)
        output_image = Image.fromarray(output_image)
        output_image.save(output_image_path)
    print(f"Pixel art saved to {output_image_path}")

if __name__ == "__main__":
    # Transform for input images
    print("Setting up image transformations...")
    transform = transforms.Compose([
        transforms.Resize((64, 64)),  # Resize to 64x64 for input
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load dataset
    print("Loading dataset...")
    image_folder = "./"  # Change this to your images folder path
    dataset = PixelArtDataset(image_folder, transform)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)  # Reduce batch size for debugging

    # Check for GPU availability
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the model, criterion, and optimizer
    model = PixelArtGenerator().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0002)

    # Enable data parallelism if multiple GPUs are available
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)

    # Train the model
    train(model, dataloader, criterion, optimizer, device, epochs=50)

    # Save the model
    torch.save(model.state_dict(), "pixel_art_generator.pth")
    print("Model saved as 'pixel_art_generator.pth'")

    # Create pixel art from a new input image
    input_image_path = "input_image.png"  # Path to the high-resolution input image
    output_image_path = "pixel_art.png"  # Path to save the generated pixel art
    create_pixel_art(model, input_image_path, output_image_path, device)
    print("Pixel art creation completed.")

The training happened on a Cognition PRO called poiesis. It consisted of 50 epochs, and it lasted for about 4 hours running on 2x NVIDIA RTX 4090.

Its intended usage is for it to transform any image into its corresponding in pixels, as you can see on this one.

For running it like such, you can run the following python code on the containing folder of the model (for transforming an image called pfp.png):

import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms
import os

# Define the neural network (same as the one used during training)
class PixelArtGenerator(nn.Module):
    def __init__(self):
        super(PixelArtGenerator, self).__init__()
        print("Initializing PixelArtGenerator Model...")
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

def create_pixel_art(model, input_image_path, output_image_path, device):
    print(f"Creating pixel art for {input_image_path}...")

    # Check if the input image file exists
    if not os.path.isfile(input_image_path):
        print(f"Error: Input image file '{input_image_path}' not found.")
        return
    
    model.eval()
    print("Model set to evaluation mode.")

    # Define the transformation for the input image
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    print("Image transformation defined.")

    # Load and preprocess the input image
    image = Image.open(input_image_path).convert("RGB")
    input_image = transform(image).unsqueeze(0).to(device)
    print(f"Input image '{input_image_path}' loaded and preprocessed.")

    # Generate pixel art using the model
    with torch.no_grad():
        output_image = model(input_image).squeeze(0).cpu().numpy()
        print("Pixel art generated by the model.")

    # Post-process and save the output image
    output_image = np.transpose(output_image, (1, 2, 0))
    output_image = (output_image * 0.5 + 0.5) * 255.0
    output_image = np.clip(output_image, 0, 255).astype(np.uint8)
    output_image = Image.fromarray(output_image)

    # Scale up the image to iPhone 11 width (828 pixels)
    scaled_output_image = output_image.resize((828, int(828 * output_image.size[1] / output_image.size[0])), Image.NEAREST)
    scaled_output_image.save(output_image_path)
    print(f"Pixel art saved to '{output_image_path}'.")

if __name__ == "__main__":
    print("Starting pixel art generation script...")

    # Load the trained model
    model = PixelArtGenerator()
    model_path = "pixel_art_generator.pth"  # Path to the saved model
    print(f"Loading model from '{model_path}'...")

    # Load model with handling for DataParallel
    state_dict = torch.load(model_path)
    if 'module.' in list(state_dict.keys())[0]:
        # Remove 'module.' prefix if model was saved with DataParallel
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    print("Model loaded successfully.")
    
    # Check for GPU availability
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Using device: {device}")

    # Define the input and output paths for the single image
    input_image_path = "pfp.jpeg"  # Path to the input image
    output_image_path = "pfp_pixelated.png"  # Path to save the generated pixel art

    # Create pixel art for the single image
    create_pixel_art(model, input_image_path, output_image_path, device)

    print("Pixel art creation completed for the single image.")

Hope you enjoy, and any questions that you may have, feel free to reach out to @jpfraneto on telegram.

If you want to contribute to Anky, we have plenty of compute available, and a powerful story (and intention) that puts the unfolding of AI at the core of our experience as humans.

Think of it as a playground for your inner child, with boundless potential.

Our farcaster channel is here: https://warpcast.com/~/channel/anky

Your uniqueness is a gift.

🎩