Vivien Chappelier commited on
Commit
b2cfd5f
1 Parent(s): 73c438e

use new proxy model

Browse files
Files changed (2) hide show
  1. app.py +10 -4
  2. detector_calibration.safetensors +3 -0
app.py CHANGED
@@ -7,9 +7,11 @@ import numpy as np
7
 
8
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
 
10
- from transformers import AutoModel, BlipImageProcessor
11
  from diffusers import DiffusionPipeline, AutoencoderKL
12
  import torchvision.transforms as transforms
 
 
13
 
14
  from copy import deepcopy
15
  from collections import OrderedDict
@@ -45,8 +47,10 @@ class BZHStableSignatureDemo(object):
45
 
46
  # load the proxy detector
47
  self.detector_image_processor = BlipImageProcessor.from_pretrained("imatag/stable-signature-bzh-detector-resnet18")
48
- commit_hash = "584a7bc01dc0f02e53bf8b8b295717ed09ed7294"
49
- self.detector_model = AutoModel.from_pretrained("imatag/stable-signature-bzh-detector-resnet18", trust_remote_code=True, revision=commit_hash)
 
 
50
 
51
  def generate(self, mode, seed, prompt):
52
  generator = torch.Generator(device=device)
@@ -132,7 +136,9 @@ class BZHStableSignatureDemo(object):
132
  inputs = self.detector_image_processor(img, return_tensors="pt")
133
 
134
  with torch.no_grad():
135
- pvalue = torch.sigmoid(self.detector_model(**inputs).logits).item()
 
 
136
 
137
  return pvalue
138
 
 
7
 
8
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
 
10
+ from transformers import AutoModelForImageClassification, BlipImageProcessor
11
  from diffusers import DiffusionPipeline, AutoencoderKL
12
  import torchvision.transforms as transforms
13
+ from huggingface_hub import hf_hub_download
14
+ from safetensors import safe_open
15
 
16
  from copy import deepcopy
17
  from collections import OrderedDict
 
47
 
48
  # load the proxy detector
49
  self.detector_image_processor = BlipImageProcessor.from_pretrained("imatag/stable-signature-bzh-detector-resnet18")
50
+ self.detector_model = AutoModelForImageClassification.from_pretrained("imatag/stable-signature-bzh-detector-resnet18")
51
+ calibration = hf_hub_download("imatag/stable-signature-bzh-detector-resnet18", filename="calibration.safetensors")
52
+ with safe_open(calibration, framework="pt") as f:
53
+ self.calibration_logits = f.get_tensor("logits")
54
 
55
  def generate(self, mode, seed, prompt):
56
  generator = torch.Generator(device=device)
 
136
  inputs = self.detector_image_processor(img, return_tensors="pt")
137
 
138
  with torch.no_grad():
139
+ logit = self.detector_model(**inputs).logits[...,0]
140
+ pvalue = (1 + torch.searchsorted(self.calibration_logits, logit)) / self.calibration_logits.shape[0]
141
+ pvalue = pvalue.item()
142
 
143
  return pvalue
144
 
detector_calibration.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5146e98a98a65559fbdb8ccf8cfd9a18982e09e13aa16f3005a640639ca5b298
3
+ size 1999942