Replace max_batch_size with batch_size for HybridCache

#3
by pbaylies - opened
Files changed (1) hide show
  1. modeling_ovis.py +3 -3
modeling_ovis.py CHANGED
@@ -552,14 +552,14 @@ class Ovis(OvisPreTrainedModel):
552
  self.get_text_tokenizer().save_pretrained(save_directory)
553
  self.get_visual_tokenizer().get_image_processor().save_pretrained(save_directory)
554
 
555
- def _get_hybrid_cache_for_llm(self, max_batch_size: int, max_cache_len: int):
556
  cache_cls = HybridCache
557
  llm = self.get_llm()
558
 
559
  need_new_cache = (
560
  not hasattr(llm, "_cache")
561
  or (not isinstance(llm._cache, cache_cls))
562
- or llm._cache.max_batch_size != max_batch_size
563
  or llm._cache.max_cache_len < max_cache_len
564
  )
565
 
@@ -570,7 +570,7 @@ class Ovis(OvisPreTrainedModel):
570
  cache_dtype = llm.dtype
571
  llm._cache = cache_cls(
572
  config=llm.config,
573
- max_batch_size=max_batch_size,
574
  max_cache_len=max_cache_len,
575
  device=llm.device,
576
  dtype=cache_dtype,
 
552
  self.get_text_tokenizer().save_pretrained(save_directory)
553
  self.get_visual_tokenizer().get_image_processor().save_pretrained(save_directory)
554
 
555
+ def _get_hybrid_cache_for_llm(self, batch_size: int, max_cache_len: int):
556
  cache_cls = HybridCache
557
  llm = self.get_llm()
558
 
559
  need_new_cache = (
560
  not hasattr(llm, "_cache")
561
  or (not isinstance(llm._cache, cache_cls))
562
+ or llm._cache.batch_size != batch_size
563
  or llm._cache.max_cache_len < max_cache_len
564
  )
565
 
 
570
  cache_dtype = llm.dtype
571
  llm._cache = cache_cls(
572
  config=llm.config,
573
+ batch_size=batch_size,
574
  max_cache_len=max_cache_len,
575
  device=llm.device,
576
  dtype=cache_dtype,