{
"cells": [
{
"cell_type": "markdown",
"id": "3b5fbb8f-5789-45db-a551-f0e6633b4f46",
"metadata": {},
"source": [
"# Populate a HDF5 dataset with base64 Pokémon images keyed by energy type"
]
},
{
"cell_type": "markdown",
"id": "a9234c78-1ac5-4c71-b11b-3a81be57f3f3",
"metadata": {},
"source": [
"Used in [**This Pokémon Does Not Exist**](https://huggingface.co/spaces/ronvolutional/ai-pokemon-card)\n",
"\n",
"Model fine-tuned by [**Max Woolf**](https://huggingface.co/minimaxir/ai-generated-pokemon-rudalle)\n",
"\n",
"ruDALL-E by [**Sber**](https://rudalle.ru/en)"
]
},
{
"cell_type": "markdown",
"id": "09850949-235d-4997-b8e3-c5b2aeffe109",
"metadata": {},
"source": [
"## Initialise datasets"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cead6fa8-e9ef-4672-bbfb-beadcaf5f3a0",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import h5py\n",
"\n",
"datasets_dir = './datasets'\n",
"datasets_file = 'pregenerated_pokemon.h5'\n",
"h5_file = os.path.join(datasets_dir, datasets_file)\n",
"\n",
"energy_types = ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']"
]
},
{
"cell_type": "raw",
"id": "2df90e94-15c0-4eb6-914e-875ec80b7c24",
"metadata": {},
"source": [
"# Only run if the datasets file does not exist\n",
"\n",
"with h5py.File(h5_file, 'x') as datasets:\n",
" for energy in energy_types:\n",
" datasets.create_dataset(energy, (0,1), h5py.string_dtype(encoding='utf-8'), maxshape=(None,1))\n",
"\n",
" print(datasets.keys())"
]
},
{
"cell_type": "markdown",
"id": "cdd3eb59-bbf5-4b85-b6bc-35f591317b47",
"metadata": {},
"source": [
"### Dataset functions"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fca1947f-8a66-4636-8049-99d6ff0ace93",
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"from time import gmtime, strftime, time\n",
"from random import choices, randint\n",
"from IPython import display\n",
"\n",
"def get_stats(h5_file=h5_file):\n",
" with h5py.File(h5_file, 'r') as datasets:\n",
" return {\n",
" \"size_counts\": {key: datasets[key].size.item() for key in datasets.keys()},\n",
" \"size_total\": sum(list(datasets[energy].size.item() for energy in datasets.keys())),\n",
" \"size_mb\": round(os.path.getsize(h5_file) / 1024**2, 1)\n",
" }\n",
"\n",
"\n",
"def add_row(energy, image):\n",
" with h5py.File(h5_file, 'r+') as datasets:\n",
" dataset = datasets[energy]\n",
" dataset.resize(dataset.size + 1, 0)\n",
" dataset[-1] = image\n",
"\n",
"\n",
"def get_image(energy=None, row=None):\n",
" if not energy:\n",
" energy = choices(energy_types)[0]\n",
"\n",
" with h5py.File(h5_file, 'r') as datasets:\n",
" if not row:\n",
" row = randint(0, datasets[energy].size - 1)\n",
"\n",
" return datasets[energy].asstr()[row][0]\n",
"\n",
"def pretty_time(seconds):\n",
" m, s = divmod(seconds, 60)\n",
" h, m = divmod(m, 60)\n",
" return f\"{f'{math.floor(h)}h ' if h else ''}{f'{math.floor(m)}m ' if m else ''}{f'{math.floor(s)}s' if s else ''}\"\n",
" \n",
"def populate_dataset(batches=1, batch_size=1, image_cap=100_000, filesize_cap=4_000):\n",
" initial_stats = get_stats()\n",
"\n",
" iterations = 0\n",
" start_time = time()\n",
"\n",
" while iterations < batches and get_stats()['size_total'] < image_cap and get_stats()['size_mb'] < filesize_cap:\n",
" for energy in energy_types:\n",
" current = get_stats()\n",
" new_images_count = (current['size_total'] - initial_stats['size_total'])\n",
" new_mb_count = round(current['size_mb'] - initial_stats['size_mb'], 1)\n",
" elapsed = time() - start_time\n",
" eta_total = elapsed / (new_images_count or 1) * batches * batch_size * len(energy_types)\n",
"\n",
" display.clear_output(wait=True)\n",
" if new_images_count:\n",
" print(f\"ETA: {pretty_time(eta_total - elapsed)} left of {pretty_time(eta_total)}\")\n",
" print(f\"Images in dataset: {current['size_total']}{f' (+{new_images_count})' if new_images_count else ''}\")\n",
" print(f\"Size of dataset: {current['size_mb']}MB{f' (+{new_mb_count}MB)' if new_mb_count else ''}\")\n",
" print(f\"Batch {iterations + 1} of {batches}:\")\n",
" print(f\"{strftime('%Y-%m-%d %H:%M:%S', gmtime())} Generating {batch_size} {energy} Pokémon...\")\n",
"\n",
" generate_pokemon(energy, batch_size)\n",
"\n",
" iterations += 1\n",
"\n",
" new_stats = get_stats()\n",
" elapsed = time() - start_time\n",
"\n",
" display.clear_output(wait=True)\n",
" print(f\"{strftime('%Y-%m-%d %H:%M:%S', gmtime())} Finished populating dataset with {batches} {'batches' if batches > 1 else 'batch'} after {pretty_time(elapsed)}\")\n",
" print(f\"Images in dataset: {new_stats['size_total']} (+{new_stats['size_total'] - initial_stats['size_total']})\")\n",
" print(f\"Size of dataset: {new_stats['size_mb']}MB (+{round(new_stats['size_mb'] - initial_stats['size_mb'], 1)}MB)\")"
]
},