Harshit Agarwal
commited on
Commit
·
eaefa93
1
Parent(s):
ee0ac2c
initial comm
Browse files- .gitignore +126 -0
- .vscode/settings.json +5 -0
- README.md +98 -3
- dataloader/dataloader_cifar10.py +39 -0
- ddpm.ipynb +0 -0
- journals/2025_01_29.md +63 -0
- model/attn_utils.py +97 -0
- model/model.py +320 -0
- model/precomputes.py +14 -0
- model/utils.py +90 -0
- notebooks/ddpm (1).ipynb +0 -0
- notebooks/ddpm.ipynb +0 -0
- pretrained_weights/sample_outputs/epoch5_1.png +0 -0
- pretrained_weights/sample_outputs/epochs5_1_cbam.png +0 -0
- requirements.txt +6 -0
- simple_game/ddpm.py +73 -0
- simple_game/estelle-peplum-top-tops-509.webp +0 -0
- simple_game/ideation.md +0 -0
- slides/notes/dwn5.png +0 -0
- slides/notes/dwn52.png +0 -0
- slides/notes/dwn53.png +0 -0
- usage/generate.py +3 -0
- usage/train.py +3 -0
.gitignore
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
noise-game/*
|
2 |
+
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
cover/
|
55 |
+
|
56 |
+
# Translations
|
57 |
+
*.mo
|
58 |
+
*.pot
|
59 |
+
|
60 |
+
# Django stuff:
|
61 |
+
*.log
|
62 |
+
local_settings.py
|
63 |
+
db.sqlite3
|
64 |
+
db.sqlite3-journal
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
.python-version
|
88 |
+
|
89 |
+
# celery beat schedule file
|
90 |
+
celerybeat-schedule
|
91 |
+
|
92 |
+
# SageMath parsed files
|
93 |
+
*.sage.py
|
94 |
+
|
95 |
+
# Environments
|
96 |
+
.env
|
97 |
+
.venv
|
98 |
+
env/
|
99 |
+
venv/
|
100 |
+
ENV/
|
101 |
+
env.bak/
|
102 |
+
venv.bak/
|
103 |
+
|
104 |
+
# Spyder project settings
|
105 |
+
.spyderproject
|
106 |
+
.spyderworkspace
|
107 |
+
|
108 |
+
# Rope project settings
|
109 |
+
.ropeproject
|
110 |
+
|
111 |
+
# mkdocs documentation
|
112 |
+
/site
|
113 |
+
|
114 |
+
# mypy
|
115 |
+
.mypy_cache/
|
116 |
+
.dmypy.json
|
117 |
+
dmypy.json
|
118 |
+
|
119 |
+
# Pyre type checker
|
120 |
+
.pyre/
|
121 |
+
|
122 |
+
# pytype static type analyzer
|
123 |
+
.pytype/
|
124 |
+
|
125 |
+
# Cython debug symbols
|
126 |
+
cython_debug/
|
.vscode/settings.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"files.exclude": {
|
3 |
+
"logseq/*": true
|
4 |
+
}
|
5 |
+
}
|
README.md
CHANGED
@@ -1,3 +1,98 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DDPM Project
|
2 |
+
|
3 |
+
This repository contains the implementation of Denoising Diffusion Probabilistic Models (DDPM).
|
4 |
+
|
5 |
+
## Table of Contents
|
6 |
+
- [Introduction](#introduction)
|
7 |
+
- [Installation](#installation)
|
8 |
+
- [Usage](#usage)
|
9 |
+
- [Contributing](#contributing)
|
10 |
+
|
11 |
+
## Introduction
|
12 |
+
Denoising Diffusion Probabilistic Models (DDPM) are a class of generative models that learn to generate data by reversing a diffusion process. This repository provides a comprehensive implementation of DDPM.
|
13 |
+
|
14 |
+
## Installation
|
15 |
+
To install the necessary dependencies, run:
|
16 |
+
```bash
|
17 |
+
pip install -r requirements.txt
|
18 |
+
```
|
19 |
+
|
20 |
+
## Usage
|
21 |
+
To train the model, use the following command:
|
22 |
+
```bash
|
23 |
+
python train.py
|
24 |
+
```
|
25 |
+
To generate samples, use:
|
26 |
+
```bash
|
27 |
+
python generate.py
|
28 |
+
```
|
29 |
+
|
30 |
+
## Game
|
31 |
+
To understand the model and it's workings, we're working on a cool cute little game where the user is the UNET reverser/diffusion model and is tasked to denoise the images with noise made of grids of lines.
|
32 |
+
|
33 |
+
Use [learndiffusion.vercel.app](learndiffusion.vercel.app) to access the primitive version of the game. You can also contribute to the game by checking out at the diffusion_game branch. A new model showcase will also be added such that the model's weights are loaded from the internet, model's files are installed and loaded into a gradio interface for direct use/inference on the vercel. Feel free to make changes for the same, issue is opened.
|
34 |
+
|
35 |
+
## Explanations and Mathematics
|
36 |
+
- slides from presentation :
|
37 |
+
- notes/explanations : [HERE](slides\notes)
|
38 |
+
- a cute lab talk ppt:
|
39 |
+
- plato's allegory : \<link to REPUBLIC>
|
40 |
+
|
41 |
+
## Resources
|
42 |
+
- Original Paper : https://arxiv.org/pdf/2006.11239
|
43 |
+
- Improvement Paper : https://arxiv.org/abs/2102.09672
|
44 |
+
- Improvement by OpenAI : https://arxiv.org/pdf/2105.05233
|
45 |
+
- Stable Diffusion Paper : https://arxiv.org/abs/2112.10752
|
46 |
+
-
|
47 |
+
|
48 |
+
### Papers for background
|
49 |
+
- UNET Paper for Biomedical Segmentation
|
50 |
+
- Autoencooder
|
51 |
+
- Variational Autoencoder
|
52 |
+
- Markov Hierarchical VAE
|
53 |
+
- Introductory Lectures on Diffusion Process
|
54 |
+
|
55 |
+
### Youtube videos and courses
|
56 |
+
#### Mathematics
|
57 |
+
- Outliers
|
58 |
+
- Omar Jahil
|
59 |
+
|
60 |
+
#### Pytorch Implementation
|
61 |
+
- [Deep Findr](https://www.youtube.com/watch?v=a4Yfz2FxXiY)
|
62 |
+
- [Notebook from Deep Findr](https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL?usp=sharing)
|
63 |
+
|
64 |
+
## Pretrained Weights
|
65 |
+
weights from the model can be found in [pretrained_weights](https://drive.google.com/drive/folders/1NiQDI3e67I9FITVnrzNPP2Az0LABRpic?usp=sharing)
|
66 |
+
|
67 |
+
For loading the pretrained weights:
|
68 |
+
```
|
69 |
+
model2 = SimpleUnet()
|
70 |
+
model2.load_state_dict(torch.load("/content/drive/MyDrive/Research Work/mlsa/DDPM/model_weights.pth"))
|
71 |
+
model2.eval()
|
72 |
+
```
|
73 |
+
|
74 |
+
For making inferences
|
75 |
+
TODO: Errors in the sampling function, boolean errors and etc. Will open issues for solving by others as exercise if needed.
|
76 |
+
```
|
77 |
+
num_samples = 8 # Number of images to generate
|
78 |
+
image_size = (3, 32, 32) # Example for CIFAR10
|
79 |
+
noise = torch.randn(num_samples, *image_size).to("cuda")
|
80 |
+
|
81 |
+
model2.to("cuda")
|
82 |
+
# Generate images by denoising
|
83 |
+
with torch.no_grad():
|
84 |
+
generated_images = model2.sample(noise)
|
85 |
+
|
86 |
+
# Save the generated images
|
87 |
+
save_image(generated_images, "generated_images.png", nrow=4, normalize=True)
|
88 |
+
```
|
89 |
+
|
90 |
+
|
91 |
+
## Contributing
|
92 |
+
Contributions are welcome! Please open an issue or submit a pull request.
|
93 |
+
|
94 |
+
|
95 |
+
## Future Ideas
|
96 |
+
- Make the model onnx compatible for training and inferencing on Intel GPUs
|
97 |
+
- Build a Stable Diffusion model Text2Img using CLIP implementationnnnn !!!
|
98 |
+
- Train the current model for a much larger dataset with more generalizations and nuances
|
dataloader/dataloader_cifar10.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
from torch.utils.data import Subset, DataLoader
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
|
9 |
+
def load_transformed_dataset(IMG_SIZE=64):
|
10 |
+
# Define the transformation pipeline
|
11 |
+
data_transforms = [
|
12 |
+
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
13 |
+
transforms.RandomHorizontalFlip(),
|
14 |
+
transforms.ToTensor(), # Scales data into [0,1]
|
15 |
+
transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
|
16 |
+
]
|
17 |
+
data_transform = transforms.Compose(data_transforms)
|
18 |
+
|
19 |
+
# Load CIFAR10 dataset without splitting
|
20 |
+
cifar10_dataset = torchvision.datasets.CIFAR10(root=".", download=True, transform=data_transform)
|
21 |
+
|
22 |
+
# Split indices into train and test using sklearn's train_test_split
|
23 |
+
dataset_size = len(cifar10_dataset)
|
24 |
+
indices = list(range(dataset_size))
|
25 |
+
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
|
26 |
+
|
27 |
+
# Create train and test subsets
|
28 |
+
train_subset = Subset(cifar10_dataset, train_indices)
|
29 |
+
test_subset = Subset(cifar10_dataset, test_indices)
|
30 |
+
|
31 |
+
# Combine train and test subsets into a single ConcatDataset
|
32 |
+
combined_dataset = torch.utils.data.ConcatDataset([train_subset, test_subset])
|
33 |
+
|
34 |
+
return combined_dataset
|
35 |
+
|
36 |
+
def load_dataloader(combined_dataset, batch_size=64):
|
37 |
+
# Create dataloaders
|
38 |
+
dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
39 |
+
return dataloader
|
ddpm.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
journals/2025_01_29.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- i am gonna make a cute lil game for the first years to understand this ddpm paper i think
|
2 |
+
for 1st task, i am thinking of making a game where user gets an image and behind the scene there are 15-20 "noise" patterns made of lines (diagonals and straight) that get transposed over the image at our ease procedurally.
|
3 |
+
|
4 |
+
|
5 |
+
now at any given level, any of the N noise patterns can be on the image and the noisy image as well as the original image will be provided to the user.
|
6 |
+
the user's task therefore is to choose which out of the mcq noise options are there in the image.
|
7 |
+
|
8 |
+
In task 2, the user will be shown a simulation. infront of the user, starting from 1 to N, each individual noise pattern will be transposed on the image infront of the user. the user needs to memorize the pattern and after the simulation, choose the patterns first forwards and then backwards.
|
9 |
+
then in the final boss level, i'll have actual gaussian normal distributions as the noise 😂 and the user will decipher that -->
|
10 |
+
|
11 |
+
|
12 |
+
### **Game Structure**
|
13 |
+
1. **Task 1: Noise Pattern Identification**
|
14 |
+
- User sees a noisy image and the original image side by side.
|
15 |
+
- User selects which noise patterns (from MCQ options) are applied to the noisy image.
|
16 |
+
- Feedback is provided on correctness.
|
17 |
+
|
18 |
+
2. **Task 2: Noise Pattern Memorization**
|
19 |
+
- User watches a simulation where noise patterns are sequentially applied to an image.
|
20 |
+
- After the simulation, the user must recall and select the patterns in the correct order (forwards and backwards).
|
21 |
+
- Feedback is provided on accuracy.
|
22 |
+
|
23 |
+
3. **Boss Level: Gaussian Noise Deciphering**
|
24 |
+
- User is shown an image with Gaussian noise applied.
|
25 |
+
- User must identify the noise characteristics (e.g., intensity, distribution) or match it to a reference.
|
26 |
+
|
27 |
+
---
|
28 |
+
|
29 |
+
#### 1. **Image and Noise Generator**
|
30 |
+
- Create a `SquareImage` component that displays the original and noisy images.
|
31 |
+
- Create a `NoisePattern` component that generates the noise patterns (grids with lines).
|
32 |
+
- Use a `NoiseGenerator` utility to procedurally apply noise patterns to the image.
|
33 |
+
|
34 |
+
#### 2. **Task 1: Noise Identification**
|
35 |
+
- Create a `Task1` component that:
|
36 |
+
- Displays the original and noisy images.
|
37 |
+
- Provides MCQ options for noise patterns.
|
38 |
+
- Handles user input and provides feedback.
|
39 |
+
|
40 |
+
#### 3. **Task 2: Noise Memorization**
|
41 |
+
- Create a `Task2` component that:
|
42 |
+
- Plays a simulation of noise patterns being applied sequentially.
|
43 |
+
- Provides an interface for the user to select patterns in order.
|
44 |
+
- Validates the user’s input and provides feedback.
|
45 |
+
|
46 |
+
#### 4. **Boss Level: Gaussian Noise**
|
47 |
+
- Create a `BossLevel` component that:
|
48 |
+
- Applies Gaussian noise to an image.
|
49 |
+
- Provides tools for the user to analyze and decipher the noise.
|
50 |
+
|
51 |
+
#### 5. **Game Manager**
|
52 |
+
- Create a `GameManager` component that:
|
53 |
+
- Tracks the user’s progress through tasks.
|
54 |
+
- Handles transitions between tasks.
|
55 |
+
- Displays scores and feedback.
|
56 |
+
|
57 |
+
|
58 |
+
### **Next Steps**
|
59 |
+
1. Implement the `NoiseGenerator` utility to create and apply noise patterns.
|
60 |
+
2. Add feedback mechanisms for user selections.
|
61 |
+
3. Build the simulation for Task 2.
|
62 |
+
4. Integrate Gaussian noise for the Boss Level.
|
63 |
+
5. Style the app to make it visually appealing.
|
model/attn_utils.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import math
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class SelfAttention(nn.Module):
|
8 |
+
def __init__(self, chs, num_heads=1, ffn_expansion=4, dropout=0.1):
|
9 |
+
super().__init__()
|
10 |
+
self.norm = nn.LayerNorm(chs)
|
11 |
+
self.attn = nn.MultiheadAttention(embed_dim=chs, num_heads=num_heads, batch_first=True)
|
12 |
+
|
13 |
+
self.ffn = nn.Sequential(
|
14 |
+
nn.Linear(chs, chs*ffn_expansion),
|
15 |
+
nn.GELU(),
|
16 |
+
nn.Linear(chs*ffn_expansion, chs)
|
17 |
+
)
|
18 |
+
|
19 |
+
self.norm2 = nn.LayerNorm(chs)
|
20 |
+
self.dropout = nn.Dropout(dropout)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
b,c,h,w = x.shape
|
24 |
+
x_reshaped = x.view(b,c,h*w).transpose(1,2)
|
25 |
+
|
26 |
+
attn_out, _ = self.attn(self.norm(x_reshaped), self.norm(x_reshaped), self.norm(x_reshaped))
|
27 |
+
x_attn = x_reshaped + self.dropout(attn_out)
|
28 |
+
|
29 |
+
ffn_out = self.ffn(self.norm2(x_attn))
|
30 |
+
x_out = x_attn + self.dropout(ffn_out)
|
31 |
+
|
32 |
+
|
33 |
+
x_out = x_out.transpose(1,2).view(b,c,h,w)
|
34 |
+
return x_out
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class CBAM(nn.Module):
|
40 |
+
def __init__(self,chs, reduction=16):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.channel_attn = nn.Sequential(
|
44 |
+
nn.AdaptiveAvgPool2d(1),
|
45 |
+
nn.Conv2d(chs, chs//reduction, 1),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.Conv2d(chs//reduction, chs, 1),
|
48 |
+
nn.Sigmoid()
|
49 |
+
)
|
50 |
+
|
51 |
+
self.spatial_attn = nn.Sequential(
|
52 |
+
nn.Conv2d(2,1,kernel_size=7,padding=3),
|
53 |
+
nn.Sigmoid()
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self,x):
|
57 |
+
ch_wt = self.channel_attn(x)
|
58 |
+
x = x*ch_wt
|
59 |
+
|
60 |
+
avg_pool = torch.mean(x, dim=1, keepdim=True)
|
61 |
+
max_pool, _ = torch.max(x, dim=1, keepdim=True)
|
62 |
+
sp_wt = self.spatial_attn(torch.cat([avg_pool, max_pool], dim=1))
|
63 |
+
x = x* sp_wt
|
64 |
+
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
class Block_CBAM(nn.Module):
|
71 |
+
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
|
72 |
+
super().__init__()
|
73 |
+
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
|
74 |
+
if up:
|
75 |
+
## up channel - go big big big bigg from smol smol smol with 3x3 kernel
|
76 |
+
self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
|
77 |
+
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
|
78 |
+
else:
|
79 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
|
80 |
+
self.transform = nn.Conv2d(out_ch, out_ch, 4,2,1)
|
81 |
+
|
82 |
+
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
83 |
+
self.relu = nn.ReLU()
|
84 |
+
self.batch_norm1 = nn.BatchNorm2d(out_ch)
|
85 |
+
self.batch_norm2 = nn.BatchNorm2d(out_ch)
|
86 |
+
|
87 |
+
self.cbam = CBAM(out_ch)
|
88 |
+
|
89 |
+
def forward(self, x, t, ):
|
90 |
+
h = self.batch_norm1(self.relu(self.conv1(x)))
|
91 |
+
time_emb = self.relu(self.time_mlp(t))
|
92 |
+
time_emb = time_emb[(..., ) + (None, ) * 2]
|
93 |
+
h = h + time_emb
|
94 |
+
h = self.batch_norm2(self.relu(self.conv2(h)))
|
95 |
+
|
96 |
+
h = self.cbam(h)
|
97 |
+
return self.transform(h)
|
model/model.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
THis gile is to contain the DDPM implementation modularized for loading, prediciton and training.
|
3 |
+
'''
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
from utils import forward_diffusion_sample, sample_timestep, sample_plot_image
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from attn_utils import SelfAttention, CBAM, Block_CBAM
|
11 |
+
|
12 |
+
class Block(nn.Module):
|
13 |
+
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
|
14 |
+
super().__init__()
|
15 |
+
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
|
16 |
+
if up:
|
17 |
+
## up channel - go big big big bigg from smol smol smol with 3x3 kernel
|
18 |
+
self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
|
19 |
+
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
|
20 |
+
else:
|
21 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
|
22 |
+
self.transform = nn.Conv2d(out_ch, out_ch, 4,2,1)
|
23 |
+
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
24 |
+
self.relu = nn.ReLU()
|
25 |
+
self.batch_norm1 = nn.BatchNorm2d(out_ch)
|
26 |
+
self.batch_norm2 = nn.BatchNorm2d(out_ch)
|
27 |
+
|
28 |
+
def forward(self, x, t, ):
|
29 |
+
h = self.batch_norm1(self.relu(self.conv1(x)))
|
30 |
+
time_emb = self.relu(self.time_mlp(t))
|
31 |
+
time_emb = time_emb[(..., ) + (None, ) * 2]
|
32 |
+
h = h + time_emb
|
33 |
+
h = self.batch_norm2(self.relu(self.conv2(h)))
|
34 |
+
return self.transform(h)
|
35 |
+
|
36 |
+
class PositionEmbeddings(nn.Module):
|
37 |
+
def __init__(self,dim):
|
38 |
+
super().__init__()
|
39 |
+
self.dim = dim
|
40 |
+
|
41 |
+
def forward(self, time):
|
42 |
+
device = time.device
|
43 |
+
half_dim = self.dim // 2
|
44 |
+
embeddings = math.log(10000) / (half_dim - 1)
|
45 |
+
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
46 |
+
embeddings = time[:, None] * embeddings[None, :]
|
47 |
+
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
48 |
+
return embeddings
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
class SimpleUnet(nn.Module):
|
53 |
+
def __init__(self):
|
54 |
+
super().__init__()
|
55 |
+
image_channels = 3
|
56 |
+
down_channels = (64, 128, 256, 512, 1024)
|
57 |
+
up_channels = (1024, 512, 256, 128, 64)
|
58 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
59 |
+
|
60 |
+
out_dim = 3
|
61 |
+
time_emb_dim = 32
|
62 |
+
|
63 |
+
## timestep stored as positional encoding in terms of sine
|
64 |
+
self.time_mlp = nn.Sequential(
|
65 |
+
PositionEmbeddings(time_emb_dim),
|
66 |
+
nn.Linear(time_emb_dim, time_emb_dim),
|
67 |
+
nn.ReLU()
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
|
72 |
+
self.down_blocks = nn.ModuleList([
|
73 |
+
Block(down_channels[i], down_channels[i+1], time_emb_dim)
|
74 |
+
for i in range(len(down_channels)-1)
|
75 |
+
])
|
76 |
+
self.up_blocks = nn.ModuleList([
|
77 |
+
Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True)
|
78 |
+
for i in range(len(up_channels)-1)
|
79 |
+
])
|
80 |
+
|
81 |
+
## readout layer
|
82 |
+
self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
|
83 |
+
|
84 |
+
def forward(self, x, timestep):
|
85 |
+
t = self.time_mlp(timestep)
|
86 |
+
x = self.conv0(x)
|
87 |
+
residual_inputs = []
|
88 |
+
for down in self.down_blocks:
|
89 |
+
x = down(x, t)
|
90 |
+
residual_inputs.append(x)
|
91 |
+
for up in self.up_blocks:
|
92 |
+
residual_x = residual_inputs.pop()
|
93 |
+
x = torch.cat((x, residual_x), dim=1)
|
94 |
+
x = up(x, t)
|
95 |
+
return self.output(x)
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def sample(self, noise):
|
99 |
+
"""
|
100 |
+
Generate an image by denoising a given noise tensor using the reverse diffusion process.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
noise (torch.Tensor): Initial noise tensor (e.g., sampled from a Gaussian distribution).
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
torch.Tensor: Denoised image.
|
107 |
+
"""
|
108 |
+
img = noise # Start with the provided noise tensor
|
109 |
+
T = self.num_timesteps # Total timesteps for diffusion
|
110 |
+
stepsize = 1 # You can adjust if needed
|
111 |
+
|
112 |
+
# Iterate through the timesteps in reverse order
|
113 |
+
for i in range(0, T)[::-1]:
|
114 |
+
t = torch.full((noise.size(0),), i, device=noise.device, dtype=torch.long) # Current timestep
|
115 |
+
img = sample_timestep(self, img, t) # Perform one reverse diffusion step
|
116 |
+
img = torch.clamp(img, -1.0, 1.0) # Clamp the image to ensure values stay in [-1, 1]
|
117 |
+
|
118 |
+
return img
|
119 |
+
|
120 |
+
def get_loss(self, x_0, t):
|
121 |
+
x_noisy, noise = forward_diffusion_sample(x_0, t, self.device)
|
122 |
+
noise_pred = self(x_noisy, t)
|
123 |
+
return F.l1_loss(noise, noise_pred)
|
124 |
+
|
125 |
+
def train(self, dataloader, BATCH_SIZE=64,T=300, EPOCHS=50, verbose=True):
|
126 |
+
from torch.optim import Adam
|
127 |
+
|
128 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
129 |
+
self.to(device)
|
130 |
+
optimizer = Adam(self.parameters(), lr=0.001)
|
131 |
+
epochs = EPOCHS
|
132 |
+
|
133 |
+
for epoch in range(epochs):
|
134 |
+
for step, batch in enumerate(dataloader):
|
135 |
+
optimizer.zero_grad()
|
136 |
+
|
137 |
+
t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
|
138 |
+
loss = self.get_loss(self, batch[0], t)
|
139 |
+
loss.backward()
|
140 |
+
optimizer.step()
|
141 |
+
|
142 |
+
if verbose:
|
143 |
+
if epoch % 5 == 0 and step % 150 == 0:
|
144 |
+
print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
|
145 |
+
sample_plot_image(self)
|
146 |
+
|
147 |
+
def test():
|
148 |
+
## TODO: add the testing loop here
|
149 |
+
pass
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
################################################################################################
|
155 |
+
####################### ATTENTION LAYERS ADDEDD TO THE MODEL ###################################
|
156 |
+
################################################################################################
|
157 |
+
|
158 |
+
class SimpleUnetWSelfAttn(nn.Module):
|
159 |
+
def __init__(self):
|
160 |
+
super().__init__()
|
161 |
+
image_channels = 3
|
162 |
+
down_channels = (64, 128, 256, 512, 1024)
|
163 |
+
up_channels = (1024, 512, 256, 128, 64)
|
164 |
+
|
165 |
+
out_dim = 3
|
166 |
+
time_emb_dim = 32
|
167 |
+
|
168 |
+
## timestep stored as positional encoding in terms of sine
|
169 |
+
self.time_mlp = nn.Sequential(
|
170 |
+
PositionEmbeddings(time_emb_dim),
|
171 |
+
nn.Linear(time_emb_dim, time_emb_dim),
|
172 |
+
nn.ReLU()
|
173 |
+
)
|
174 |
+
self.num_timesteps = 300
|
175 |
+
|
176 |
+
|
177 |
+
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
|
178 |
+
self.down_blocks = nn.ModuleList([
|
179 |
+
Block(down_channels[i], down_channels[i+1], time_emb_dim)
|
180 |
+
for i in range(len(down_channels)-1)
|
181 |
+
])
|
182 |
+
self.up_blocks = nn.ModuleList([
|
183 |
+
Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True)
|
184 |
+
for i in range(len(up_channels)-1)
|
185 |
+
])
|
186 |
+
|
187 |
+
self.self_attention = SelfAttention(down_channels[-1])
|
188 |
+
|
189 |
+
|
190 |
+
## readout layer
|
191 |
+
self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
|
192 |
+
|
193 |
+
|
194 |
+
# def settimestep()
|
195 |
+
|
196 |
+
def forward(self, x, timestep):
|
197 |
+
self.num_timesteps = timestep
|
198 |
+
t = self.time_mlp(timestep)
|
199 |
+
x = self.conv0(x)
|
200 |
+
residual_inputs = []
|
201 |
+
for down in self.down_blocks:
|
202 |
+
x = down(x, t)
|
203 |
+
residual_inputs.append(x)
|
204 |
+
|
205 |
+
x = self.self_attention(x)
|
206 |
+
|
207 |
+
for up in self.up_blocks:
|
208 |
+
residual_x = residual_inputs.pop()
|
209 |
+
x = torch.cat((x, residual_x), dim=1)
|
210 |
+
x = up(x, t)
|
211 |
+
return self.output(x)
|
212 |
+
|
213 |
+
@torch.no_grad()
|
214 |
+
def sample(self, noise):
|
215 |
+
"""
|
216 |
+
Generate an image by denoising a given noise tensor using the reverse diffusion process.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
noise (torch.Tensor): Initial noise tensor (e.g., sampled from a Gaussian distribution).
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
torch.Tensor: Denoised image.
|
223 |
+
"""
|
224 |
+
img = noise # Start with the provided noise tensor
|
225 |
+
T = self.num_timesteps # Total timesteps for diffusion
|
226 |
+
stepsize = 1 # You can adjust if needed
|
227 |
+
print(noise.device)
|
228 |
+
|
229 |
+
# Iterate through the timesteps in reverse order
|
230 |
+
for i in range(T - 1, -1, -1):
|
231 |
+
t = torch.full((noise.size(0),), i, device=noise.device, dtype=torch.long) # Current timestep
|
232 |
+
img = sample_timestep(self, img, t) # Perform one reverse diffusion step
|
233 |
+
img = torch.clamp(img, -1.0, 1.0) # Clamp the image to ensure values stay in [-1, 1]
|
234 |
+
|
235 |
+
return img
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
################################################################################################
|
240 |
+
#################### Convolutional Block Attention Module ADDED TO THE MODEL ###################
|
241 |
+
################################################################################################
|
242 |
+
|
243 |
+
class SimpleUnetWCBAM(nn.Module):
|
244 |
+
def __init__(self):
|
245 |
+
super().__init__()
|
246 |
+
image_channels = 3
|
247 |
+
down_channels = (64, 128, 256, 512, 1024)
|
248 |
+
up_channels = (1024, 512, 256, 128, 64)
|
249 |
+
|
250 |
+
out_dim = 3
|
251 |
+
time_emb_dim = 32
|
252 |
+
|
253 |
+
## timestep stored as positional encoding in terms of sine
|
254 |
+
self.time_mlp = nn.Sequential(
|
255 |
+
PositionEmbeddings(time_emb_dim),
|
256 |
+
nn.Linear(time_emb_dim, time_emb_dim),
|
257 |
+
nn.ReLU()
|
258 |
+
)
|
259 |
+
self.num_timesteps = 300
|
260 |
+
|
261 |
+
|
262 |
+
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
|
263 |
+
self.down_blocks = nn.ModuleList([
|
264 |
+
Block_CBAM(down_channels[i], down_channels[i+1], time_emb_dim)
|
265 |
+
for i in range(len(down_channels)-1)
|
266 |
+
])
|
267 |
+
self.up_blocks = nn.ModuleList([
|
268 |
+
Block_CBAM(up_channels[i], up_channels[i+1], time_emb_dim, up=True)
|
269 |
+
for i in range(len(up_channels)-1)
|
270 |
+
])
|
271 |
+
|
272 |
+
self.self_attention = SelfAttention(down_channels[-1])
|
273 |
+
|
274 |
+
|
275 |
+
## readout layer
|
276 |
+
self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
|
277 |
+
|
278 |
+
|
279 |
+
# def settimestep()
|
280 |
+
|
281 |
+
def forward(self, x, timestep):
|
282 |
+
self.num_timesteps = timestep
|
283 |
+
t = self.time_mlp(timestep)
|
284 |
+
x = self.conv0(x)
|
285 |
+
residual_inputs = []
|
286 |
+
for down in self.down_blocks:
|
287 |
+
x = down(x, t)
|
288 |
+
residual_inputs.append(x)
|
289 |
+
|
290 |
+
x = self.self_attention(x)
|
291 |
+
|
292 |
+
for up in self.up_blocks:
|
293 |
+
residual_x = residual_inputs.pop()
|
294 |
+
x = torch.cat((x, residual_x), dim=1)
|
295 |
+
x = up(x, t)
|
296 |
+
return self.output(x)
|
297 |
+
|
298 |
+
@torch.no_grad()
|
299 |
+
def sample(self, noise):
|
300 |
+
"""
|
301 |
+
Generate an image by denoising a given noise tensor using the reverse diffusion process.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
noise (torch.Tensor): Initial noise tensor (e.g., sampled from a Gaussian distribution).
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
torch.Tensor: Denoised image.
|
308 |
+
"""
|
309 |
+
img = noise # Start with the provided noise tensor
|
310 |
+
T = self.num_timesteps # Total timesteps for diffusion
|
311 |
+
stepsize = 1 # You can adjust if needed
|
312 |
+
print(noise.device)
|
313 |
+
|
314 |
+
# Iterate through the timesteps in reverse order
|
315 |
+
for i in range(T - 1, -1, -1):
|
316 |
+
t = torch.full((noise.size(0),), i, device=noise.device, dtype=torch.long) # Current timestep
|
317 |
+
img = sample_timestep(self, img, t) # Perform one reverse diffusion step
|
318 |
+
img = torch.clamp(img, -1.0, 1.0) # Clamp the image to ensure values stay in [-1, 1]
|
319 |
+
|
320 |
+
return img
|
model/precomputes.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
T = 300 ## according to the paper
|
5 |
+
|
6 |
+
### SOO MMANNYY PRECOMPUTEDD VALUESS TO TRACKKKK
|
7 |
+
betas = torch.linspace(1e-4, 0.02, T)
|
8 |
+
alphas = 1. - betas
|
9 |
+
alphas_cumulative_products = torch.cumprod(alphas, axis=0)
|
10 |
+
alphas_cumulative_products_prev = F.pad(alphas_cumulative_products[:-1], (1, 0), value=1.0)
|
11 |
+
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
|
12 |
+
sqrt_alphas_cumulative_products = torch.sqrt(alphas_cumulative_products)
|
13 |
+
sqrt_one_minus_alphas_cumulative_products = torch.sqrt(1. - alphas_cumulative_products)
|
14 |
+
posterior_variance = betas * (1. - alphas_cumulative_products_prev) / (1. - alphas_cumulative_products)
|
model/utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Scheduler
|
2 |
+
'''
|
3 |
+
sequentially add noise to images
|
4 |
+
precomputed values used
|
5 |
+
'''
|
6 |
+
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch
|
9 |
+
from precomputes import betas, sqrt_recip_alphas, sqrt_alphas_cumulative_products, sqrt_one_minus_alphas_cumulative_products, posterior_variance
|
10 |
+
# from model import model
|
11 |
+
from torchvision import transforms
|
12 |
+
from matplotlib import pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
def get_index_from_list(vals, t, x_shape):
|
16 |
+
batch_size = t.shape[0]
|
17 |
+
out = vals.gather(-1, t.cpu())
|
18 |
+
return out.reshape(batch_size, *((1,)* (len(x_shape) - 1))).to(t.device)
|
19 |
+
|
20 |
+
def forward_diffusion_sample(x_0, t, device="cpu"):
|
21 |
+
noise = torch.randn_like(x_0)
|
22 |
+
sqrt_alphas_cumulative_products_t = get_index_from_list(sqrt_alphas_cumulative_products, t, x_0.shape)
|
23 |
+
sqrt_one_minus_alphas_cumulative_products_t = get_index_from_list(
|
24 |
+
sqrt_one_minus_alphas_cumulative_products, t, x_0.shape
|
25 |
+
)
|
26 |
+
## formulae for image augged looks like sqrt(pi(alpha_t)) * x_t-1 * sqrt(pi(1-alpha_t)) * noise~N(0,1)
|
27 |
+
return sqrt_alphas_cumulative_products_t.to(device) * x_0.to(device) \
|
28 |
+
+ sqrt_one_minus_alphas_cumulative_products_t.to(device) * noise.to(device), noise.to(device)
|
29 |
+
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def sample_timestep(model, x, t):
|
33 |
+
"""
|
34 |
+
Calls the model to predict the noise in the image and returns
|
35 |
+
the denoised image.
|
36 |
+
Applies noise to this image, if we are not in the last step yet.
|
37 |
+
"""
|
38 |
+
betas_t = get_index_from_list(betas, t, x.shape)
|
39 |
+
sqrt_one_minus_alphas_cumulative_products_t = get_index_from_list(
|
40 |
+
sqrt_one_minus_alphas_cumulative_products, t, x.shape
|
41 |
+
)
|
42 |
+
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
|
43 |
+
|
44 |
+
# Call model (current image - noise prediction)
|
45 |
+
model_mean = sqrt_recip_alphas_t * (
|
46 |
+
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumulative_products_t
|
47 |
+
)
|
48 |
+
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
|
49 |
+
|
50 |
+
if t == 0:
|
51 |
+
return model_mean
|
52 |
+
else:
|
53 |
+
noise = torch.randn_like(x)
|
54 |
+
return model_mean + torch.sqrt(posterior_variance_t) * noise
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def sample_plot_image(model, IMG_SIZE=64, device="cpu", T=300):
|
58 |
+
|
59 |
+
# Sample noise
|
60 |
+
img_size = IMG_SIZE
|
61 |
+
img = torch.randn((1, 3, img_size, img_size), device=device)
|
62 |
+
plt.figure(figsize=(15,15))
|
63 |
+
plt.axis('off')
|
64 |
+
num_images = 10
|
65 |
+
stepsize = int(T/num_images)
|
66 |
+
|
67 |
+
for i in range(0,T)[::-1]:
|
68 |
+
t = torch.full((1,), i, device=device, dtype=torch.long)
|
69 |
+
img = sample_timestep(img, t)
|
70 |
+
# Edit: This is to maintain the natural range of the distribution
|
71 |
+
img = torch.clamp(img, -1.0, 1.0)
|
72 |
+
if i % stepsize == 0:
|
73 |
+
plt.subplot(1, num_images, int(i/stepsize)+1)
|
74 |
+
show_tensor_image(model, img.detach().cpu())
|
75 |
+
# plt.show()
|
76 |
+
return img
|
77 |
+
|
78 |
+
def show_tensor_image(image):
|
79 |
+
reverse_transforms = transforms.Compose([
|
80 |
+
transforms.Lambda(lambda t: (t + 1) / 2),
|
81 |
+
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
|
82 |
+
transforms.Lambda(lambda t: t * 255.),
|
83 |
+
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
|
84 |
+
transforms.ToPILImage(),
|
85 |
+
])
|
86 |
+
|
87 |
+
# Take first image of batch
|
88 |
+
if len(image.shape) == 4:
|
89 |
+
image = image[0, :, :, :]
|
90 |
+
plt.imshow(reverse_transforms(image))
|
notebooks/ddpm (1).ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/ddpm.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained_weights/sample_outputs/epoch5_1.png
ADDED
![]() |
pretrained_weights/sample_outputs/epochs5_1_cbam.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
math
|
3 |
+
torchvision
|
4 |
+
matplotlib
|
5 |
+
numpy
|
6 |
+
scikit-learn
|
simple_game/ddpm.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import pygame
|
2 |
+
# import random
|
3 |
+
|
4 |
+
# # Initialize Pygame
|
5 |
+
# pygame.init()
|
6 |
+
|
7 |
+
# # Screen dimensions
|
8 |
+
# WIDTH, HEIGHT = 800, 600
|
9 |
+
# screen = pygame.display.set_mode((WIDTH, HEIGHT))
|
10 |
+
# pygame.display.set_caption("DDPM Noise Game")
|
11 |
+
|
12 |
+
# # Colors
|
13 |
+
# WHITE = (255, 255, 255)
|
14 |
+
# BLACK = (0, 0, 0)
|
15 |
+
|
16 |
+
# # Load image
|
17 |
+
# image = pygame.image.load("estelle-peplum-top-tops-509.webp")
|
18 |
+
# image = pygame.transform.scale(image, (300, 300))
|
19 |
+
|
20 |
+
# # Generate noise patterns
|
21 |
+
# def generate_noise_pattern():
|
22 |
+
# pattern = pygame.Surface((300, 300), pygame.SRCALPHA)
|
23 |
+
# for _ in range(50): # Draw random lines
|
24 |
+
# start = (random.randint(0, 300), random.randint(0, 300))
|
25 |
+
# end = (random.randint(0, 300), random.randint(0, 300))
|
26 |
+
# pygame.draw.line(pattern, BLACK, start, end, 2)
|
27 |
+
# return pattern
|
28 |
+
|
29 |
+
# noise_patterns = [generate_noise_pattern() for _ in range(20)]
|
30 |
+
|
31 |
+
# # Task 1: Identify the Noise
|
32 |
+
# def task1():
|
33 |
+
# noisy_image = image.copy()
|
34 |
+
# current_noise = random.choice(noise_patterns)
|
35 |
+
# noisy_image.blit(current_noise, (0, 0))
|
36 |
+
|
37 |
+
# screen.fill(WHITE)
|
38 |
+
# screen.blit(image, (50, 50))
|
39 |
+
# screen.blit(noisy_image, (400, 50))
|
40 |
+
# pygame.display.flip()
|
41 |
+
|
42 |
+
# # Wait for user input
|
43 |
+
# running = True
|
44 |
+
# while running:
|
45 |
+
# for event in pygame.event.get():
|
46 |
+
# if event.type == pygame.QUIT:
|
47 |
+
# pygame.quit()
|
48 |
+
# return
|
49 |
+
# if event.type == pygame.KEYDOWN:
|
50 |
+
# if event.key == pygame.K_1: # Example: User selects pattern 1
|
51 |
+
# print("You selected Pattern 1")
|
52 |
+
# running = False
|
53 |
+
|
54 |
+
# # Task 2: Memorize the Sequence
|
55 |
+
# def task2():
|
56 |
+
# sequence = random.sample(noise_patterns, 5) # Random sequence of 5 patterns
|
57 |
+
# for pattern in sequence:
|
58 |
+
# screen.fill(WHITE)
|
59 |
+
# screen.blit(image, (50, 50))
|
60 |
+
# # noisy_image = image.copy()
|
61 |
+
# # screen.blit(pattern, (0, 0))
|
62 |
+
# screen.set_alpha(0)
|
63 |
+
# screen.blit(pattern, (400, 50))
|
64 |
+
# pygame.display.flip()
|
65 |
+
# pygame.time.wait(1000) # Show each pattern for 1 second
|
66 |
+
|
67 |
+
# # Ask user to recall the sequence
|
68 |
+
# print("Recall the sequence of patterns!")
|
69 |
+
|
70 |
+
# # Main loop
|
71 |
+
# task1()
|
72 |
+
# task2()
|
73 |
+
# pygame.quit()
|
simple_game/estelle-peplum-top-tops-509.webp
ADDED
![]() |
simple_game/ideation.md
ADDED
File without changes
|
slides/notes/dwn5.png
ADDED
![]() |
slides/notes/dwn52.png
ADDED
![]() |
slides/notes/dwn53.png
ADDED
![]() |
usage/generate.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
TODO: Add code here for generating example results from the model
|
3 |
+
'''
|
usage/train.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
TODO: Add the training script here
|
3 |
+
'''
|