EditGuard / utils /jpegtest.py
Ricoooo's picture
'folder'
5d21dd2
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import random, string
class JpegTest(nn.Module):
def __init__(self, Q=50, subsample=0, path="temp/"):
super(JpegTest, self).__init__()
self.Q = Q
self.subsample = subsample
self.path = path
if not os.path.exists(path): os.mkdir(path)
self.transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def get_path(self):
return self.path + ''.join(random.sample(string.ascii_letters + string.digits, 16)) + ".jpg"
def forward(self, image_cover_mask):
image = image_cover_mask
noised_image = torch.zeros_like(image)
for i in range(image.shape[0]):
single_image = ((image[i].clamp(0, 1).permute(1, 2, 0)) * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
im = Image.fromarray(single_image)
file = self.get_path()
while os.path.exists(file):
file = self.get_path()
im.save(file, format="JPEG", quality=self.Q, subsampling=self.subsample)
jpeg = np.array(Image.open(file), dtype=np.uint8)
os.remove(file)
noised_image[i] = self.transform(jpeg).unsqueeze(0).to(image.device)
return noised_image