
{
"cell_type": "markdown",
"id": "b0da582a-2b29-4df6-8a3f-ddd56377af16",
"metadata": {},
"source": [
"## Load Pokémon model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ad435db5-bd35-4440-87b2-5a108f4ae385",
"metadata": {},
"outputs": [],
"source": [
"from rudalle import get_rudalle_model, get_tokenizer, get_vae\n",
"from huggingface_hub import cached_download, hf_hub_url\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5a2a4e6a-2086-4f98-b2b5-65a3631be61e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPUs available: 1\n"
]
}
],
"source": [
"print(f\"GPUs available: {torch.cuda.device_count()}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "720df30d-f42c-406a-92ba-4a465f6ff1d3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Working with z of shape (1, 256, 32, 32) = 262144 dimensions.\n",
"vae --> ready\n",
"tokenizer --> ready\n",
"GPU[0] memory: 11263Mib\n",
"GPU[0] memory reserved: 5144Mib\n",
"GPU[0] memory allocated: 2767Mib\n"
]
}
],
"source": [
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"fp16 = torch.cuda.is_available()\n",
"map_location = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"file_dir = \"./models\"\n",
"file_name = \"pytorch_model.bin\"\n",
"config_file_url = hf_hub_url(repo_id=\"minimaxir/ai-generated-pokemon-rudalle\", filename=file_name)\n",
"cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)\n",
"\n",
"model = get_rudalle_model('Malevich', pretrained=False, fp16=fp16, device=device)\n",
"model.load_state_dict(torch.load(f\"{file_dir}/{file_name}\", map_location=map_location))\n",
"\n",
"vae = get_vae().to(device)\n",
"tokenizer = get_tokenizer()\n",
"\n",
"print(f\"GPU[0] memory: {int(torch.cuda.get_device_properties(0).total_memory / 1024**2)}Mib\")\n",
"print(f\"GPU[0] memory reserved: {int(torch.cuda.memory_reserved(0) / 1024**2)}Mib\")\n",
"print(f\"GPU[0] memory allocated: {int(torch.cuda.memory_allocated(0) / 1024**2)}Mib\")"
]
},