Spaces:
Build error
Build error
File size: 5,482 Bytes
d171496 2e6d026 d171496 bfe9ff5 d171496 2e6d026 d171496 c9d8858 d171496 2e6d026 d171496 2e6d026 b66f253 d171496 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import os, glob
from functools import partial
import glob
import torch
from torch import nn
from PIL import Image
import numpy as np
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class RuleCA(nn.Module):
def __init__(self, hidden_n=6, rule_channels=4, zero_w2=True, device=device):
super().__init__()
# The hard-coded filters:
self.filters = torch.stack([torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]]),
torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]),
torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]).T,
torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])]).to(device)
self.chn = 4
self.rule_channels = rule_channels
self.w1 = nn.Conv2d(4*4+rule_channels, hidden_n, 1).to(device)
self.relu = nn.ReLU()
self.w2 = nn.Conv2d(hidden_n, 4, 1, bias=False).to(device)
if zero_w2:
self.w2.weight.data.zero_()
self.device = device
def perchannel_conv(self, x, filters):
'''filters: [filter_n, h, w]'''
b, ch, h, w = x.shape
y = x.reshape(b*ch, 1, h, w)
y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
y = torch.nn.functional.conv2d(y, filters[:,None])
return y.reshape(b, -1, h, w)
def forward(self, x, rule=0, update_rate=0.5):
b, ch, xsz, ysz = x.shape
rule_grid = torch.zeros(b, self.rule_channels, xsz, ysz).to(self.device)
rule_grid[:,rule] = 1
y = self.perchannel_conv(x, self.filters) # Apply the filters
y = torch.cat([y, rule_grid], dim=1)
y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
b, c, h, w = y.shape
update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
return x+y*update_mask
def forward_w_rule_grid(self, x, rule_grid, update_rate=0.5):
y = self.perchannel_conv(x, self.filters) # Apply the filters
y = torch.cat([y, rule_grid], dim=1)
y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
b, c, h, w = y.shape
update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
return x+y*update_mask
def to_rgb(self, x):
# TODO: rename this to_rgb & explain
return x[...,:3,:,:]+0.5
def seed(self, n, sz=128):
"""Initializes n 'grids', size sz. In this case all 0s."""
return torch.zeros(n, self.chn, sz, sz).to(self.device)
def to_frames(video_file):
os.system('rm -r guide_frames;mkdir guide_frames')
os.system(f"ffmpeg -i {video_file} guide_frames/%04d.jpg")
def update(preset, enhance, scale2x, video_file):
# Load presets
ca = RuleCA(hidden_n=32, rule_channels=3)
ca_fn = ''
if preset == 'Glowing Crystals':
ca_fn = 'glowing_crystals.pt'
elif preset == 'Rainbow Diamonds':
ca_fn = 'rainbow_diamonds.pt'
elif preset == 'Dark Diamonds':
ca_fn = 'dark_diamonds.pt'
elif preset == 'Dragon Scales':
ca = RuleCA(hidden_n=16, rule_channels=3)
ca_fn = 'dragon_scales.pt'
ca.load_state_dict(torch.load(ca_fn, map_location=device))
# Get video frames
to_frames(video_file)
size=(426, 240)
vid_size = Image.open(f'guide_frames/0001.jpg').size
if vid_size[0]>vid_size[1]: # Change < to > if larger side should be capped at 256px
size = (256, int(256*(vid_size[1]/vid_size[0])))
else:
size = (int(256*(vid_size[0]/vid_size[1])), 256)
if scale2x:
size = (size[0]*2, size[1]*2)
# Starting grid
x = torch.zeros(1, 4, size[1], size[0]).to(ca.device)
os.system("rm -r steps;mkdir steps")
for i in range(2*len(glob.glob('guide_frames/*.jpg'))-1):
# load frame
im = Image.open(f'guide_frames/{i//2+1:04}.jpg').resize(size)
# make rule grid
rule_grid = torch.tensor(np.array(im)/255).permute(2, 0, 1).unsqueeze(0).to(ca.device)
if enhance:
rule_grid = rule_grid * 2 - 0.3 # Add * 2 - 0.3 to 'enhance' an effect
# Apply the updates
with torch.no_grad():
x = ca.forward_w_rule_grid(x, rule_grid.float())
if i%2==0:
img = ca.to_rgb(x).detach().cpu().clip(0, 1).squeeze().permute(1, 2, 0)
img = Image.fromarray(np.array(img*255).astype(np.uint8))
img.save(f'steps/{i//2:05}.jpeg')
# Write output video from saved frames
os.system("ffmpeg -y -v 0 -framerate 24 -i steps/%05d.jpeg video.mp4")
return 'video.mp4'
demo = gr.Blocks()
with demo:
gr.Markdown("Choose a preset below, upload a video and then click **Run** to see the output. Read [this report](https://wandb.ai/johnowhitaker/nca/reports/Fun-with-Neural-Cellular-Automata--VmlldzoyMDQ5Mjg0) for background on this project, or check out my [AI art course](https://github.com/johnowhitaker/aiaiart) for an in-depth lesson on Neural Cellular Automata like this.")
with gr.Row():
preset = gr.Dropdown(['Glowing Crystals', 'Rainbow Diamonds', 'Dark Diamonds', 'Dragon Scales'], label='Preset')
with gr.Column():
enhance = gr.Checkbox(label='Rescale inputs (more extreme results)')
scale2x = gr.Checkbox(label='Larger output (slower)')
with gr.Row():
inp = gr.Video(format='mp4', source='upload', label="Input video (ideally <30s)")
out = gr.Video(label="Output")
btn = gr.Button("Run")
btn.click(fn=update, inputs=[preset, enhance, scale2x, inp], outputs=out)
with gr.Row():
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=gradio-blocks_video_nca)")
demo.launch(enable_queue=True) |