dghdgkl commited on
Commit
6a3f4a1
·
verified ·
1 Parent(s): 872a8d5

Create maindata

Browse files
Files changed (1) hide show
  1. maindata +47 -0
maindata ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip install torch diffusers transformers datasets wandb
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ # Define a basic U-Net style model (you can scale this up for an XL model)
7
+ class UNetModel(nn.Module):
8
+ def __init__(self, in_channels=3, out_channels=3, base_channels=64):
9
+ super(UNetModel, self).__init__()
10
+
11
+ # Downsample
12
+ self.enc1 = self.conv_block(in_channels, base_channels)
13
+ self.enc2 = self.conv_block(base_channels, base_channels * 2)
14
+ self.enc3 = self.conv_block(base_channels * 2, base_channels * 4)
15
+
16
+ # Middle
17
+ self.middle = self.conv_block(base_channels * 4, base_channels * 8)
18
+
19
+ # Upsample
20
+ self.dec3 = self.conv_block(base_channels * 8, base_channels * 4)
21
+ self.dec2 = self.conv_block(base_channels * 4, base_channels * 2)
22
+ self.dec1 = self.conv_block(base_channels * 2, out_channels)
23
+
24
+ def conv_block(self, in_channels, out_channels):
25
+ return nn.Sequential(
26
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
27
+ nn.ReLU(),
28
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
29
+ nn.ReLU(),
30
+ nn.MaxPool2d(2)
31
+ )
32
+
33
+ def forward(self, x):
34
+ # Encode (Downsample)
35
+ x1 = self.enc1(x)
36
+ x2 = self.enc2(x1)
37
+ x3 = self.enc3(x2)
38
+
39
+ # Middle block
40
+ x_middle = self.middle(x3)
41
+
42
+ # Decode (Upsample)
43
+ x3_dec = self.dec3(x_middle)
44
+ x2_dec = self.dec2(x3_dec + x3)
45
+ x1_dec = self.dec1(x2_dec + x2)
46
+
47
+ return x1_dec