distildire / custommodel.py
Yewon Lim
import argparse
import json
from time import perf_counter
from datetime import datetime
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.transforms.functional as TF
from torchvision.transforms import Compose, Resize, CenterCrop
from torchvision.io import decode_jpeg, encode_jpeg
from PIL import Image
import os
import torch
from utils.sanic_utils import *
import typing
import requests
import time # Import the time module
from guided_diffusion.compute_dire_eps import dire_get_first_step_noise, create_argparser
from networks.distill_model import DistilDIREOnlyEPS, DistilDIRE
from guided_diffusion.guided_diffusion.script_util import (
def download_file(input_path):
Download a file from a given URL and save it locally if input_path is a URL.
If input_path is a local file path and the file exists, skip the download.
:param input_path: The URL of the file to download or a local file path.
:return: The local filepath to the downloaded or existing file.
# Check if input_path is a URL
if input_path.startswith(('http://', 'https://')):
# Extract filename from the URL
# Splits the URL by '/' and get the last part
filename = input_path.split('/')[-1]
# Ensure the filename does not contain query parameters if present in the URL
# Splits the filename by '?' and get the first part
filename = filename.split('?')[0]
# put jpg extension if not present
if '.' not in filename:
filename += ".jpg"
# Define the local path where the file will be saved
local_filepath = os.path.join('.', filename)
# Check if file already exists locally
if os.path.isfile(local_filepath):
print(f"The file already exists locally: {local_filepath}")
return local_filepath
# Start timing the download
start_time = time.time()
# Send a GET request to the URL
response = requests.get(input_path, stream=True)
# Raise an exception if the request was unsuccessful
# Open the local file in write-binary mode
with open(local_filepath, 'wb') as file:
# Write the content of the response to the local file
for chunk in response.iter_content(chunk_size=8192):
# End timing the download
end_time = time.time()
# Calculate the download duration
download_duration = end_time - start_time
f"Downloaded file saved to {local_filepath} in {download_duration:.2f} seconds.")
# Assume input_path is a local file path
local_filepath = input_path
# Check if the specified local file exists
if not os.path.isfile(local_filepath):
raise FileNotFoundError(f"No such file: '{local_filepath}'")
print(f"Using existing file: {local_filepath}")
return local_filepath
class CustomModel:
"""Wrapper for a DIRE model."""
def __init__(self, net='DIRE', num_frames=15, ckpt='truemedia-global-scaled.pth'):
self.net = net
self.num_frames = num_frames
# self.model = DistilDIREOnlyEPS('cuda').to('cuda')
self.model = DistilDIRE('cuda').to('cuda')
self.trans = transforms.Compose((transforms.Resize(256, antialias=True), transforms.CenterCrop((256, 256)),))
args = create_argparser()
args['timestep_respacing'] = 'ddim20'
adm_model, diffusion = create_model_and_diffusion(**dict_parse(args, model_and_diffusion_defaults().keys()))
adm_model.load_state_dict(torch.load(args['model_path'], map_location="cpu"))
self.adm_model = adm_model
self.diffusion = diffusion
self.args = args
def _load_state_dict(self, ckpt):
print(f"Loading the model from {ckpt}...")
state_dict = torch.load(ckpt, map_location="cpu")['model']
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
print("The model is successfully loaded")
def _forward_dire_img(self, img_path, save_dire=True, thr=0.5):
img = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img)
img = self.trans(img).cuda() * 2 - 1
img = img.unsqueeze(0)
with torch.no_grad():
eps = dire_get_first_step_noise(img, self.adm_model, self.diffusion, self.args, "cuda")
# eps = eps.detach().cpu()
# ext = img_path.split('.')[-1]
# eps_path = img_path.replace(f".{ext}", ".pt")
# torch.save(eps, eps_path)
# eps = torch.load(eps_path, weights_only=True, mmap=True).cuda()
# os.remove(eps_path)
prob = self.model(img, eps)['logit'].sigmoid()
return {"df_probability": prob.median().item(), "prediction": real_or_fake_thres(prob.median().item(), thr)}
def predict(self, inputs: typing.Dict[str, str]) -> typing.Dict[str, str]:
file_path = inputs.get('file_path', None)
video_file = download_file(file_path)
if os.path.isfile(video_file):
if is_image(video_file):
print(f"{self.net} is being run.")
return self._forward_dire_img(video_file)
f"Invalid media file: {video_file}. Please provide a valid video/img file.")
return {"msg": f"Invalid media file: {video_file}. Please provide a valid video/img file."}
except Exception as e:
print(f"An error occurred: {str(e)}")
return {"msg": f"An error occurred: {str(e)}"}
print(f"The file {video_file} does not exist.")
return {"msg": f"The file {video_file} does not exist."}
def fetch(cls) -> None: