byh711 commited on
Commit
4a77a5e
1 Parent(s): 35df19e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -9,11 +9,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  torch_dtype = torch.float32
10
 
11
  # Load the fine-tuned base model
12
- base_model = AutoModelForCausalLM.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True, torch_dtype=torch_dtype).to(device)
 
13
  processor = AutoProcessor.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True)
14
-
15
- # Load the LoRA weights
16
- model = PeftModel.from_pretrained(base_model, peft_model_path)
17
  model.eval()
18
 
19
  def caption_generate(task_prompt, text_input=None, image=None):
@@ -25,7 +23,7 @@ def caption_generate(task_prompt, text_input=None, image=None):
25
  else:
26
  prompt = task_prompt + text_input
27
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
28
- generated_ids = model.generate(
29
  input_ids=inputs["input_ids"],
30
  pixel_values=inputs["pixel_values"],
31
  max_new_tokens=1024,
@@ -52,7 +50,7 @@ def run_example(task_prompt, text_input=None, image=None):
52
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
53
  inputs = {k: v.to(torch_dtype) if v.is_floating_point() else v for k, v in inputs.items()}
54
 
55
- generated_ids = base_model.generate(
56
  input_ids=inputs["input_ids"],
57
  pixel_values=inputs["pixel_values"],
58
  max_new_tokens=1024,
 
9
  torch_dtype = torch.float32
10
 
11
  # Load the fine-tuned base model
12
+ caption_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True, revision='refs/pr/6', torch_dtype=torch_dtype).to(device)
13
+ model = AutoModelForCausalLM.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True, torch_dtype=torch_dtype).to(device)
14
  processor = AutoProcessor.from_pretrained('byh711/FLODA-deepfake', trust_remote_code=True)
 
 
 
15
  model.eval()
16
 
17
  def caption_generate(task_prompt, text_input=None, image=None):
 
23
  else:
24
  prompt = task_prompt + text_input
25
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
26
+ generated_ids = caption_model.generate(
27
  input_ids=inputs["input_ids"],
28
  pixel_values=inputs["pixel_values"],
29
  max_new_tokens=1024,
 
50
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
51
  inputs = {k: v.to(torch_dtype) if v.is_floating_point() else v for k, v in inputs.items()}
52
 
53
+ generated_ids = model.generate(
54
  input_ids=inputs["input_ids"],
55
  pixel_values=inputs["pixel_values"],
56
  max_new_tokens=1024,