ikuinen99 commited on
Commit
3a29a17
2 Parent(s): d58d6b8 72960a7

Merge branch 'main' of hf.co:spaces/magicr/BuboGPT

Browse files
bubogpt/models/mm_gpt4.py CHANGED
@@ -276,7 +276,7 @@ class MMGPT4(BaseModel):
276
  with_bind_head = cfg.get("with_bind_head", False)
277
  freeze_llm = cfg.get("freeze_llm", True)
278
  use_blip_vision = cfg.get("use_blip_vision", False)
279
- proj_model = cfg.get("proj_model", "checkpoints/prerained_minigpt4_7b.pth")
280
 
281
  model = cls(
282
  joiner_cfg=joiner_cfg,
 
276
  with_bind_head = cfg.get("with_bind_head", False)
277
  freeze_llm = cfg.get("freeze_llm", True)
278
  use_blip_vision = cfg.get("use_blip_vision", False)
279
+ proj_model = cfg.get("proj_model", "")
280
 
281
  model = cls(
282
  joiner_cfg=joiner_cfg,
imagebind/models/image_bind.py CHANGED
@@ -656,8 +656,9 @@ def replace_joiner_vision(joiner, q_former_model, proj_model):
656
 
657
  joiner.modality_qformers[ModalityType.VISION].load_Qformer(q_former_model)
658
 
659
- state_dict = torch.load(proj_model, map_location="cpu")["model"]
660
- params = type(state_dict)()
661
- params["fc.weight"] = state_dict["llama_proj.weight"]
662
- params["fc.bias"] = state_dict["llama_proj.bias"]
663
- joiner.modality_post_projectors[ModalityType.VISION].load_state_dict(params, strict=False)
 
 
656
 
657
  joiner.modality_qformers[ModalityType.VISION].load_Qformer(q_former_model)
658
 
659
+ if proj_model:
660
+ state_dict = torch.load(proj_model, map_location="cpu")["model"]
661
+ params = type(state_dict)()
662
+ params["fc.weight"] = state_dict["llama_proj.weight"]
663
+ params["fc.bias"] = state_dict["llama_proj.bias"]
664
+ joiner.modality_post_projectors[ModalityType.VISION].load_state_dict(params, strict=False)