Spaces:
Sleeping
Sleeping
shuklaji9810
commited on
Commit
·
5e014de
1
Parent(s):
c329af9
first commit
Browse files- .gitignore +46 -0
- Examples/DeepFakes_10.png +0 -0
- Examples/DeepFakes_2.png +0 -0
- Examples/DeepFakes_4.png +0 -0
- Examples/DeepFakes_8.png +0 -0
- Examples/DeepFakes_9.png +0 -0
- Examples/SimSwap_8.png +0 -0
- Examples/StyleGAN_7.png +0 -0
- Examples/o_11.jpg +0 -0
- Examples/o_3.jpg +0 -0
- Examples/o_5.jpg +0 -0
- Examples/o_6.jpg +0 -0
- Examples/o_7.jpg +0 -0
- app.py +206 -0
- dataset/real_n_fake_dataloader.py +119 -0
- face_cropper.py +99 -0
- net/Multimodalmodel.py +41 -0
- test_image_fusion.py +182 -0
- utils/__init__.py +1 -0
- utils/basicblocks.py +32 -0
- utils/classifier.py +32 -0
- utils/config.py +38 -0
- utils/data_transforms.py +33 -0
- utils/feature_fusion_block.py +46 -0
- weights/faceswap-fft-best_model.pth +3 -0
- weights/faceswap-hh-best_model.pth +3 -0
.gitignore
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
working.ipynb
|
2 |
+
training.py
|
3 |
+
|
4 |
+
# Compiled source #
|
5 |
+
###################
|
6 |
+
*.com
|
7 |
+
*.class
|
8 |
+
*.dll
|
9 |
+
*.exe
|
10 |
+
*.o
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Packages #
|
14 |
+
############
|
15 |
+
# it's better to unpack these files and commit the raw source because
|
16 |
+
# git has its own built in compression methods
|
17 |
+
*.7z
|
18 |
+
*.dmg
|
19 |
+
*.gz
|
20 |
+
*.iso
|
21 |
+
*.jar
|
22 |
+
*.rar
|
23 |
+
*.tar
|
24 |
+
*.zip
|
25 |
+
|
26 |
+
# Logs and databases #
|
27 |
+
######################
|
28 |
+
*.log
|
29 |
+
*.sql
|
30 |
+
*.sqlite
|
31 |
+
|
32 |
+
# OS generated files #
|
33 |
+
######################
|
34 |
+
.DS_Store
|
35 |
+
.DS_Store?
|
36 |
+
._*
|
37 |
+
.Spotlight-V100
|
38 |
+
.Trashes
|
39 |
+
ehthumbs.db
|
40 |
+
Thumbs.db
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
___pycache__/
|
45 |
+
test_image.py
|
46 |
+
*.pyc
|
Examples/DeepFakes_10.png
ADDED
Examples/DeepFakes_2.png
ADDED
Examples/DeepFakes_4.png
ADDED
Examples/DeepFakes_8.png
ADDED
Examples/DeepFakes_9.png
ADDED
Examples/SimSwap_8.png
ADDED
Examples/StyleGAN_7.png
ADDED
Examples/o_11.jpg
ADDED
Examples/o_3.jpg
ADDED
Examples/o_5.jpg
ADDED
Examples/o_6.jpg
ADDED
Examples/o_7.jpg
ADDED
app.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from face_cropper import detect_and_label_faces
|
6 |
+
# Define a custom function to convert an image to grayscale
|
7 |
+
def to_grayscale(input_image):
|
8 |
+
grayscale_image = Image.fromarray(np.array(input_image).mean(axis=-1).astype(np.uint8))
|
9 |
+
return grayscale_image
|
10 |
+
|
11 |
+
|
12 |
+
description_markdown = """
|
13 |
+
# Fake Face Detection tool from TrustWorthy BiometraVision Lab IISER Bhopal
|
14 |
+
|
15 |
+
## Usage
|
16 |
+
This tool expects a face image as input. Upon submission, it will process the image and provide an output with bounding boxes drawn on the face. Alongside the visual markers, the tool will give a detection result indicating whether the face is fake or real.
|
17 |
+
|
18 |
+
## Disclaimer
|
19 |
+
Please note that this tool is for research purposes only and may not always be 100% accurate. Users are advised to exercise discretion and supervise the tool's usage accordingly.
|
20 |
+
|
21 |
+
## Licensing and Permissions
|
22 |
+
This tool has been developed solely for research and demonstrative purposes. Any commercial utilization of this tool is strictly prohibited unless explicit permission has been obtained from the developers.
|
23 |
+
|
24 |
+
## Developer Contact
|
25 |
+
For further inquiries or permissions, you can reach out to the developer through the following social media accounts:
|
26 |
+
- [LAB Webpage](https://sites.google.com/iiitd.ac.in/agarwalakshay/labiiserb?authuser=0)
|
27 |
+
- [LinkedIn](https://www.linkedin.com/in/shivam-shukla-0a50ab1a2/)
|
28 |
+
- [GitHub](https://github.com/SaShukla090)
|
29 |
+
"""
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
# Create the Gradio app
|
35 |
+
app = gr.Interface(
|
36 |
+
fn=detect_and_label_faces,
|
37 |
+
inputs=gr.Image(type="pil"),
|
38 |
+
outputs="image",
|
39 |
+
# examples=[
|
40 |
+
# "path_to_example_image_1.jpg",
|
41 |
+
# "path_to_example_image_2.jpg"
|
42 |
+
# ]
|
43 |
+
examples=[
|
44 |
+
os.path.join("Examples", image_name) for image_name in os.listdir("Examples")
|
45 |
+
],
|
46 |
+
title="Fake Face Detection",
|
47 |
+
description=description_markdown,
|
48 |
+
)
|
49 |
+
|
50 |
+
# Run the app
|
51 |
+
app.launch()
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
# import torch.nn.functional as F
|
86 |
+
# import torch
|
87 |
+
# import torch.nn as nn
|
88 |
+
# import torch.optim as optim
|
89 |
+
# from torch.utils.data import DataLoader
|
90 |
+
# from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
91 |
+
# from torch.optim.lr_scheduler import CosineAnnealingLR
|
92 |
+
# from tqdm import tqdm
|
93 |
+
# import warnings
|
94 |
+
# warnings.filterwarnings("ignore")
|
95 |
+
|
96 |
+
# from utils.config import cfg
|
97 |
+
# from dataset.real_n_fake_dataloader import Extracted_Frames_Dataset
|
98 |
+
# from utils.data_transforms import get_transforms_train, get_transforms_val
|
99 |
+
# from net.Multimodalmodel import Image_n_DCT
|
100 |
+
# import gradio as gr
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
# import os
|
106 |
+
# import json
|
107 |
+
# import torch
|
108 |
+
# from torchvision import transforms
|
109 |
+
# from torch.utils.data import DataLoader, Dataset
|
110 |
+
# from PIL import Image
|
111 |
+
# import numpy as np
|
112 |
+
# import pandas as pd
|
113 |
+
# import cv2
|
114 |
+
# import argparse
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
# from sklearn.metrics import classification_report, confusion_matrix
|
122 |
+
# import matplotlib.pyplot as plt
|
123 |
+
# import seaborn as sns
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
# class Test_Dataset(Dataset):
|
130 |
+
# def __init__(self, test_data_path = None, transform = None, image = None):
|
131 |
+
# """
|
132 |
+
# Args:
|
133 |
+
# returns:
|
134 |
+
# """
|
135 |
+
|
136 |
+
# if test_data_path is None and image is not None:
|
137 |
+
# self.dataset = [(image, 2)]
|
138 |
+
# self.transform = transform
|
139 |
+
|
140 |
+
# def __len__(self):
|
141 |
+
# return len(self.dataset)
|
142 |
+
|
143 |
+
# def __getitem__(self, idx):
|
144 |
+
# sample_input = self.get_sample_input(idx)
|
145 |
+
# return sample_input
|
146 |
+
|
147 |
+
|
148 |
+
# def get_sample_input(self, idx):
|
149 |
+
# rgb_image = self.get_rgb_image(self.dataset[idx][0])
|
150 |
+
# dct_image = self.compute_dct_color(self.dataset[idx][0])
|
151 |
+
# # label = self.get_label(idx)
|
152 |
+
# sample_input = {"rgb_image": rgb_image, "dct_image": dct_image}
|
153 |
+
|
154 |
+
# return sample_input
|
155 |
+
|
156 |
+
|
157 |
+
# def get_rgb_image(self, rgb_image):
|
158 |
+
# # rgb_image_path = self.dataset[idx][0]
|
159 |
+
# # rgb_image = Image.open(rgb_image_path)
|
160 |
+
# if self.transform:
|
161 |
+
# rgb_image = self.transform(rgb_image)
|
162 |
+
# return rgb_image
|
163 |
+
|
164 |
+
# def get_dct_image(self, idx):
|
165 |
+
# rgb_image_path = self.dataset[idx][0]
|
166 |
+
# rgb_image = cv2.imread(rgb_image_path)
|
167 |
+
# dct_image = self.compute_dct_color(rgb_image)
|
168 |
+
# if self.transform:
|
169 |
+
# dct_image = self.transform(dct_image)
|
170 |
+
|
171 |
+
# return dct_image
|
172 |
+
|
173 |
+
# def get_label(self, idx):
|
174 |
+
# return self.dataset[idx][1]
|
175 |
+
|
176 |
+
|
177 |
+
# def compute_dct_color(self, image):
|
178 |
+
# image_float = np.float32(image)
|
179 |
+
# dct_image = np.zeros_like(image_float)
|
180 |
+
# for i in range(3):
|
181 |
+
# dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
|
182 |
+
# if self.transform:
|
183 |
+
# dct_image = self.transform(dct_image)
|
184 |
+
# return dct_image
|
185 |
+
|
186 |
+
|
187 |
+
# device = torch.device("cpu")
|
188 |
+
# # print(device)
|
189 |
+
# model = Image_n_DCT()
|
190 |
+
# model.load_state_dict(torch.load('weights/best_model.pth', map_location = device))
|
191 |
+
# model.to(device)
|
192 |
+
# model.eval()
|
193 |
+
|
194 |
+
|
195 |
+
# def classify(image):
|
196 |
+
# test_dataset = Test_Dataset(transform = get_transforms_val(), image = image)
|
197 |
+
# inputs = test_dataset[0]
|
198 |
+
# rgb_image, dct_image = inputs['rgb_image'].to(device), inputs['dct_image'].to(device)
|
199 |
+
# output = model(rgb_image.unsqueeze(0), dct_image.unsqueeze(0))
|
200 |
+
# # _, predicted = torch.max(output.data, 1)
|
201 |
+
# # print(f"the face is {'real' if predicted==1 else 'fake'}")
|
202 |
+
# return {'Fake': output[0][0], 'Real': output[0][1]}
|
203 |
+
|
204 |
+
# iface = gr.Interface(fn=classify, inputs="image", outputs="label")
|
205 |
+
# if __name__ == "__main__":
|
206 |
+
# iface.launch()
|
dataset/real_n_fake_dataloader.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# We will use this file to create a dataloader for the real and fake dataset
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from torch.utils.data import DataLoader, Dataset
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
import pywt
|
16 |
+
|
17 |
+
class Extracted_Frames_Dataset(Dataset):
|
18 |
+
def __init__(self, root_dir, split = "train", transform = None, extend = 'None', multi_modal = "dct"):
|
19 |
+
"""
|
20 |
+
Args:
|
21 |
+
returns:
|
22 |
+
"""
|
23 |
+
AssertionError(split in ["train", "val", "test"]), "Split must be one of (train, val, test)"
|
24 |
+
self.multi_modal = multi_modal
|
25 |
+
self.root_dir = root_dir
|
26 |
+
self.split = split
|
27 |
+
self.transform = transform
|
28 |
+
if extend == 'faceswap':
|
29 |
+
self.dataset = pd.read_csv(os.path.join(root_dir, f"faceswap_extended_{self.split}.csv"))
|
30 |
+
elif extend == 'fsgan':
|
31 |
+
self.dataset = pd.read_csv(os.path.join(root_dir, f"fsgan_extended_{self.split}.csv"))
|
32 |
+
else:
|
33 |
+
self.dataset = pd.read_csv(os.path.join(root_dir, f"{self.split}.csv"))
|
34 |
+
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.dataset)
|
38 |
+
|
39 |
+
def __getitem__(self, idx):
|
40 |
+
sample_input = self.get_sample_input(idx)
|
41 |
+
return sample_input
|
42 |
+
|
43 |
+
|
44 |
+
def get_sample_input(self, idx):
|
45 |
+
rgb_image = self.get_rgb_image(idx)
|
46 |
+
label = self.get_label(idx)
|
47 |
+
if self.multi_modal == "dct":
|
48 |
+
dct_image = self.get_dct_image(idx)
|
49 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}
|
50 |
+
|
51 |
+
# dct_image = self.get_dct_image(idx)
|
52 |
+
elif self.multi_modal == "fft":
|
53 |
+
fft_image = self.get_fft_image(idx)
|
54 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
|
55 |
+
elif self.multi_modal == "hh":
|
56 |
+
hh_image = self.get_hh_image(idx)
|
57 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
|
58 |
+
else:
|
59 |
+
AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")
|
60 |
+
|
61 |
+
return sample_input
|
62 |
+
|
63 |
+
|
64 |
+
def get_fft_image(self, idx):
|
65 |
+
gray_image_path = self.dataset.iloc[idx, 0]
|
66 |
+
gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
|
67 |
+
fft_image = self.compute_fft(gray_image)
|
68 |
+
if self.transform:
|
69 |
+
fft_image = self.transform(fft_image)
|
70 |
+
|
71 |
+
return fft_image
|
72 |
+
|
73 |
+
|
74 |
+
def compute_fft(self, image):
|
75 |
+
f = np.fft.fft2(image)
|
76 |
+
fshift = np.fft.fftshift(f)
|
77 |
+
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
|
78 |
+
return magnitude_spectrum
|
79 |
+
|
80 |
+
|
81 |
+
def get_hh_image(self, idx):
|
82 |
+
gray_image_path = self.dataset.iloc[idx, 0]
|
83 |
+
gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
|
84 |
+
hh_image = self.compute_hh(gray_image)
|
85 |
+
if self.transform:
|
86 |
+
hh_image = self.transform(hh_image)
|
87 |
+
return hh_image
|
88 |
+
|
89 |
+
def compute_hh(self, image):
|
90 |
+
coeffs2 = pywt.dwt2(image, 'haar')
|
91 |
+
LL, (LH, HL, HH) = coeffs2
|
92 |
+
return HH
|
93 |
+
|
94 |
+
def get_rgb_image(self, idx):
|
95 |
+
rgb_image_path = self.dataset.iloc[idx, 0]
|
96 |
+
rgb_image = Image.open(rgb_image_path)
|
97 |
+
if self.transform:
|
98 |
+
rgb_image = self.transform(rgb_image)
|
99 |
+
return rgb_image
|
100 |
+
|
101 |
+
def get_dct_image(self, idx):
|
102 |
+
rgb_image_path = self.dataset.iloc[idx, 0]
|
103 |
+
rgb_image = cv2.imread(rgb_image_path)
|
104 |
+
dct_image = self.compute_dct_color(rgb_image)
|
105 |
+
if self.transform:
|
106 |
+
dct_image = self.transform(dct_image)
|
107 |
+
|
108 |
+
return dct_image
|
109 |
+
|
110 |
+
def get_label(self, idx):
|
111 |
+
return self.dataset.iloc[idx, 1]
|
112 |
+
|
113 |
+
|
114 |
+
def compute_dct_color(self, image):
|
115 |
+
image_float = np.float32(image)
|
116 |
+
dct_image = np.zeros_like(image_float)
|
117 |
+
for i in range(3):
|
118 |
+
dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
|
119 |
+
return dct_image
|
face_cropper.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import mediapipe as mp
|
3 |
+
import os
|
4 |
+
from gradio_client import Client
|
5 |
+
from test_image_fusion import Test
|
6 |
+
from test_image_fusion import Test
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
import cv2
|
14 |
+
|
15 |
+
# client = Client("https://tbvl-real-and-fake-face-detection.hf.space/--replicas/40d41jxhhx/")
|
16 |
+
|
17 |
+
data = 'faceswap'
|
18 |
+
dct = 'fft'
|
19 |
+
|
20 |
+
|
21 |
+
testet = Test(model_paths = [f"weights/{data}-hh-best_model.pth",
|
22 |
+
f"weights/{data}-fft-best_model.pth"],
|
23 |
+
multi_modal = ['hh', 'fft'])
|
24 |
+
|
25 |
+
# Initialize MediaPipe Face Detection
|
26 |
+
mp_face_detection = mp.solutions.face_detection
|
27 |
+
mp_drawing = mp.solutions.drawing_utils
|
28 |
+
face_detection = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.35)
|
29 |
+
|
30 |
+
# Create a directory to save the cropped face images if it does not exist
|
31 |
+
save_dir = "cropped_faces"
|
32 |
+
os.makedirs(save_dir, exist_ok=True)
|
33 |
+
|
34 |
+
# def detect_and_label_faces(image_path):
|
35 |
+
|
36 |
+
|
37 |
+
# Function to crop faces from a video and save them as images
|
38 |
+
# def crop_faces_from_video(video_path):
|
39 |
+
# # Read the video
|
40 |
+
# cap = cv2.VideoCapture(video_path)
|
41 |
+
# frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
42 |
+
# frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
43 |
+
# fps = int(cap.get(cv2.CAP_PROP_FPS))
|
44 |
+
# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
45 |
+
|
46 |
+
# # Define the codec and create VideoWriter object
|
47 |
+
# out = cv2.VideoWriter(f'output_{real}_{data}_fusion.avi', cv2.VideoWriter_fourcc('M','J','P','G'), fps, (frame_width, frame_height))
|
48 |
+
|
49 |
+
# if not cap.isOpened():
|
50 |
+
# print("Error: Could not open video.")
|
51 |
+
# return
|
52 |
+
# Convert PIL Image to NumPy array for OpenCV
|
53 |
+
def pil_to_opencv(pil_image):
|
54 |
+
open_cv_image = np.array(pil_image)
|
55 |
+
# Convert RGB to BGR for OpenCV
|
56 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
57 |
+
return open_cv_image
|
58 |
+
|
59 |
+
# Convert OpenCV NumPy array to PIL Image
|
60 |
+
def opencv_to_pil(opencv_image):
|
61 |
+
# Convert BGR to RGB
|
62 |
+
pil_image = Image.fromarray(opencv_image[:, :, ::-1])
|
63 |
+
return pil_image
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
def detect_and_label_faces(frame):
|
69 |
+
frame = pil_to_opencv(frame)
|
70 |
+
|
71 |
+
|
72 |
+
print(type(frame))
|
73 |
+
# Convert the frame to RGB
|
74 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
75 |
+
# Perform face detection
|
76 |
+
results = face_detection.process(frame_rgb)
|
77 |
+
|
78 |
+
# If faces are detected, crop and save each face as an image
|
79 |
+
if results.detections:
|
80 |
+
for face_count,detection in enumerate(results.detections):
|
81 |
+
bboxC = detection.location_data.relative_bounding_box
|
82 |
+
ih, iw, _ = frame.shape
|
83 |
+
x, y, w, h = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih)
|
84 |
+
# Crop the face region and make sure the bounding box is within the frame dimensions
|
85 |
+
crop_img = frame[max(0, y):min(ih, y+h), max(0, x):min(iw, x+w)]
|
86 |
+
if crop_img.size > 0:
|
87 |
+
face_filename = os.path.join(save_dir, f'face_{face_count}.jpg')
|
88 |
+
cv2.imwrite(face_filename, crop_img)
|
89 |
+
|
90 |
+
label = testet.testimage(face_filename)
|
91 |
+
|
92 |
+
if os.path.exists(face_filename):
|
93 |
+
os.remove(face_filename)
|
94 |
+
|
95 |
+
color = (0, 0, 255) if label == 'fake' else (0, 255, 0)
|
96 |
+
cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
|
97 |
+
cv2.putText(frame, label, (x, y + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
|
98 |
+
return opencv_to_pil(frame)
|
99 |
+
|
net/Multimodalmodel.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from utils.config import cfg
|
5 |
+
from utils.basicblocks import BasicBlock
|
6 |
+
from utils.feature_fusion_block import DCT_Attention_Fusion_Conv
|
7 |
+
from utils.classifier import ClassifierModel
|
8 |
+
|
9 |
+
class Image_n_DCT(nn.Module):
|
10 |
+
def __init__(self,):
|
11 |
+
super(Image_n_DCT, self).__init__()
|
12 |
+
self.Img_Block = nn.ModuleList()
|
13 |
+
self.DCT_Block = nn.ModuleList()
|
14 |
+
self.RGB_n_DCT_Fusion = nn.ModuleList()
|
15 |
+
self.num_classes = len(cfg.CLASSES)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
for i in range(len(cfg.MULTIMODAL_FUSION.IMG_CHANNELS) - 1):
|
20 |
+
self.Img_Block.append(BasicBlock(cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i], cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1], stride=1))
|
21 |
+
self.DCT_Block.append(BasicBlock(cfg.MULTIMODAL_FUSION.DCT_CHANNELS[i], cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1], stride=1))
|
22 |
+
self.RGB_n_DCT_Fusion.append(DCT_Attention_Fusion_Conv(cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1]))
|
23 |
+
|
24 |
+
|
25 |
+
self.classifier = ClassifierModel(self.num_classes)
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def forward(self, rgb_image, dct_image):
|
30 |
+
image = [rgb_image]
|
31 |
+
dct_image = [dct_image]
|
32 |
+
|
33 |
+
for i in range(len(self.Img_Block)):
|
34 |
+
image.append(self.Img_Block[i](image[-1]))
|
35 |
+
dct_image.append(self.DCT_Block[i](dct_image[-1]))
|
36 |
+
image[-1] = self.RGB_n_DCT_Fusion[i](image[-1], dct_image[-1])
|
37 |
+
dct_image[-1] = image[-1]
|
38 |
+
out = self.classifier(image[-1])
|
39 |
+
|
40 |
+
return out
|
41 |
+
|
test_image_fusion.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
7 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
8 |
+
from tqdm import tqdm
|
9 |
+
import warnings
|
10 |
+
warnings.filterwarnings("ignore")
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import pywt
|
15 |
+
|
16 |
+
from utils.config import cfg
|
17 |
+
from dataset.real_n_fake_dataloader import Extracted_Frames_Dataset
|
18 |
+
from utils.data_transforms import get_transforms_train, get_transforms_val
|
19 |
+
from net.Multimodalmodel import Image_n_DCT
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
import os
|
24 |
+
import json
|
25 |
+
import torch
|
26 |
+
from torchvision import transforms
|
27 |
+
from torch.utils.data import DataLoader, Dataset
|
28 |
+
from PIL import Image
|
29 |
+
import numpy as np
|
30 |
+
import pandas as pd
|
31 |
+
import cv2
|
32 |
+
import argparse
|
33 |
+
|
34 |
+
class Test_Dataset(Dataset):
|
35 |
+
def __init__(self, test_data_path = None, transform = None, image_path = None, multi_modal = "dct"):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
returns:
|
39 |
+
"""
|
40 |
+
self.multi_modal = multi_modal
|
41 |
+
if test_data_path is None and image_path is not None:
|
42 |
+
self.dataset = [[image_path, 2]]
|
43 |
+
self.transform = transform
|
44 |
+
|
45 |
+
else:
|
46 |
+
self.transform = transform
|
47 |
+
|
48 |
+
self.real_data = os.listdir(test_data_path + "/real")
|
49 |
+
self.fake_data = os.listdir(test_data_path + "/fake")
|
50 |
+
self.dataset = []
|
51 |
+
for image in self.real_data:
|
52 |
+
self.dataset.append([test_data_path + "/real/" + image, 1])
|
53 |
+
|
54 |
+
for image in self.fake_data:
|
55 |
+
self.dataset.append([test_data_path + "/fake/" + image, 0])
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
return len(self.dataset)
|
59 |
+
|
60 |
+
def __getitem__(self, idx):
|
61 |
+
sample_input = self.get_sample_input(idx)
|
62 |
+
return sample_input
|
63 |
+
|
64 |
+
def get_sample_input(self, idx):
|
65 |
+
rgb_image = self.get_rgb_image(idx)
|
66 |
+
label = self.get_label(idx)
|
67 |
+
if self.multi_modal == "dct":
|
68 |
+
dct_image = self.get_dct_image(idx)
|
69 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}
|
70 |
+
|
71 |
+
# dct_image = self.get_dct_image(idx)
|
72 |
+
elif self.multi_modal == "fft":
|
73 |
+
fft_image = self.get_fft_image(idx)
|
74 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
|
75 |
+
elif self.multi_modal == "hh":
|
76 |
+
hh_image = self.get_hh_image(idx)
|
77 |
+
sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
|
78 |
+
else:
|
79 |
+
AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")
|
80 |
+
|
81 |
+
return sample_input
|
82 |
+
|
83 |
+
|
84 |
+
def get_fft_image(self, idx):
|
85 |
+
gray_image_path = self.dataset[idx][0]
|
86 |
+
gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
|
87 |
+
fft_image = self.compute_fft(gray_image)
|
88 |
+
if self.transform:
|
89 |
+
fft_image = self.transform(fft_image)
|
90 |
+
|
91 |
+
return fft_image
|
92 |
+
|
93 |
+
|
94 |
+
def compute_fft(self, image):
|
95 |
+
f = np.fft.fft2(image)
|
96 |
+
fshift = np.fft.fftshift(f)
|
97 |
+
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
|
98 |
+
return magnitude_spectrum
|
99 |
+
|
100 |
+
|
101 |
+
def get_hh_image(self, idx):
|
102 |
+
gray_image_path = self.dataset[idx][0]
|
103 |
+
gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
|
104 |
+
hh_image = self.compute_hh(gray_image)
|
105 |
+
if self.transform:
|
106 |
+
hh_image = self.transform(hh_image)
|
107 |
+
return hh_image
|
108 |
+
|
109 |
+
def compute_hh(self, image):
|
110 |
+
coeffs2 = pywt.dwt2(image, 'haar')
|
111 |
+
LL, (LH, HL, HH) = coeffs2
|
112 |
+
return HH
|
113 |
+
|
114 |
+
def get_rgb_image(self, idx):
|
115 |
+
rgb_image_path = self.dataset[idx][0]
|
116 |
+
rgb_image = Image.open(rgb_image_path)
|
117 |
+
if self.transform:
|
118 |
+
rgb_image = self.transform(rgb_image)
|
119 |
+
return rgb_image
|
120 |
+
|
121 |
+
def get_dct_image(self, idx):
|
122 |
+
rgb_image_path = self.dataset[idx][0]
|
123 |
+
rgb_image = cv2.imread(rgb_image_path)
|
124 |
+
dct_image = self.compute_dct_color(rgb_image)
|
125 |
+
if self.transform:
|
126 |
+
dct_image = self.transform(dct_image)
|
127 |
+
|
128 |
+
return dct_image
|
129 |
+
|
130 |
+
def get_label(self, idx):
|
131 |
+
return self.dataset[idx][1]
|
132 |
+
|
133 |
+
|
134 |
+
def compute_dct_color(self, image):
|
135 |
+
image_float = np.float32(image)
|
136 |
+
dct_image = np.zeros_like(image_float)
|
137 |
+
for i in range(3):
|
138 |
+
dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
|
139 |
+
return dct_image
|
140 |
+
|
141 |
+
|
142 |
+
class Test:
|
143 |
+
def __init__(self, model_paths = [ 'weights/faceswap-hh-best_model.pth',
|
144 |
+
'weights/faceswap-fft-best_model.pth',
|
145 |
+
],
|
146 |
+
multi_modal = ["hh","fct"]):
|
147 |
+
self.model_path = model_paths
|
148 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
149 |
+
print(self.device)
|
150 |
+
# Load the model
|
151 |
+
self.model1 = Image_n_DCT()
|
152 |
+
self.model1.load_state_dict(torch.load(self.model_path[0], map_location = self.device))
|
153 |
+
self.model1.to(self.device)
|
154 |
+
self.model1.eval()
|
155 |
+
|
156 |
+
self.model2 = Image_n_DCT()
|
157 |
+
self.model2.load_state_dict(torch.load(self.model_path[1], map_location = self.device))
|
158 |
+
self.model2.to(self.device)
|
159 |
+
self.model2.eval()
|
160 |
+
|
161 |
+
|
162 |
+
self.multi_modal = multi_modal
|
163 |
+
|
164 |
+
|
165 |
+
def testimage(self, image_path):
|
166 |
+
test_dataset1 = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal[0])
|
167 |
+
test_dataset2 = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal[1])
|
168 |
+
|
169 |
+
inputs1 = test_dataset1[0]
|
170 |
+
rgb_image1, dct_image1 = inputs1['rgb_image'].to(self.device), inputs1['dct_image'].to(self.device)
|
171 |
+
|
172 |
+
inputs2 = test_dataset2[0]
|
173 |
+
rgb_image2, dct_image2 = inputs2['rgb_image'].to(self.device), inputs2['dct_image'].to(self.device)
|
174 |
+
|
175 |
+
output1 = self.model1(rgb_image1.unsqueeze(0), dct_image1.unsqueeze(0))
|
176 |
+
|
177 |
+
output2 = self.model2(rgb_image2.unsqueeze(0), dct_image2.unsqueeze(0))
|
178 |
+
|
179 |
+
output = (output1 + output2)/2
|
180 |
+
# print(output.shape)
|
181 |
+
_, predicted = torch.max(output.data, 1)
|
182 |
+
return 'real' if predicted==1 else 'fake'
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
import os
|
utils/basicblocks.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
BatchNorm2d = nn.BatchNorm2d
|
7 |
+
|
8 |
+
def conv3x3(in_planes, out_planes, stride = 1):
|
9 |
+
"""3x3 convolution with padding"""
|
10 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size = 3, stride = stride,
|
11 |
+
padding = 1, bias = False)
|
12 |
+
|
13 |
+
def conv1x1(in_planes, out_planes, stride = 1):
|
14 |
+
"""3x3 convolution with padding"""
|
15 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size = 1, stride = stride,
|
16 |
+
padding = 0, bias = False)
|
17 |
+
|
18 |
+
class BasicBlock(nn.Module):
|
19 |
+
def __init__(self, inplanes, outplanes, stride = 1):
|
20 |
+
super(BasicBlock, self).__init__()
|
21 |
+
self.conv1 = conv3x3(inplanes, outplanes, stride)
|
22 |
+
self.bn1 = BatchNorm2d(outplanes)
|
23 |
+
self.relu = nn.ReLU(inplace = True)
|
24 |
+
self.conv2 = conv3x3(outplanes, outplanes, 2*stride)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
out = self.conv1(x)
|
28 |
+
out = self.bn1(out)
|
29 |
+
out = self.relu(out)
|
30 |
+
out = self.conv2(out)
|
31 |
+
|
32 |
+
return out
|
utils/classifier.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class ClassifierModel(nn.Module):
|
6 |
+
def __init__(self, num_classes):
|
7 |
+
super(ClassifierModel, self).__init__()
|
8 |
+
# Apply adaptive average pooling to convert (512, 14, 14) to (512)
|
9 |
+
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
|
10 |
+
|
11 |
+
# Define multiple fully connected layers
|
12 |
+
self.fc1 = nn.Linear(512, 256) # First FC layer, reducing to 256 features
|
13 |
+
self.fc2 = nn.Linear(256, 128) # Second FC layer, reducing to 128 features
|
14 |
+
self.fc3 = nn.Linear(128, num_classes) # Final FC layer, outputting num_classes for classification
|
15 |
+
|
16 |
+
#dropout for regularization
|
17 |
+
self.dropout = nn.Dropout(0.2)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
# Flatten the output from the adaptive pooling
|
21 |
+
x = self.adaptive_pool(x)
|
22 |
+
x = torch.flatten(x, 1)
|
23 |
+
|
24 |
+
# Pass through the fully connected layers with ReLU activations and dropout
|
25 |
+
x = F.relu(self.fc1(x))
|
26 |
+
x = self.dropout(x)
|
27 |
+
x = F.relu(self.fc2(x))
|
28 |
+
x = self.dropout(x)
|
29 |
+
x = self.fc3(x) # No activation, raw scores
|
30 |
+
x = F.softmax(x, dim=1)
|
31 |
+
|
32 |
+
return x
|
utils/config.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
__C = edict()
|
5 |
+
cfg = __C
|
6 |
+
|
7 |
+
# 0. basic config
|
8 |
+
__C.TAG = 'default'
|
9 |
+
__C.CLASSES = ['Real', 'Fake']
|
10 |
+
|
11 |
+
|
12 |
+
# config of network input
|
13 |
+
__C.MULTIMODAL_FUSION = edict()
|
14 |
+
__C.MULTIMODAL_FUSION.IMG_CHANNELS = [3, 64, 128, 256, 512]
|
15 |
+
__C.MULTIMODAL_FUSION.DCT_CHANNELS = [1, 64, 128, 256, 512]
|
16 |
+
|
17 |
+
|
18 |
+
__C.NUM_EPOCHS = 100
|
19 |
+
|
20 |
+
__C.BATCH_SIZE = 64
|
21 |
+
|
22 |
+
__C.NUM_WORKERS = 4
|
23 |
+
|
24 |
+
__C.LEARNING_RATE = 0.0001
|
25 |
+
|
26 |
+
__C.PRETRAINED = False
|
27 |
+
|
28 |
+
__C.PRETRAINED_PATH = "/home/user/Documents/Real_and_DeepFake/src/best_model.pth"
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
__C.TEST_BATCH_SIZE = 512
|
34 |
+
|
35 |
+
__C.TEST_CSV = "/home/user/Documents/Real_and_DeepFake/src/dataset/extended_val.csv"
|
36 |
+
|
37 |
+
__C.MODEL_PATH = "/home/user/Documents/Real_and_DeepFake/src/best_model.pth"
|
38 |
+
|
utils/data_transforms.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
def get_transforms_train():
|
6 |
+
# Define the dataset object
|
7 |
+
transform = transform = transforms.Compose([
|
8 |
+
transforms.ToTensor(),
|
9 |
+
transforms.Lambda(lambda x: x.float()) ,
|
10 |
+
transforms.Resize((224, 224)),
|
11 |
+
transforms.RandomHorizontalFlip(),
|
12 |
+
transforms.RandomRotation(10),
|
13 |
+
transforms.Normalize(mean=[(0.485+0.456+0.406)/3], std=[(0.229+ 0.224+ 0.225)/3]),
|
14 |
+
])
|
15 |
+
|
16 |
+
return transform
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def get_transforms_val():
|
22 |
+
transform = transform = transforms.Compose([
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Lambda(lambda x: x.float()) ,
|
25 |
+
transforms.Resize((224, 224)),
|
26 |
+
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
27 |
+
transforms.Normalize(mean=[(0.485+0.456+0.406)/3], std=[(0.229+ 0.224+ 0.225)/3]),
|
28 |
+
|
29 |
+
|
30 |
+
])
|
31 |
+
|
32 |
+
|
33 |
+
return transform
|
utils/feature_fusion_block.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class SpatialAttention(nn.Module):
|
6 |
+
def __init__(self, in_channels):
|
7 |
+
super(SpatialAttention, self).__init__()
|
8 |
+
self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1, padding=0)
|
9 |
+
|
10 |
+
def forward(self, x):
|
11 |
+
# Calculate attention scores
|
12 |
+
attention_scores = self.conv1(x)
|
13 |
+
attention_scores = F.softmax(attention_scores, dim=2)
|
14 |
+
|
15 |
+
# Apply attention to input features
|
16 |
+
attended_features = x * attention_scores
|
17 |
+
|
18 |
+
return attended_features
|
19 |
+
|
20 |
+
class DCT_Attention_Fusion_Conv(nn.Module):
|
21 |
+
def __init__(self, channels):
|
22 |
+
super(DCT_Attention_Fusion_Conv, self).__init__()
|
23 |
+
self.rgb_attention = SpatialAttention(channels)
|
24 |
+
self.depth_attention = SpatialAttention(channels)
|
25 |
+
self.rgb_pooling = nn.AdaptiveAvgPool2d(1)
|
26 |
+
self.depth_pooling = nn.AdaptiveAvgPool2d(1)
|
27 |
+
|
28 |
+
def forward(self, rgb_features, DCT_features):
|
29 |
+
# Spatial attention for both modalities
|
30 |
+
rgb_attended_features = self.rgb_attention(rgb_features)
|
31 |
+
depth_attended_features = self.depth_attention(DCT_features)
|
32 |
+
|
33 |
+
# Adaptive pooling for both modalities
|
34 |
+
rgb_pooled = self.rgb_pooling(rgb_attended_features)
|
35 |
+
depth_pooled = self.depth_pooling(depth_attended_features)
|
36 |
+
|
37 |
+
# Upsample attended and pooled features to the original size
|
38 |
+
rgb_upsampled = F.interpolate(rgb_pooled, size=rgb_features.size()[2:], mode='bilinear', align_corners=False)
|
39 |
+
depth_upsampled = F.interpolate(depth_pooled, size=DCT_features.size()[2:], mode='bilinear', align_corners=False)
|
40 |
+
|
41 |
+
# Concatenate the upsampled features
|
42 |
+
fused_features = F.relu(rgb_upsampled+depth_upsampled)
|
43 |
+
# fused_features = fused_features.sum(dim=1)
|
44 |
+
|
45 |
+
return fused_features
|
46 |
+
|
weights/faceswap-fft-best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c42f82049bed6db4edb5e933ffe4ce6e3612e7fbf351c29327d9cfe81f8c5ff
|
3 |
+
size 38189260
|
weights/faceswap-hh-best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15272d1439ef629566cf43b3d4d1bc4f2091f3db1c0d0430038b56880c7ef385
|
3 |
+
size 38189178
|