
{
"cell_type": "markdown",
"id": "88d413a4-a8c9-401e-9cb5-32a1ae34c179",
"metadata": {},
"source": [
"### Model functions"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0624c686-c75f-46c6-afae-bbe6c455caa1",
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"from io import BytesIO\n",
"from time import gmtime, strftime, time\n",
"from rudalle.pipelines import generate_images\n",
"\n",
"def english_to_russian(english):\n",
" word_map = {\n",
" \"colorless\": \"Покемон нормального типа\",\n",
" \"dragon\": \"Покемон типа дракона\",\n",
" \"darkness\": \"Покемон темного типа\",\n",
" \"fairy\": \"Покемон фея\",\n",
" \"fighting\": \"Покемон боевого типа\",\n",
" \"fire\": \"Покемон огня\",\n",
" \"grass\": \"Покемон трава\",\n",
" \"lightning\": \"Покемон электрического типа\",\n",
" \"metal\": \"Покемон из стали типа\",\n",
" \"psychic\": \"Покемон психического типа\",\n",
" \"water\": \"Покемон в воду\"\n",
" }\n",
"\n",
" return word_map[english.lower()]\n",
"\n",
"\n",
"def generate_pokemon(energy, num=1):\n",
" if energy in energy_types:\n",
" russian_prompt = english_to_russian(energy)\n",
" \n",
" images, _ = generate_images(russian_prompt, tokenizer, model, vae, top_k=2048, images_num=num, top_p=0.995)\n",
" \n",
" for image in images:\n",
" buffer = BytesIO()\n",
" image.save(buffer, format=\"JPEG\", quality=100, optimize=True)\n",
" base64_bytes = base64.b64encode(buffer.getvalue())\n",
" base64_string = base64_bytes.decode(\"UTF-8\")\n",
" base64_image = \"data:image/jpeg;base64,\" + base64_string\n",
" add_row(energy, base64_image)"
]
},
{
"cell_type": "markdown",
"id": "7b309b8e-0c34-4a80-a411-8f093a105494",
"metadata": {},
"source": [
"## Populate dataset"
]
},
{
"cell_type": "markdown",
"id": "a96d026f-e300-40c7-88a6-20be0012c584",
"metadata": {},
"source": [
"Total number of images per population = `batches` × `len(energy_types)` (11) × `batch_size`"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "74d2c50a-93ae-4040-89a0-87f818187bbb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-03-16 05:07:48 Finished populating dataset with 1 batch after 10m 8s\n",
"Images in dataset: 5082 (+66)\n",
"Size of dataset: 199.8MB (+2.5MB)\n"
]
}
],
"source": [
"batches = 1\n",
"batch_size = 6\n",
"image_cap = 100_000\n",
"filesize_cap = 4_000 # MB\n",
"\n",
"populate_dataset(batches, batch_size, image_cap, filesize_cap)"
]
},
{
"cell_type": "markdown",
"id": "4eb6f750-aa2d-42b9-89dc-2cb103e8869e",
"metadata": {},
"source": [
"## Getting images"
]
},
{
"cell_type": "markdown",
"id": "9365ac7b-36e4-42b9-8380-a039d556356b",
"metadata": {},
"source": [
"### Random image"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dca3bc8e-1f65-4566-8385-bf6b9f20eaaf",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"display.HTML(f'')"
]
},
{
"cell_type": "markdown",
"id": "f3b9278e-b6dc-4bea-9ef6-3aa312739718",
"metadata": {},
"source": [
"### Random image of specific energy type"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b6e98817-19f7-44f3-8970-0a11b01cf37b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"display.HTML(f'')"
]
},
{
"cell_type": "markdown",
"id": "34ca6f96-625b-459a-ba90-2c58a1d0ea47",
"metadata": {},
"source": [
"### Specific image"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5a60fc60-4198-4318-a5b4-26095cb2c0bb",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"display.HTML(f'')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}