Spaces:
Running
Running
import numpy as np | |
import os.path as osp | |
from nltk.translate.bleu_score import corpus_bleu | |
from rdkit import RDLogger | |
from Levenshtein import distance as lev | |
from rdkit import Chem | |
from rdkit.Chem import MACCSkeys | |
from rdkit import DataStructs | |
from rdkit.Chem import AllChem | |
from rdkit import DataStructs | |
RDLogger.DisableLog('rdApp.*') | |
from fcd import get_fcd, load_ref_model, canonical_smiles | |
import warnings | |
import os | |
warnings.filterwarnings('ignore') | |
def get_smis(filepath): | |
print(filepath) | |
with open(filepath) as f: | |
lines = f.readlines() | |
gt_smis= [] | |
op_smis = [] | |
for s in lines: | |
if len(s)<3: | |
continue | |
s0,s1 = s.split(' || ') | |
s0,s1 = s0.strip().replace('[EOS]','').replace('[SOS]','').replace('[X]','').replace('[XPara]','').replace('[XRing]',''),s1.strip() | |
gt_smis.append(s1) | |
op_smis.append(s0) | |
return gt_smis,op_smis | |
def evaluate(gt_smis,op_smis): | |
references = [] | |
hypotheses = [] | |
for i, (gt, out) in enumerate(zip(gt_smis,op_smis)): | |
gt_tokens = [c for c in gt] | |
out_tokens = [c for c in out] | |
references.append([gt_tokens]) | |
hypotheses.append(out_tokens) | |
# BLEU score | |
bleu_score = corpus_bleu(references, hypotheses) | |
references = [] | |
hypotheses = [] | |
levs = [] | |
num_exact = 0 | |
bad_mols = 0 | |
for i, (gt, out) in enumerate(zip(gt_smis,op_smis)): | |
hypotheses.append(out) | |
references.append(gt) | |
try: | |
m_out = Chem.MolFromSmiles(out) | |
m_gt = Chem.MolFromSmiles(gt) | |
if Chem.MolToInchi(m_out) == Chem.MolToInchi(m_gt): num_exact += 1 | |
except: | |
bad_mols += 1 | |
levs.append(lev(out, gt)) | |
# Exact matching score | |
exact_match_score = num_exact/(i+1) | |
# Levenshtein score | |
levenshtein_score = np.mean(levs) | |
validity_score = 1 - bad_mols/len(gt_smis) | |
return bleu_score, exact_match_score, levenshtein_score, validity_score | |
def fevaluate(gt_smis,op_smis, morgan_r=2): | |
outputs = [] | |
bad_mols = 0 | |
for n, (gt_smi,ot_smi) in enumerate(zip(gt_smis,op_smis)): | |
try: | |
gt_m = Chem.MolFromSmiles(gt_smi) | |
ot_m = Chem.MolFromSmiles(ot_smi) | |
if ot_m == None: raise ValueError('Bad SMILES') | |
outputs.append((gt_m, ot_m)) | |
except: | |
bad_mols += 1 | |
validity_score = len(outputs)/(len(outputs)+bad_mols) | |
MACCS_sims = [] | |
morgan_sims = [] | |
RDK_sims = [] | |
enum_list = outputs | |
for i, (gt_m, ot_m) in enumerate(enum_list): | |
MACCS_sims.append(DataStructs.FingerprintSimilarity(MACCSkeys.GenMACCSKeys(gt_m), MACCSkeys.GenMACCSKeys(ot_m), metric=DataStructs.TanimotoSimilarity)) | |
RDK_sims.append(DataStructs.FingerprintSimilarity(Chem.RDKFingerprint(gt_m), Chem.RDKFingerprint(ot_m), metric=DataStructs.TanimotoSimilarity)) | |
morgan_sims.append(DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprint(gt_m,morgan_r), AllChem.GetMorganFingerprint(ot_m, morgan_r))) | |
maccs_sims_score = np.mean(MACCS_sims) | |
rdk_sims_score = np.mean(RDK_sims) | |
morgan_sims_score = np.mean(morgan_sims) | |
return validity_score, maccs_sims_score, rdk_sims_score, morgan_sims_score | |
def fcdevaluate(qgt_smis,qop_smis): | |
gt_smis = [] | |
ot_smis = [] | |
for n, (gt_smi,ot_smi) in enumerate(zip(qgt_smis,qop_smis)): | |
if len(ot_smi) == 0: ot_smi = '[]' | |
gt_smis.append(gt_smi) | |
ot_smis.append(ot_smi) | |
model = load_ref_model() | |
canon_gt_smis = [w for w in canonical_smiles(gt_smis) if w is not None] | |
canon_ot_smis = [w for w in canonical_smiles(ot_smis) if w is not None] | |
fcd_sim_score = get_fcd(canon_gt_smis, canon_ot_smis, model) | |
return fcd_sim_score | |
os.environ["CUDA_VISIBLE_DEVICES"] = "1" | |
gt,op = get_smis('output.txt') | |
bleu_score, exact_match_score, levenshtein_score,_ = evaluate(gt,op) | |
validity_score, maccs_sims_score, rdk_sims_score, morgan_sims_score = fevaluate(gt,op) | |
fcd_metric_score = fcdevaluate(gt,op) | |
print(f'BLEU: {round(bleu_score, 3)}') | |
print(f'Exact: {round(exact_match_score, 3)}') | |
print(f'Levenshtein: {round(levenshtein_score, 3)}') | |
print(f'MACCS FTS: {round(maccs_sims_score, 3)}') | |
print(f'RDK FTS: {round(rdk_sims_score, 3)}') | |
print(f'Morgan FTS: {round(morgan_sims_score, 3)}') | |
print(f'FCD Metric: {round(fcd_metric_score, 3)}') | |
print(f'Validity: {round(validity_score, 3)}') |