BioMike commited on
Commit
7476d14
1 Parent(s): 3b0c756

Upload 23 files

Browse files
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from interfaces import smiles2iupac, iupac2smiles, iupac2style, landing
3
+
4
+
5
+ demo = gr.TabbedInterface([landing, smiles2iupac, iupac2smiles, iupac2style],
6
+ ["Introduction", "SMILES-to-IUPAC", "IUPAC-to-SMILES", "IUPAC style prediction"],
7
+ title="ChemConverters 🧪🔬🧬👨🏻‍🔬",
8
+ theme=gr.themes.Base())
9
+
10
+ demo.launch(share=True)
article.html ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>ChemConverters App Description</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ margin: 30px;
11
+ line-height: 4;
12
+ }
13
+ .link-button {
14
+ display: inline-block;
15
+ margin: 50px 50px;
16
+ padding: 50px;
17
+ background-color: #007bff;
18
+ color: white;
19
+ text-decoration: none;
20
+ border-radius: 50px;
21
+ font-weight: bold;
22
+ }
23
+ .link-button:hover {
24
+ background-color: #0056b3;
25
+ }
26
+ </style>
27
+ </head>
28
+ <body>
29
+ <p>With ChemConverters, you can effortlessly:</p>
30
+ <ul>
31
+ <li>Convert SMILES strings to IUPAC names and vice versa 🔄</li>
32
+ <li>Choose your preferred IUPAC naming style: BASE, SYSTEMATIC, or TRADITIONAL 📚</li>
33
+ <li>Validate chemical naming with molecules fingerprints similarity for accuracy checks ✔️</li>
34
+ </ul>
35
+ <p>Developed by the brilliant minds at Knowladgator, this app showcases the abilities of our chemical transformer models. Whether you're working on a research project, studying for an exam, or just exploring the chemical universe, ChemConverters is your go-to tool. 🛠️</p>
36
+ <p>Remember, chemistry is not just about reactions; it's about connections. Let's build those connections together! 💫</p>
37
+ <!-- Links Section -->
38
+ <div>
39
+ <a href="https://www.knowledgator.com/" class="link-button" target="_blank">🔗Visit our Website 🔗 </a>
40
+ <a href="https://www.linkedin.com/company/knowledgator/" class="link-button" target="_blank">💼Follow on LinkedIn 💼 </a>
41
+ <a href="https://huggingface.co/knowledgator/" class="link-button" target="_blank">🤗Hugging Face Profile🤗</a>
42
+ </div>
43
+ </body>
44
+ </html>
interfaces/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .smiles2iupac import smiles2iupac
2
+ from .iupac2smiles import iupac2smiles
3
+ from .iupac2style import iupac2style
4
+ from .landing import landing
interfaces/iupac2smiles.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils import ChemicalConverter, validate_smiles2iupac, plot_mol
3
+
4
+ def convert(chemical_name, plot):
5
+ # Initialize the ChemicalConverter
6
+ converter = ChemicalConverter(mode="IUPAC2SMILES")
7
+ converted_name = ""
8
+ plot_image = None
9
+ converted_name = converter.convert(chemical_name)[6:]
10
+ if plot:
11
+ plot_image = plot_mol(converted_name)
12
+ return converted_name, plot_image
13
+
14
+
15
+ iupac2smiles = gr.Interface(
16
+ fn=convert,
17
+ allow_flagging='auto',
18
+ inputs=[
19
+ gr.Textbox(label="Enter your IUPAC name", placeholder="Enter IUPAC name here"),
20
+ gr.Checkbox(label="Plot molecule", value=True)
21
+ ],
22
+ outputs=[gr.Text(label="Converted Name"),
23
+ gr.Image(type='pil', label="Molecule Plot", height=170, width=890)],
24
+ examples=[
25
+ ["ethanol", True]
26
+ ],
27
+ )
interfaces/iupac2style.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils import ChemicalConverter, validate_smiles2iupac, plot_mol
3
+
4
+ def convert(chemical_name, plot):
5
+ # Initialize the ChemicalConverter
6
+ converter = ChemicalConverter(mode="IUPAC2SMILES")
7
+ converted_name = converter.convert(chemical_name)[:6]
8
+ styles = {"<SYST>": "SYSTEMATIC", "<TRAD>": "TRADITIONAL", "<BASE>": "BASE"}
9
+ return styles.get(converted_name, "")
10
+
11
+
12
+ iupac2style = gr.Interface(
13
+ fn=convert,
14
+ allow_flagging='auto',
15
+ inputs=[
16
+ gr.Textbox(label="Enter your IUPAC name", placeholder="Enter IUPAC name here"),
17
+ ],
18
+ outputs=[gr.Text(label="IUPAC style")],
19
+ examples=[
20
+ ["propan-2-yl 2-[4-(4-chlorophenyl)carbonylphenoxy]-2-methyl-propanoate"]
21
+ ],
22
+ )
interfaces/landing.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ with open('materials/introduction.html', 'r', encoding='utf-8') as file:
4
+ html_description = file.read()
5
+
6
+ landing = gr.HTML(html_description)
interfaces/smiles2iupac.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils import ChemicalConverter, validate_smiles2iupac, plot_mol
3
+
4
+ def convert(chemical_name, style, validate, plot):
5
+ # Initialize the ChemicalConverter
6
+ converter = ChemicalConverter(mode="SMILES2IUPAC")
7
+ converted_name = ""
8
+ validation_score = ""
9
+ plot_image = None
10
+ style_prefix = "<" + style[:4] + ">"
11
+ converted_name = converter.convert(style_prefix + chemical_name)
12
+ if validate:
13
+ validation_score = validate_smiles2iupac(chemical_name, converted_name)
14
+ if plot:
15
+ plot_image = plot_mol(chemical_name)
16
+ return converted_name, validation_score, plot_image
17
+
18
+ smiles2iupac = gr.Interface(
19
+ fn=convert,
20
+ allow_flagging='auto',
21
+ inputs=[
22
+ gr.Textbox(label="Enter your SMILES name", placeholder="Enter your SMILES name here"),
23
+ gr.Radio(
24
+ choices=["BASE", "SYSTEMATIC", "TRADITIONAL"],
25
+ label="Choose desired IUPAC style",
26
+ ),
27
+ gr.Checkbox(label="Validate with molecular similarity", value=False),
28
+ gr.Checkbox(label="Plot molecule", value=True)
29
+ ],
30
+ outputs=[gr.Text(label="Converted Name"),
31
+ gr.Text(label="Input-Target similarity score"),
32
+ gr.Image(type='pil', label="Molecule Plot", height=170, width=890)],
33
+ examples=[
34
+ ["CCO", "BASE", True, True]
35
+ ],
36
+ )
materials/introduction.html ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>ChemConverters App Description</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ margin: 10px;
11
+ line-height: 1.6;
12
+ }
13
+ .link-button {
14
+ display: inline-flex;
15
+ align-items: center;
16
+ justify-content: center;
17
+ margin: 10px;
18
+ padding: 10px;
19
+ background-color: white;
20
+ border: 1px solid grey; /* Added border to make the button visible against white background */
21
+ color: #007bff; /* Text color changed to make it visible against white background */
22
+ text-decoration: none;
23
+ border-radius: 10px;
24
+ text-align: center;
25
+ vertical-align: middle;
26
+ box-sizing: border-box;
27
+ }
28
+ .link-button:hover {
29
+ background-color: #c0dcfc;
30
+ }
31
+ .link-button img {
32
+ height: 30px;
33
+ width: auto;
34
+ display: block;
35
+ }
36
+ .links-container {
37
+ text-align: center; /* Center the container's content */
38
+ margin: auto; /* Auto margins for horizontal centering if necessary */
39
+ display: flex; /* Use flexbox */
40
+ justify-content: center; /* Center flex items horizontally */
41
+ flex-wrap: wrap; /* Allow items to wrap */
42
+ }
43
+ </style>
44
+ </head>
45
+ <body>
46
+ <h2>Welcome to ChemConverters! 🧪🔬</h2>
47
+ <h3>With ChemConverters, you can effortlessly:</h3>
48
+ <ol>
49
+ <li>Convert SMILES strings to IUPAC names and vice versa 🔄</li>
50
+ <li>Choose your preferred IUPAC naming style: BASE, SYSTEMATIC, or TRADITIONAL 📚</li>
51
+ <li>Validate chemical naming with molecules fingerprints similarity for accuracy checks ✔️</li>
52
+ </ol>
53
+ <h3>What is ChemConverters?</h3>
54
+ <p>ChemConverters serves as a foundational showcase of our technological capabilities within the chemical domain. The models deployed in this application represent our entry-level offerings, designed to provide a glimpse into the potential applications of our advanced solutions. For access to our comprehensive suite of larger and more precise models, we invite interested parties to engage directly with us. Developed by the brilliant minds at Knowladgator, this app showcases the abilities of our chemical transformer models. Whether you're working on a research project, studying for an exam, or just exploring the chemical universe, ChemConverters is your go-to tool 🛠.<p>
55
+ <h3>Models Availability</h3>
56
+ <p>All models used in the applications are available on <a href="https://huggingface.co/knowledgator/" target="_blank">our Hugging Face page</a>. For translating from SMILES to IUPAC, the <a href="https://huggingface.co/knowledgator/SMILES2IUPAC-canonical-base" target="_blank">knowledgator/SMILES2IUPAC-canonical-base</a> model was used. To translate from IUPAC to SMILES, the <a href="https://huggingface.co/knowledgator/IUPAC2SMILES-canonical-base" target="_blank">knowledgator/IUPAC2SMILES-canonical-base</a> model was used.</p>
57
+ <h3>Citation</h3>
58
+ <p>Coming soon</p>
59
+ <h3>Remember, chemistry is not just about reactions; it's about connections. Let's build those connections together! 💫</h3>
60
+ <!-- Links Section -->
61
+ <div class="links-container">
62
+ <a href="https://www.knowledgator.com/" class="link-button" target="_blank"><img src="https://assets-global.website-files.com/65902be8ba48a05dfdb73331/6590476fcc8e8f35b2332781_Group%201000002504%20(1).png" alt="Visit our website"></a>
63
+ <a href="https://www.linkedin.com/company/knowledgator/" class="link-button" target="_blank"><img src="https://www.edigitalagency.com.au/wp-content/uploads/Linkedin-logo-png.png" alt="Follow on LinkedIn"></a>
64
+ <a href="https://huggingface.co/knowledgator/" class="link-button" target="_blank"><img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-title.png" alt="Hugging Face Profile"></a>
65
+ </div>
66
+ </body>
67
+ </html>
modeling/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model import MT5ForConditionalGeneration
2
+ from .config import MT5Config
modeling/config.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MT5Config(PretrainedConfig):
4
+ r"""
5
+ This is the configuration class to store the configuration of a [`MT5Model`] or a [`TFMT5Model`]. It is used to
6
+ instantiate a mT5 model according to the specified arguments, defining the model architecture. Instantiating a
7
+ configuration with the defaults will yield a similar configuration to that of the mT5
8
+ [google/mt5-small](https://huggingface.co/google/mt5-small) architecture.
9
+
10
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
11
+ documentation from [`PretrainedConfig`] for more information.
12
+
13
+ Arguments:
14
+ vocab_size (`int`, *optional*, defaults to 250112):
15
+ Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
16
+ `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
17
+ d_model (`int`, *optional*, defaults to 512):
18
+ Size of the encoder layers and the pooler layer.
19
+ d_kv (`int`, *optional*, defaults to 64):
20
+ Size of the key, query, value projections per attention head. In the conventional context, it is typically expected that `d_kv` has to be equal to `d_model // num_heads`.
21
+ But in the architecture of mt5-small, `d_kv` is not equal to `d_model //num_heads`. The `inner_dim` of the projection layer will be defined as `num_heads * d_kv`.
22
+ d_ff (`int`, *optional*, defaults to 1024):
23
+ Size of the intermediate feed forward layer in each `T5Block`.
24
+ num_layers (`int`, *optional*, defaults to 8):
25
+ Number of hidden layers in the Transformer encoder.
26
+ num_decoder_layers (`int`, *optional*):
27
+ Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
28
+ num_heads (`int`, *optional*, defaults to 6):
29
+ Number of attention heads for each attention layer in the Transformer encoder.
30
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
31
+ The number of buckets to use for each attention layer.
32
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
33
+ The maximum distance of the longer sequences for the bucket separation.
34
+ dropout_rate (`float`, *optional*, defaults to 0.1):
35
+ The ratio for all dropout layers.
36
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
37
+ The dropout ratio for classifier.
38
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
39
+ The epsilon used by the layer normalization layers.
40
+ initializer_factor (`float`, *optional*, defaults to 1):
41
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
42
+ testing).
43
+ feed_forward_proj (`string`, *optional*, defaults to `"gated-gelu"`):
44
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`.
45
+ use_cache (`bool`, *optional*, defaults to `True`):
46
+ Whether or not the model should return the last key/values attentions (not used by all models).
47
+ """
48
+
49
+ model_type = "mt5"
50
+ keys_to_ignore_at_inference = ["past_key_values"]
51
+
52
+ def __init__(
53
+ self,
54
+ encoder_vocab_size=250112,
55
+ decoder_vocab_size=250112,
56
+ shared_embedding=False,
57
+ d_model=256,
58
+ d_kv=64,
59
+ d_ff=512,
60
+ num_layers=4,
61
+ num_decoder_layers=None,
62
+ num_heads=3,
63
+ relative_attention_num_buckets=32,
64
+ relative_attention_max_distance=128,
65
+ dropout_rate=0.1,
66
+ layer_norm_epsilon=1e-6,
67
+ initializer_factor=1.0,
68
+ feed_forward_proj="gated-gelu",
69
+ is_encoder_decoder=True,
70
+ use_cache=True,
71
+ tokenizer_class="ChemTokenizers.SMILES_IUPAC_FAST.FastTokenizer",
72
+ tie_word_embeddings=False,
73
+ pad_token_id=0,
74
+ eos_token_id=1,
75
+ decoder_start_token_id=2,
76
+ classifier_dropout=0.0,
77
+ **kwargs,
78
+ ):
79
+ super().__init__(
80
+ is_encoder_decoder=is_encoder_decoder,
81
+ tokenizer_class=tokenizer_class,
82
+ tie_word_embeddings=tie_word_embeddings,
83
+ pad_token_id=pad_token_id,
84
+ eos_token_id=eos_token_id,
85
+ decoder_start_token_id=decoder_start_token_id,
86
+ **kwargs,
87
+ )
88
+ self.encoder_vocab_size = encoder_vocab_size
89
+ self.decoder_vocab_size = decoder_vocab_size
90
+ self.shared_embedding = shared_embedding
91
+ self.d_model = d_model
92
+ self.d_kv = d_kv
93
+ self.d_ff = d_ff
94
+ self.num_layers = num_layers
95
+ self.num_decoder_layers = (
96
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
97
+ ) # default = symmetry
98
+ self.num_heads = num_heads
99
+ self.relative_attention_num_buckets = relative_attention_num_buckets
100
+ self.relative_attention_max_distance = relative_attention_max_distance
101
+ self.dropout_rate = dropout_rate
102
+ self.classifier_dropout = classifier_dropout
103
+ self.layer_norm_epsilon = layer_norm_epsilon
104
+ self.initializer_factor = initializer_factor
105
+ self.feed_forward_proj = feed_forward_proj
106
+ self.use_cache = use_cache
107
+
108
+ act_info = self.feed_forward_proj.split("-")
109
+ self.dense_act_fn = act_info[-1]
110
+ self.is_gated_act = act_info[0] == "gated"
111
+
112
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
113
+ raise ValueError(
114
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
115
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
116
+ "'gated-gelu' or 'relu'"
117
+ )
118
+
119
+ # for backwards compatibility
120
+ if feed_forward_proj == "gated-gelu":
121
+ self.dense_act_fn = "gelu_new"
122
+
123
+ @property
124
+ def hidden_size(self):
125
+ return self.d_model
126
+
127
+ @property
128
+ def num_attention_heads(self):
129
+ return self.num_heads
130
+
131
+ @property
132
+ def num_hidden_layers(self):
133
+ return self.num_layers
modeling/docstrings.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PARALLELIZE_DOCSTRING = r"""
2
+ This is an experimental feature and is a subject to change at a moment's notice.
3
+
4
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
5
+ it will evenly distribute blocks across all devices.
6
+
7
+ Args:
8
+ device_map (`Dict[int, list]`, optional, defaults to None):
9
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
10
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
11
+ have fewer attention modules mapped to it than other devices. For reference, the mt5 models have the
12
+ following number of attention modules:
13
+
14
+ - mt5-small: 6
15
+ - mt5-base: 12
16
+ - mt5-large: 24
17
+ - mt5-xl: 24
18
+ - mt5-xxl: 24
19
+
20
+ Example:
21
+
22
+ ```python
23
+ # Here is an example of a device map on a machine with 4 GPUs using mt5-xl, which has a total of 24 attention modules:
24
+ model = MT5ForConditionalGeneration.from_pretrained("mt5-xl")
25
+ device_map = {
26
+ 0: [0, 1, 2],
27
+ 1: [3, 4, 5, 6, 7, 8, 9],
28
+ 2: [10, 11, 12, 13, 14, 15, 16],
29
+ 3: [17, 18, 19, 20, 21, 22, 23],
30
+ }
31
+ model.parallelize(device_map)
32
+ ```
33
+ """
34
+ DEPARALLELIZE_DOCSTRING = r"""
35
+ Moves the model to cpu from a model parallel state.
36
+
37
+ Example:
38
+
39
+ ```python
40
+ # On a 4 GPU machine with mt5-xl:
41
+ model = MT5ForConditionalGeneration.from_pretrained("Mt5-xl")
42
+ device_map = {
43
+ 0: [0, 1, 2],
44
+ 1: [3, 4, 5, 6, 7, 8, 9],
45
+ 2: [10, 11, 12, 13, 14, 15, 16],
46
+ 3: [17, 18, 19, 20, 21, 22, 23],
47
+ }
48
+ model.parallelize(device_map) # Splits the model across several devices
49
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
50
+ ```
51
+ """
52
+
53
+ __HEAD_MASK_WARNING_MSG = """
54
+ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
55
+ `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
56
+ If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
57
+ num_heads)`.
58
+ """
59
+
60
+ MT5_START_DOCSTRING = r"""
61
+
62
+ The MT5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
63
+ Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
64
+ Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
65
+ text-to-text denoising generative setting.
66
+
67
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
68
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
69
+ etc.)
70
+
71
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
72
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
73
+ and behavior.
74
+
75
+ Parameters:
76
+ config ([`MT5Config`]): Model configuration class with all the parameters of the model.
77
+ Initializing with a config file does not load the weights associated with the model, only the
78
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
79
+ """
80
+
81
+ MT5_INPUTS_DOCSTRING = r"""
82
+ Args:
83
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
84
+ Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
85
+ should be able to pad the inputs on both the right and the left.
86
+
87
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
88
+ [`PreTrainedTokenizer.__call__`] for detail.
89
+
90
+ [What are input IDs?](../glossary#input-ids)
91
+
92
+ To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
93
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
94
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
95
+
96
+ - 1 for tokens that are **not masked**,
97
+ - 0 for tokens that are **masked**.
98
+
99
+ [What are attention masks?](../glossary#attention-mask)
100
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
101
+ Indices of decoder input sequence tokens in the vocabulary.
102
+
103
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
104
+ [`PreTrainedTokenizer.__call__`] for details.
105
+
106
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
107
+
108
+ MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
109
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
110
+
111
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5
112
+ Training](./mt5#training).
113
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
114
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
115
+ be used by default.
116
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
117
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
118
+ 1]`:
119
+
120
+ - 1 indicates the head is **not masked**,
121
+ - 0 indicates the head is **masked**.
122
+
123
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
124
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
125
+ 1]`:
126
+
127
+ - 1 indicates the head is **not masked**,
128
+ - 0 indicates the head is **masked**.
129
+
130
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
131
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
132
+ `[0, 1]`:
133
+
134
+ - 1 indicates the head is **not masked**,
135
+ - 0 indicates the head is **masked**.
136
+
137
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
138
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
139
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
140
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
141
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
142
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
143
+
144
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
145
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
146
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
147
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
148
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
149
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
150
+ model's internal embedding lookup matrix.
151
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
152
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
153
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
154
+ input (see `past_key_values`). This is useful if you want more control over how to convert
155
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
156
+
157
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
158
+ of `inputs_embeds`.
159
+
160
+ use_cache (`bool`, *optional*):
161
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
162
+ `past_key_values`).
163
+
164
+ output_attentions (`bool`, *optional*):
165
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
166
+ tensors for more detail.
167
+ output_hidden_states (`bool`, *optional*):
168
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
169
+ more detail.
170
+ return_dict (`bool`, *optional*):
171
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
172
+ """
173
+
174
+ MT5_ENCODER_INPUTS_DOCSTRING = r"""
175
+ Args:
176
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
177
+ Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
178
+ should be able to pad the inputs on both the right and the left.
179
+
180
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
181
+ [`PreTrainedTokenizer.__call__`] for detail.
182
+
183
+ To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
184
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
185
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
186
+
187
+ - 1 for tokens that are **not masked**,
188
+ - 0 for tokens that are **masked**.
189
+
190
+ [What are attention masks?](../glossary#attention-mask)
191
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
192
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
193
+
194
+ - 1 indicates the head is **not masked**,
195
+ - 0 indicates the head is **masked**.
196
+
197
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
198
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
199
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
200
+ model's internal embedding lookup matrix.
201
+ output_attentions (`bool`, *optional*):
202
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
203
+ tensors for more detail.
204
+ output_hidden_states (`bool`, *optional*):
205
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
206
+ more detail.
207
+ return_dict (`bool`, *optional*):
208
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
209
+ """
210
+
211
+ # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
212
+ __HEAD_MASK_WARNING_MSG = """
213
+ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
214
+ `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
215
+ If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
216
+ num_heads)`.
217
+ """
modeling/model.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import warnings
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ from transformers import MT5PreTrainedModel
7
+ from transformers.models.mt5 import MT5Stack
8
+ from transformers.modeling_outputs import Seq2SeqModelOutput,Seq2SeqLMOutput, BaseModelOutput
9
+ from transformers.utils import (
10
+ add_start_docstrings,
11
+ add_start_docstrings_to_model_forward,
12
+ logging,
13
+ replace_return_docstrings,
14
+ )
15
+
16
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
17
+
18
+ import torch
19
+ from torch import nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from .config import MT5Config
23
+ from .docstrings import (
24
+ PARALLELIZE_DOCSTRING,
25
+ DEPARALLELIZE_DOCSTRING,
26
+ __HEAD_MASK_WARNING_MSG,
27
+ MT5_START_DOCSTRING,
28
+ MT5_INPUTS_DOCSTRING,
29
+ )
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ _CONFIG_FOR_DOC = "MT5Config"
35
+ _CHECKPOINT_FOR_DOC = "mt5-small"
36
+
37
+
38
+ class MT5Model(MT5PreTrainedModel):
39
+ r"""
40
+ Examples:
41
+
42
+ ```python
43
+ >>> from transformers import MT5Model, AutoTokenizer
44
+
45
+ >>> model = MT5Model.from_pretrained("google/mt5-small")
46
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
47
+ >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
48
+ >>> summary = "Weiter Verhandlung in Syrien."
49
+ >>> inputs = tokenizer(article, return_tensors="pt")
50
+ >>> labels = tokenizer(text_target=summary, return_tensors="pt")
51
+
52
+ >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
53
+ >>> hidden_states = outputs.last_hidden_state
54
+ ```"""
55
+
56
+ model_type = "mt5"
57
+ config_class = MT5Config
58
+ _keys_to_ignore_on_load_missing = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
59
+ _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
60
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
61
+
62
+ # Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5
63
+ def __init__(self, config: MT5Config):
64
+ super().__init__(config)
65
+ self.encoder_embedding = nn.Embedding(config.encoder_vocab_size, config.d_model)
66
+ if config.shared_embedding:
67
+ self.decoder_embedding = self.encoder_embedding
68
+ else:
69
+ self.decoder_emebedding = nn.Embedding(config.decoder_vocab_size, config.d_model)
70
+
71
+ encoder_config = copy.deepcopy(config)
72
+ encoder_config.is_decoder = False
73
+ encoder_config.use_cache = False
74
+ encoder_config.is_encoder_decoder = False
75
+ self.encoder = MT5Stack(encoder_config, self.encoder_embedding)
76
+
77
+ decoder_config = copy.deepcopy(config)
78
+ decoder_config.is_decoder = True
79
+ decoder_config.is_encoder_decoder = False
80
+ decoder_config.num_layers = config.num_decoder_layers
81
+ self.decoder = MT5Stack(decoder_config, self.decoder_emebedding)
82
+
83
+ # Initialize weights and apply final processing
84
+ self.post_init()
85
+
86
+ # Model parallel
87
+ self.model_parallel = False
88
+ self.device_map = None
89
+
90
+ # Copied from transformers.models.t5.modeling_t5.T5Model.parallelize
91
+ def parallelize(self, device_map=None):
92
+ warnings.warn(
93
+ "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
94
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
95
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
96
+ " 0, 'encoder.block.1': 1, ...}",
97
+ FutureWarning,
98
+ )
99
+ self.device_map = (
100
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
101
+ if device_map is None
102
+ else device_map
103
+ )
104
+ assert_device_map(self.device_map, len(self.encoder.block))
105
+ self.encoder.parallelize(self.device_map)
106
+ self.decoder.parallelize(self.device_map)
107
+ self.model_parallel = True
108
+
109
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
110
+ # Copied from transformers.models.t5.modeling_t5.T5Model.deparallelize
111
+ def deparallelize(self):
112
+ warnings.warn(
113
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
114
+ FutureWarning,
115
+ )
116
+ self.encoder.deparallelize()
117
+ self.decoder.deparallelize()
118
+ self.encoder = self.encoder.to("cpu")
119
+ self.decoder = self.decoder.to("cpu")
120
+ self.model_parallel = False
121
+ self.device_map = None
122
+ torch.cuda.empty_cache()
123
+
124
+ # Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings
125
+ def get_input_embeddings(self):
126
+ return self.encoder_embedding
127
+
128
+ # Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings
129
+ def set_input_embeddings(self, new_embeddings):
130
+ self.encoder_embedding = new_embeddings
131
+ self.encoder.set_input_embeddings(new_embeddings)
132
+ self.decoder.set_input_embeddings(new_embeddings)
133
+
134
+ # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
135
+ def get_encoder(self):
136
+ return self.encoder
137
+
138
+ # Copied from transformers.models.t5.modeling_t5.T5Model.get_decoder
139
+ def get_decoder(self):
140
+ return self.decoder
141
+
142
+ # Copied from transformers.models.t5.modeling_t5.T5Model._prune_heads
143
+ def _prune_heads(self, heads_to_prune):
144
+ """
145
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
146
+ class PreTrainedModel
147
+ """
148
+ for layer, heads in heads_to_prune.items():
149
+ self.encoder.layer[layer].attention.prune_heads(heads)
150
+
151
+ @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
152
+ # Copied from transformers.models.t5.modeling_t5.T5Model.forward with T5->MT5, t5->mt5
153
+ def forward(
154
+ self,
155
+ input_ids: Optional[torch.LongTensor] = None,
156
+ attention_mask: Optional[torch.FloatTensor] = None,
157
+ decoder_input_ids: Optional[torch.LongTensor] = None,
158
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
159
+ head_mask: Optional[torch.FloatTensor] = None,
160
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
161
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
162
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
163
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
164
+ inputs_embeds: Optional[torch.Tensor] = None,
165
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
166
+ use_cache: Optional[bool] = None,
167
+ output_attentions: Optional[bool] = None,
168
+ output_hidden_states: Optional[bool] = None,
169
+ return_dict: Optional[bool] = None,
170
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
171
+ r"""
172
+ Returns:
173
+
174
+ Example:
175
+
176
+ ```python
177
+ >>> from transformers import AutoTokenizer, MT5Model
178
+
179
+ >>> tokenizer = AutoTokenizer.from_pretrained("mt5-small")
180
+ >>> model = MT5Model.from_pretrained("mt5-small")
181
+
182
+ >>> input_ids = tokenizer(
183
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
184
+ ... ).input_ids # Batch size 1
185
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
186
+
187
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for MT5Model.
188
+ >>> # This is not needed for torch's MT5ForConditionalGeneration as it does this internally using labels arg.
189
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
190
+
191
+ >>> # forward pass
192
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
193
+ >>> last_hidden_states = outputs.last_hidden_state
194
+ ```"""
195
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
196
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
197
+
198
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
199
+ if head_mask is not None and decoder_head_mask is None:
200
+ if self.config.num_layers == self.config.num_decoder_layers:
201
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
202
+ decoder_head_mask = head_mask
203
+
204
+ # Encode if needed (training, first prediction pass)
205
+ if encoder_outputs is None:
206
+ encoder_outputs = self.encoder(
207
+ input_ids=input_ids,
208
+ attention_mask=attention_mask,
209
+ inputs_embeds=inputs_embeds,
210
+ head_mask=head_mask,
211
+ output_attentions=output_attentions,
212
+ output_hidden_states=output_hidden_states,
213
+ return_dict=return_dict,
214
+ )
215
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
216
+ encoder_outputs = BaseModelOutput(
217
+ last_hidden_state=encoder_outputs[0],
218
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
219
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
220
+ )
221
+
222
+ hidden_states = encoder_outputs[0]
223
+
224
+ # Set device for model parallelism
225
+ if self.model_parallel:
226
+ torch.cuda.set_device(self.decoder.first_device)
227
+ hidden_states = hidden_states.to(self.decoder.first_device)
228
+ if decoder_input_ids is not None:
229
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
230
+ if attention_mask is not None:
231
+ attention_mask = attention_mask.to(self.decoder.first_device)
232
+ if decoder_attention_mask is not None:
233
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
234
+
235
+ # Decode
236
+ decoder_outputs = self.decoder(
237
+ input_ids=decoder_input_ids,
238
+ attention_mask=decoder_attention_mask,
239
+ inputs_embeds=decoder_inputs_embeds,
240
+ past_key_values=past_key_values,
241
+ encoder_hidden_states=hidden_states,
242
+ encoder_attention_mask=attention_mask,
243
+ head_mask=decoder_head_mask,
244
+ cross_attn_head_mask=cross_attn_head_mask,
245
+ use_cache=use_cache,
246
+ output_attentions=output_attentions,
247
+ output_hidden_states=output_hidden_states,
248
+ return_dict=return_dict,
249
+ )
250
+
251
+ if not return_dict:
252
+ return decoder_outputs + encoder_outputs
253
+
254
+ return Seq2SeqModelOutput(
255
+ last_hidden_state=decoder_outputs.last_hidden_state,
256
+ past_key_values=decoder_outputs.past_key_values,
257
+ decoder_hidden_states=decoder_outputs.hidden_states,
258
+ decoder_attentions=decoder_outputs.attentions,
259
+ cross_attentions=decoder_outputs.cross_attentions,
260
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
261
+ encoder_hidden_states=encoder_outputs.hidden_states,
262
+ encoder_attentions=encoder_outputs.attentions,
263
+ )
264
+
265
+
266
+ @add_start_docstrings("""MT5 Model with a `language modeling` head on top.""", MT5_START_DOCSTRING)
267
+ class MT5ForConditionalGeneration(MT5PreTrainedModel):
268
+ r"""
269
+ Examples:
270
+
271
+ ```python
272
+ >>> from transformers import MT5ForConditionalGeneration, AutoTokenizer
273
+
274
+ >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
275
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
276
+ >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
277
+ >>> summary = "Weiter Verhandlung in Syrien."
278
+ >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
279
+
280
+ >>> outputs = model(**inputs)
281
+ >>> loss = outputs.loss
282
+ ```"""
283
+
284
+ model_type = "mt5"
285
+ config_class = MT5Config
286
+ _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
287
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
288
+
289
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5
290
+ def __init__(self, config: MT5Config):
291
+ super().__init__(config)
292
+ self.model_dim = config.d_model
293
+
294
+ self.encoder_embedding = nn.Embedding(config.encoder_vocab_size, config.d_model)
295
+ if config.shared_embedding:
296
+ self.decoder_embedding = self.encoder_embedding
297
+ else:
298
+ self.decoder_emebedding = nn.Embedding(config.decoder_vocab_size, config.d_model)
299
+
300
+ encoder_config = copy.deepcopy(config)
301
+ encoder_config.is_decoder = False
302
+ encoder_config.use_cache = False
303
+ encoder_config.is_encoder_decoder = False
304
+ self.encoder = MT5Stack(encoder_config, self.encoder_embedding)
305
+
306
+ decoder_config = copy.deepcopy(config)
307
+ decoder_config.is_decoder = True
308
+ decoder_config.is_encoder_decoder = False
309
+ decoder_config.num_layers = config.num_decoder_layers
310
+ self.decoder = MT5Stack(decoder_config, self.decoder_emebedding)
311
+
312
+ self.lm_head = nn.Linear(config.d_model, config.decoder_vocab_size, bias=False)
313
+
314
+ # Initialize weights and apply final processing
315
+ self.post_init()
316
+
317
+ # Model parallel
318
+ self.model_parallel = False
319
+ self.device_map = None
320
+
321
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
322
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.parallelize
323
+ def parallelize(self, device_map=None):
324
+ warnings.warn(
325
+ "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
326
+ " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
327
+ " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
328
+ " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
329
+ FutureWarning,
330
+ )
331
+ self.device_map = (
332
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
333
+ if device_map is None
334
+ else device_map
335
+ )
336
+ assert_device_map(self.device_map, len(self.encoder.block))
337
+ self.encoder.parallelize(self.device_map)
338
+ self.decoder.parallelize(self.device_map)
339
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
340
+ self.model_parallel = True
341
+
342
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
343
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.deparallelize
344
+ def deparallelize(self):
345
+ warnings.warn(
346
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
347
+ FutureWarning,
348
+ )
349
+ self.encoder.deparallelize()
350
+ self.decoder.deparallelize()
351
+ self.encoder = self.encoder.to("cpu")
352
+ self.decoder = self.decoder.to("cpu")
353
+ self.lm_head = self.lm_head.to("cpu")
354
+ self.model_parallel = False
355
+ self.device_map = None
356
+ torch.cuda.empty_cache()
357
+
358
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_input_embeddings
359
+ def get_input_embeddings(self):
360
+ return self.encoder_embedding
361
+
362
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_input_embeddings
363
+ def set_input_embeddings(self, new_embeddings):
364
+ self.encoder_embedding = new_embeddings
365
+ self.encoder.set_input_embeddings(new_embeddings)
366
+ self.decoder.set_input_embeddings(new_embeddings)
367
+
368
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings
369
+ def set_output_embeddings(self, new_embeddings):
370
+ self.lm_head = new_embeddings
371
+
372
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_output_embeddings
373
+ def get_output_embeddings(self):
374
+ return self.lm_head
375
+
376
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder
377
+ def get_encoder(self):
378
+ return self.encoder
379
+
380
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_decoder
381
+ def get_decoder(self):
382
+ return self.decoder
383
+
384
+ @add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
385
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
386
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward with T5->MT5, t5->mt5
387
+ def forward(
388
+ self,
389
+ input_ids: Optional[torch.LongTensor] = None,
390
+ attention_mask: Optional[torch.FloatTensor] = None,
391
+ decoder_input_ids: Optional[torch.LongTensor] = None,
392
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
393
+ head_mask: Optional[torch.FloatTensor] = None,
394
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
395
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
396
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
397
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
398
+ inputs_embeds: Optional[torch.FloatTensor] = None,
399
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
400
+ labels: Optional[torch.LongTensor] = None,
401
+ use_cache: Optional[bool] = None,
402
+ output_attentions: Optional[bool] = None,
403
+ output_hidden_states: Optional[bool] = None,
404
+ return_dict: Optional[bool] = None,
405
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
406
+ r"""
407
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
408
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
409
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
410
+ labels in `[0, ..., config.vocab_size]`
411
+
412
+ Returns:
413
+
414
+ Examples:
415
+
416
+ ```python
417
+ >>> from transformers import AutoTokenizer, MT5ForConditionalGeneration
418
+
419
+ >>> tokenizer = AutoTokenizer.from_pretrained("mt5-small")
420
+ >>> model = MT5ForConditionalGeneration.from_pretrained("mt5-small")
421
+
422
+ >>> # training
423
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
424
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
425
+ >>> outputs = model(input_ids=input_ids, labels=labels)
426
+ >>> loss = outputs.loss
427
+ >>> logits = outputs.logits
428
+
429
+ >>> # inference
430
+ >>> input_ids = tokenizer(
431
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
432
+ ... ).input_ids # Batch size 1
433
+ >>> outputs = model.generate(input_ids)
434
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
435
+ >>> # studies have shown that owning a dog is good for you.
436
+ ```"""
437
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
438
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
439
+
440
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
441
+ if head_mask is not None and decoder_head_mask is None:
442
+ if self.config.num_layers == self.config.num_decoder_layers:
443
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
444
+ decoder_head_mask = head_mask
445
+
446
+ # Encode if needed (training, first prediction pass)
447
+ if encoder_outputs is None:
448
+ # Convert encoder inputs in embeddings if needed
449
+ encoder_outputs = self.encoder(
450
+ input_ids=input_ids,
451
+ attention_mask=attention_mask,
452
+ inputs_embeds=inputs_embeds,
453
+ head_mask=head_mask,
454
+ output_attentions=output_attentions,
455
+ output_hidden_states=output_hidden_states,
456
+ return_dict=return_dict,
457
+ )
458
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
459
+ encoder_outputs = BaseModelOutput(
460
+ last_hidden_state=encoder_outputs[0],
461
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
462
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
463
+ )
464
+
465
+ hidden_states = encoder_outputs[0]
466
+
467
+ if self.model_parallel:
468
+ torch.cuda.set_device(self.decoder.first_device)
469
+
470
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
471
+ # get decoder inputs from shifting lm labels to the right
472
+ decoder_input_ids = self._shift_right(labels)
473
+
474
+ # Set device for model parallelism
475
+ if self.model_parallel:
476
+ torch.cuda.set_device(self.decoder.first_device)
477
+ hidden_states = hidden_states.to(self.decoder.first_device)
478
+ if decoder_input_ids is not None:
479
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
480
+ if attention_mask is not None:
481
+ attention_mask = attention_mask.to(self.decoder.first_device)
482
+ if decoder_attention_mask is not None:
483
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
484
+
485
+ # Decode
486
+ decoder_outputs = self.decoder(
487
+ input_ids=decoder_input_ids,
488
+ attention_mask=decoder_attention_mask,
489
+ inputs_embeds=decoder_inputs_embeds,
490
+ past_key_values=past_key_values,
491
+ encoder_hidden_states=hidden_states,
492
+ encoder_attention_mask=attention_mask,
493
+ head_mask=decoder_head_mask,
494
+ cross_attn_head_mask=cross_attn_head_mask,
495
+ use_cache=use_cache,
496
+ output_attentions=output_attentions,
497
+ output_hidden_states=output_hidden_states,
498
+ return_dict=return_dict,
499
+ )
500
+
501
+ sequence_output = decoder_outputs[0]
502
+
503
+ # Set device for model parallelism
504
+ if self.model_parallel:
505
+ torch.cuda.set_device(self.encoder.first_device)
506
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
507
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
508
+
509
+ if self.config.tie_word_embeddings:
510
+ # Rescale output before projecting on vocab
511
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
512
+ sequence_output = sequence_output * (self.model_dim**-0.5)
513
+
514
+ lm_logits = self.lm_head(sequence_output)
515
+
516
+ loss = None
517
+ if labels is not None:
518
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
519
+ # move labels to correct device to enable PP
520
+ labels = labels.to(lm_logits.device)
521
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
522
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
523
+
524
+ if not return_dict:
525
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
526
+ return ((loss,) + output) if loss is not None else output
527
+
528
+ return Seq2SeqLMOutput(
529
+ loss=loss,
530
+ logits=lm_logits,
531
+ past_key_values=decoder_outputs.past_key_values,
532
+ decoder_hidden_states=decoder_outputs.hidden_states,
533
+ decoder_attentions=decoder_outputs.attentions,
534
+ cross_attentions=decoder_outputs.cross_attentions,
535
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
536
+ encoder_hidden_states=encoder_outputs.hidden_states,
537
+ encoder_attentions=encoder_outputs.attentions,
538
+ )
539
+
540
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation
541
+ def prepare_inputs_for_generation(
542
+ self,
543
+ input_ids,
544
+ past_key_values=None,
545
+ attention_mask=None,
546
+ head_mask=None,
547
+ decoder_head_mask=None,
548
+ decoder_attention_mask=None,
549
+ cross_attn_head_mask=None,
550
+ use_cache=None,
551
+ encoder_outputs=None,
552
+ **kwargs,
553
+ ):
554
+ # cut decoder_input_ids if past_key_values is used
555
+ if past_key_values is not None:
556
+ past_length = past_key_values[0][0].shape[2]
557
+
558
+ # Some generation methods already pass only the last input ID
559
+ if input_ids.shape[1] > past_length:
560
+ remove_prefix_length = past_length
561
+ else:
562
+ # Default to old behavior: keep only final ID
563
+ remove_prefix_length = input_ids.shape[1] - 1
564
+
565
+ input_ids = input_ids[:, remove_prefix_length:]
566
+
567
+ return {
568
+ "decoder_input_ids": input_ids,
569
+ "past_key_values": past_key_values,
570
+ "encoder_outputs": encoder_outputs,
571
+ "attention_mask": attention_mask,
572
+ "head_mask": head_mask,
573
+ "decoder_head_mask": decoder_head_mask,
574
+ "decoder_attention_mask": decoder_attention_mask,
575
+ "cross_attn_head_mask": cross_attn_head_mask,
576
+ "use_cache": use_cache,
577
+ }
578
+
579
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
580
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
581
+ return self._shift_right(labels)
582
+
583
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache
584
+ def _reorder_cache(self, past_key_values, beam_idx):
585
+ # if decoder past is not included in output
586
+ # speedy decoding is disabled and no need to reorder
587
+ if past_key_values is None:
588
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
589
+ return past_key_values
590
+
591
+ reordered_decoder_past = ()
592
+ for layer_past_states in past_key_values:
593
+ # get the correct batch idx from layer past batch dim
594
+ # batch dim of `past` is at 2nd position
595
+ reordered_layer_past_states = ()
596
+ for layer_past_state in layer_past_states:
597
+ # need to set correct `past` for each of the four key / value states
598
+ reordered_layer_past_states = reordered_layer_past_states + (
599
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
600
+ )
601
+
602
+ if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
603
+ raise ValueError(
604
+ f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
605
+ )
606
+ if len(reordered_layer_past_states) != len(layer_past_states):
607
+ raise ValueError(
608
+ f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
609
+ )
610
+
611
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
612
+ return reordered_decoder_past
models/IUPAC2SMILES/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MT5ForConditionalGeneration"
4
+ ],
5
+ "classifier_dropout": 0.0,
6
+ "d_ff": 512,
7
+ "d_kv": 64,
8
+ "d_model": 256,
9
+ "decoder_start_token_id": 2,
10
+ "decoder_vocab_size": 137,
11
+ "dense_act_fn": "gelu_new",
12
+ "dropout_rate": 0.1,
13
+ "encoder_vocab_size": 822,
14
+ "eos_token_id": 1,
15
+ "feed_forward_proj": "gated-gelu",
16
+ "initializer_factor": 1.0,
17
+ "is_encoder_decoder": true,
18
+ "is_gated_act": true,
19
+ "layer_norm_epsilon": 1e-06,
20
+ "model_type": "mt5",
21
+ "num_decoder_layers": 4,
22
+ "num_heads": 3,
23
+ "num_layers": 4,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "shared_embedding": false,
28
+ "tie_word_embeddings": false,
29
+ "tokenizer_class": "T5Tokenizer",
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.37.1",
32
+ "use_cache": true
33
+ }
models/IUPAC2SMILES/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 2,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.37.1"
7
+ }
models/IUPAC2SMILES/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1f38994ec986388a2f099652139d6a05b5981fb57bdf62361d4614f84ca07ed
3
+ size 23177168
models/SMILES2IUPAC/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MT5ForConditionalGeneration"
4
+ ],
5
+ "classifier_dropout": 0.0,
6
+ "d_ff": 512,
7
+ "d_kv": 64,
8
+ "d_model": 256,
9
+ "decoder_start_token_id": 2,
10
+ "decoder_vocab_size": 822,
11
+ "dense_act_fn": "gelu_new",
12
+ "dropout_rate": 0.1,
13
+ "encoder_vocab_size": 137,
14
+ "eos_token_id": 1,
15
+ "feed_forward_proj": "gated-gelu",
16
+ "initializer_factor": 1.0,
17
+ "is_encoder_decoder": true,
18
+ "is_gated_act": true,
19
+ "layer_norm_epsilon": 1e-06,
20
+ "model_type": "mt5",
21
+ "num_decoder_layers": 4,
22
+ "num_heads": 3,
23
+ "num_layers": 4,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "shared_embedding": false,
28
+ "tie_word_embeddings": false,
29
+ "tokenizer_class": "T5Tokenizer",
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.37.1",
32
+ "use_cache": true
33
+ }
models/SMILES2IUPAC/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 2,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.37.1"
7
+ }
models/SMILES2IUPAC/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4307a50d6b192a06bb81552d7cd6bcf6ac7ea6bb21d72ca4755e28d7d28655d2
3
+ size 23878608
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ rdkit
test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+
3
+ client = Client("https://knowledgator-chemicalconverters.hf.space/--replicas/ucig0/")
4
+ result = client.predict(
5
+ "CCO", # str in 'Enter your chemical name' Textbox component
6
+ "SMILES2IUPAC", # Literal['SMILES2IUPAC', 'IUPAC2SMILES', 'IUPAC style prediction'] in 'Choose method to convert chemical names' Radio component
7
+ "BASE", # Literal['BASE', 'SYSTEMATIC', 'TRADITIONAL'] in 'If SMILES to IUPAC, choose desired IUPAC style' Radio component
8
+ True, # bool in 'Validate with molecular similarity' Checkbox component
9
+ True, # bool in 'Plot molecule' Checkbox component
10
+ api_name="/predict"
11
+ )
12
+ print(result)
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .main_model import ChemicalConverter
2
+ from .rdkit_utils import validate_smiles2iupac, plot_mol
utils/main_model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modeling import MT5ForConditionalGeneration
2
+ from transformers import AutoTokenizer
3
+ import os
4
+
5
+
6
+ class ChemicalConverter:
7
+ def __init__(self, mode: str):
8
+ self.mode = mode
9
+ model_directory = os.path.abspath("models")
10
+ model_path = os.path.join(model_directory, mode)
11
+ if not os.path.exists(model_path):
12
+ raise ValueError(f"Model path does not exist: {model_path}")
13
+ self.model = MT5ForConditionalGeneration.from_pretrained(model_path)
14
+ self.smiles_tokenizer = AutoTokenizer.from_pretrained("BioMike/smiles")
15
+ self.iupac_tokenizer = AutoTokenizer.from_pretrained("BioMike/iupac")
16
+ self.smiles_max_len = 128
17
+ self.iupac_max_len = 156
18
+
19
+ def convert(self, input):
20
+ if self.mode == "SMILES2IUPAC":
21
+ tokenizer = self.smiles_tokenizer
22
+ reverse_tokenizer = self.iupac_tokenizer
23
+ max_length = self.smiles_max_len
24
+ else:
25
+ tokenizer = self.iupac_tokenizer
26
+ reverse_tokenizer = self.smiles_tokenizer
27
+ max_length = self.iupac_max_len
28
+
29
+ encoding = tokenizer(input,
30
+ return_tensors='pt',
31
+ padding="max_length",
32
+ truncation=True,
33
+ max_length=max_length)
34
+ # Move the input tensor to GPU
35
+ encoding = {key: value.to(self.model.device) for key, value in encoding.items()}
36
+
37
+ # Generate names
38
+ output = self.model.generate(input_ids=encoding['input_ids'],
39
+ attention_mask=encoding['attention_mask'],
40
+ max_new_tokens=156,
41
+ num_beams=1,
42
+ num_return_sequences=1)
43
+
44
+ # Decode names
45
+ output = [reverse_tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
46
+
47
+ return output[0]
utils/rdkit_utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import DataStructs, Chem
2
+ from rdkit.Chem import AllChem
3
+ from rdkit.Chem import Draw
4
+ from PIL import Image
5
+ import io
6
+ from .main_model import ChemicalConverter
7
+
8
+ def validate_smiles2iupac(input_smiles, predicted_iupac):
9
+ converter = ChemicalConverter(mode="IUPAC2SMILES")
10
+ predicted_smiles = converter.convert(predicted_iupac)
11
+
12
+ ms = [Chem.MolFromSmiles(input_smiles), Chem.MolFromSmiles(predicted_smiles[6:])]
13
+
14
+ if None in ms:
15
+ return None
16
+
17
+ fpgen = AllChem.GetRDKitFPGenerator()
18
+ fps = [fpgen.GetFingerprint(x) for x in ms]
19
+
20
+ return DataStructs.TanimotoSimilarity(fps[0], fps[1])
21
+
22
+ def plot_mol(smiles):
23
+ # Convert the SMILES string to an RDKit molecule object
24
+ mol = Chem.MolFromSmiles(smiles)
25
+
26
+ # Use RDKit to draw the molecule to an image, with original intended size
27
+ img = Draw.MolToImage(mol, size=(185, 185))
28
+
29
+ # Create a new, blank image with the desired final size (800x190 pixels) with a white background
30
+ final_img = Image.new('RGB', (890, 185), 'white')
31
+
32
+ # Calculate the position to paste the original image onto the blank image to keep it centered
33
+ left = (890 - 185) // 2
34
+ top = (185 - 185) // 2 # This will be zero in this case but included for clarity
35
+
36
+ # Paste the original image onto the blank image
37
+ final_img.paste(img, (left, top))
38
+
39
+ return final_img