Replace max_batch_size with batch_size for HybridCache
#3
by
pbaylies
- opened
- modeling_ovis.py +3 -3
modeling_ovis.py
CHANGED
@@ -552,14 +552,14 @@ class Ovis(OvisPreTrainedModel):
|
|
552 |
self.get_text_tokenizer().save_pretrained(save_directory)
|
553 |
self.get_visual_tokenizer().get_image_processor().save_pretrained(save_directory)
|
554 |
|
555 |
-
def _get_hybrid_cache_for_llm(self,
|
556 |
cache_cls = HybridCache
|
557 |
llm = self.get_llm()
|
558 |
|
559 |
need_new_cache = (
|
560 |
not hasattr(llm, "_cache")
|
561 |
or (not isinstance(llm._cache, cache_cls))
|
562 |
-
or llm._cache.
|
563 |
or llm._cache.max_cache_len < max_cache_len
|
564 |
)
|
565 |
|
@@ -570,7 +570,7 @@ class Ovis(OvisPreTrainedModel):
|
|
570 |
cache_dtype = llm.dtype
|
571 |
llm._cache = cache_cls(
|
572 |
config=llm.config,
|
573 |
-
|
574 |
max_cache_len=max_cache_len,
|
575 |
device=llm.device,
|
576 |
dtype=cache_dtype,
|
|
|
552 |
self.get_text_tokenizer().save_pretrained(save_directory)
|
553 |
self.get_visual_tokenizer().get_image_processor().save_pretrained(save_directory)
|
554 |
|
555 |
+
def _get_hybrid_cache_for_llm(self, batch_size: int, max_cache_len: int):
|
556 |
cache_cls = HybridCache
|
557 |
llm = self.get_llm()
|
558 |
|
559 |
need_new_cache = (
|
560 |
not hasattr(llm, "_cache")
|
561 |
or (not isinstance(llm._cache, cache_cls))
|
562 |
+
or llm._cache.batch_size != batch_size
|
563 |
or llm._cache.max_cache_len < max_cache_len
|
564 |
)
|
565 |
|
|
|
570 |
cache_dtype = llm.dtype
|
571 |
llm._cache = cache_cls(
|
572 |
config=llm.config,
|
573 |
+
batch_size=batch_size,
|
574 |
max_cache_len=max_cache_len,
|
575 |
device=llm.device,
|
576 |
dtype=cache_dtype,
|