byh711 commited on
Commit
b7f2204
1 Parent(s): 1617299

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -4,13 +4,27 @@ from PIL import Image
4
  import torch
5
  from peft import PeftModel
6
  import numpy as np
 
 
 
 
 
 
 
 
 
 
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  torch_dtype = torch.float32
10
 
11
  # Load the fine-tuned base model
12
- caption_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True, revision='refs/pr/6', torch_dtype=torch_dtype, attn_implementation="eager").to(device)
13
- model = AutoModelForCausalLM.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True, torch_dtype=torch_dtype, attn_implementation="eager").to(device)
 
 
 
 
14
  processor = AutoProcessor.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True)
15
  model.eval()
16
 
 
4
  import torch
5
  from peft import PeftModel
6
  import numpy as np
7
+ import os
8
+ from unittest.mock import patch
9
+ from transformers.dynamic_module_utils import get_imports
10
+
11
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
12
+ if not str(filename).endswith("modeling_florence2.py"):
13
+ return get_imports(filename)
14
+ imports = get_imports(filename)
15
+ imports.remove("flash_attn")
16
+ return imports
17
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  torch_dtype = torch.float32
20
 
21
  # Load the fine-tuned base model
22
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
23
+ caption_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True, revision='refs/pr/6', torch_dtype=torch_dtype).to(device)
24
+
25
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
26
+ model = AutoModelForCausalLM.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True, torch_dtype=torch_dtype).to(device)
27
+
28
  processor = AutoProcessor.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True)
29
  model.eval()
30