Spaces:
Running
Running
# -*- coding:utf-8 -*- | |
# Rhizome | |
# Version beta 0.0, August 2023 | |
# Property of IBM Research, Accelerated Discovery | |
# | |
import os | |
import pickle | |
import sys | |
from rdkit import Chem | |
import torch | |
from torch_geometric.utils.smiles import from_smiles | |
from typing import Any, Dict, List, Optional, Union | |
from typing_extensions import Self | |
from .graph_grammar.io.smi import hg_to_mol | |
from .models.mhgvae import GrammarGINVAE | |
from huggingface_hub import hf_hub_download | |
class PretrainedModelWrapper: | |
model: GrammarGINVAE | |
def __init__(self, model_dict: Dict[str, Any]) -> None: | |
json_params = model_dict['gnn_params'] | |
encoder_params = json_params['encoder_params'] | |
encoder_params['node_feature_size'] = model_dict['num_features'] | |
encoder_params['edge_feature_size'] = model_dict['num_edge_features'] | |
self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params, | |
decoder_params=json_params['decoder_params'], | |
prod_rule_embed_params=json_params["prod_rule_embed_params"], | |
batch_size=512, max_len=model_dict['max_length']) | |
self.model.load_state_dict(model_dict['model_state_dict']) | |
self.model.eval() | |
def to(self, device: Union[str, int, torch.device]) -> Self: | |
dev_type = type(device) | |
if dev_type != torch.device: | |
if dev_type == str or torch.cuda.is_available(): | |
device = torch.device(device) | |
else: | |
device = torch.device("mps", device) | |
self.model = self.model.to(device) | |
return self | |
def encode(self, data: List[str]) -> List[torch.tensor]: | |
# Need to encode them into a graph nn | |
output = [] | |
for d in data: | |
params = next(self.model.parameters()) | |
g = from_smiles(d) | |
if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'): | |
g.to(params.device) | |
ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch) | |
output.append(ltvec[0]) | |
return output | |
def decode(self, data: List[torch.tensor]) -> List[str]: | |
output = [] | |
for d in data: | |
mu, logvar = self.model.get_mean_var(d.unsqueeze(0)) | |
z = self.model.reparameterize(mu, logvar) | |
flags, _, hgs = self.model.decode(z) | |
if flags[0]: | |
reconstructed_mol, _ = hg_to_mol(hgs[0], True) | |
output.append(Chem.MolToSmiles(reconstructed_mol)) | |
else: | |
output.append(None) | |
return output | |
def load(model_name: str = "mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[ | |
PretrainedModelWrapper]: | |
repo_id = "ibm/materials.mhg-ged" | |
filename = "pytorch_model.bin" #"mhggnn_pretrained_model_0724_2023.pickle" | |
file_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
with open(file_path, "rb") as f: | |
model_dict = torch.load(f) | |
return PretrainedModelWrapper(model_dict) | |
"""try: | |
if os.path.isfile(model_name): | |
with open(model_name, "rb") as f: | |
model_dict = pickle.load(f) | |
print("MHG Model Loaded") | |
return PretrainedModelWrapper(model_dict) | |
except: | |
for p in sys.path: | |
file = p + "/" + model_name | |
if os.path.isfile(file): | |
with open(file, "rb") as f: | |
model_dict = pickle.load(f) | |
return PretrainedModelWrapper(model_dict)""" | |
return None | |