koukyo1994 commited on
Commit
3b36933
1 Parent(s): 46619e2

upload LFQ implementation

Browse files
Files changed (1) hide show
  1. configuration_lfq_tokenizer.py +43 -0
configuration_lfq_tokenizer.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face compatible implementation of Open-MAGVIt2
3
+ Code reference: https://github.com/TencentARC/Open-MAGVIT2
4
+ """
5
+
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class EncoderDecoderConfig(PretrainedConfig):
11
+ model_type = "resnet_encoder_decoder"
12
+
13
+ def __init__(self, **kwargs):
14
+ super().__init__(**kwargs)
15
+ self.ch = kwargs.get("ch", 128)
16
+ self.in_channels = kwargs.get("in_channels", 3)
17
+ self.out_ch = kwargs.get("out_ch", 3)
18
+ self.z_channels = kwargs.get("z_channels", 18)
19
+ self.num_res_blocks = kwargs.get("num_res_blocks", 2)
20
+ self.ch_mult = kwargs.get("ch_mult", [1, 1, 2, 2, 4])
21
+
22
+
23
+ class QuantizerConfig(PretrainedConfig):
24
+ model_type = "lfq_quantizer"
25
+
26
+ def __init__(self, **kwargs):
27
+ super().__init__(**kwargs)
28
+ self.dim = kwargs.get("dim", 18)
29
+ self.codebook_size = kwargs.get("codebook_size", 262144)
30
+ self.batch_maximization_weight = kwargs.get("batch_maximization_weight", 1.0)
31
+ self.sample_minimization_weight = kwargs.get("sample_minimization_weight", 1.0)
32
+
33
+
34
+ class LFQTokenizerConfig(PretrainedConfig):
35
+ r"""
36
+ This is the configuration class to store the configuration of a :class:`~transform
37
+ """
38
+ model_type = "lfq_tokenizer"
39
+
40
+ def __init__(self, **kwargs):
41
+ super().__init__(**kwargs)
42
+ self.encoder_decoder_config = kwargs.get("encoder_decoder_config", EncoderDecoderConfig())
43
+ self.quantizer_config = kwargs.get("quantizer_config", QuantizerConfig())