vumichien commited on
Commit
2b58cff
1 Parent(s): fa5f5ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -6,6 +6,11 @@ import numpy as np
6
  import tensorflow as tf
7
  from tensorflow import keras
8
 
 
 
 
 
 
9
  RDLogger.DisableLog("rdApp.*")
10
 
11
  def graph_to_molecule(graph):
@@ -50,7 +55,7 @@ generator = from_pretrained_keras("keras-io/wgan-molecular-graphs")
50
 
51
  def predict(num_mol):
52
  samples = num_mol*2
53
- z = tf.random.normal((samples, 64))
54
  graph = generator.predict(z)
55
  # obtain one-hot encoded adjacency tensor
56
  adjacency = tf.argmax(graph[0], axis=1)
@@ -59,7 +64,7 @@ def predict(num_mol):
59
  adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))
60
  # obtain one-hot encoded feature tensor
61
  features = tf.argmax(graph[1], axis=2)
62
- features = tf.one_hot(features, depth=5, axis=2)
63
  molecules = [
64
  graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])
65
  for i in range(samples)
 
6
  import tensorflow as tf
7
  from tensorflow import keras
8
 
9
+ # Config
10
+ NUM_ATOMS = 9 # Maximum number of atoms
11
+ ATOM_DIM = 4 + 1 # Number of atom types
12
+ BOND_DIM = 4 + 1 # Number of bond types
13
+ LATENT_DIM = 64 # Size of the latent space
14
  RDLogger.DisableLog("rdApp.*")
15
 
16
  def graph_to_molecule(graph):
 
55
 
56
  def predict(num_mol):
57
  samples = num_mol*2
58
+ z = tf.random.normal((samples, LATENT_DIM))
59
  graph = generator.predict(z)
60
  # obtain one-hot encoded adjacency tensor
61
  adjacency = tf.argmax(graph[0], axis=1)
 
64
  adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))
65
  # obtain one-hot encoded feature tensor
66
  features = tf.argmax(graph[1], axis=2)
67
+ features = tf.one_hot(features, depth=ATOM_DIM, axis=2)
68
  molecules = [
69
  graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])
70
  for i in range(samples)