Writing an atomic neural network

This notebook demonstrates the construction of a “network” in PiNN.

[1]:
import os, warnings
import tensorflow as tf
from glob import glob
from pinn.io import load_qm9, sparse_batch
from pinn.layers import cell_list_nl
from pinn.models import potential_model
[2]:
filelist = glob('/home/yunqi/datasets/QM9/dsgdb9nsd/*.xyz')
dataset = lambda: load_qm9(filelist, split={'train':8, 'test':2})
train = lambda: dataset()['train'].repeat().shuffle(1000).apply(sparse_batch(100))
test = lambda: dataset()['test'].repeat().apply(sparse_batch(100))

Network function

A network function represents a learnable function that maps structures to atomic predictions.
We’ll build a network function from scratch to familiarize you with the structure of a network.
The following will be a neural network that takes the elements and distances of atom pairs as input,
and predicts pairwise energies.

PS. If you are not sure about what an operation does, sess.run(the_tensor) to find out.

Starting with the input tensors:

[3]:
tensors = train().make_one_shot_iterator().get_next(); tensors
WARNING:tensorflow:From /home/yunqi/.miniconda/envs/pinn_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:429: py_func (from tensorflow.python.ops.script_ops) is deprecated and will be removed in a future version.
Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, use
    tf.py_function, which takes a python function which manipulates tf eager
    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
    an ndarray (just call tensor.numpy()) but having access to eager tensors
    means `tf.py_function`s can use accelerators such as GPUs as well as
    being differentiable using a gradient tape.

[3]:
{'elems': <tf.Tensor 'IteratorGetNext:2' shape=(?,) dtype=int32>,
 'coord': <tf.Tensor 'IteratorGetNext:0' shape=(?, 3) dtype=float32>,
 'e_data': <tf.Tensor 'IteratorGetNext:1' shape=(?,) dtype=float32>,
 'ind_1': <tf.Tensor 'IteratorGetNext:3' shape=(?, 1) dtype=int32>}

We then build the neighbor list.

cell_list_nl is an important component of PiNN, it implements the cell lists algorithm which yields the neighor list of structures with linear scaling. It returns three tensors: ind_2 is the indices of the i,j atoms, dist is the pairwise distance and diff is the displacement vector. The first dimension of the tensors is the number of pairs.

Batches and periodic boundary condictions are handled.

[4]:
nl = cell_list_nl(tensors); nl

WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

[4]:
{'ind_2': <tf.Tensor 'cell_list_nl/concat_3:0' shape=(?, 2) dtype=int32>,
 'dist': <tf.Tensor 'cell_list_nl/GatherNd_5:0' shape=(?,) dtype=float32>,
 'diff': <tf.Tensor 'cell_list_nl/GatherNd_6:0' shape=(?, 3) dtype=float32>}

We can gather the elements of i,j for each pair using tf.gather. Here we cast them into floating points so that they can be used as inputs of the neural network.

[5]:
elems = tensors['elems']
natoms = tf.shape(elems)[0]
ind_2 = nl['ind_2']
dist = nl['dist']

elem_i = tf.cast(tf.gather(elems, ind_2[:, 0]), tf.float32)
elem_j = tf.cast(tf.gather(elems, ind_2[:, 1]), tf.float32)
elem_i, elem_j
[5]:
(<tf.Tensor 'Cast:0' shape=(?,) dtype=float32>,
 <tf.Tensor 'Cast_1:0' shape=(?,) dtype=float32>)

We take the elements of a pair and their distance as input, and output pairwise energy.

[6]:
pair = tf.stack([elem_i, elem_j, dist], axis=1)

nodes = pair
for n in [16, 16]:
    nodes = tf.layers.dense(nodes, n, activation='tanh')

# use a linear output layer to produce the energy
e_pair = tf.layers.dense(nodes, 1, activation=None); e_pair
WARNING:tensorflow:From <ipython-input-6-0e8a3406120e>:5: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dense instead.
WARNING:tensorflow:From /home/yunqi/.miniconda/envs/pinn_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
[6]:
<tf.Tensor 'dense_2/BiasAdd:0' shape=(?, 1) dtype=float32>

The pairwise energies can be summed to atomic energies with tf.unsorted_segment_sum.

[7]:
e_atom = tf.unsorted_segment_sum(e_pair, ind_2[:, 0], natoms); e_atom
[7]:
<tf.Tensor 'UnsortedSegmentSum:0' shape=(?, 1) dtype=float32>

Put them together, now you have a new network function.

[8]:
def my_network(tensors, n_nodes=[]):
    elems = tensors['elems']
    natoms = tf.shape(elems)[0]

    nl = cell_list_nl(tensors)
    ind_2 = nl['ind_2']
    dist = nl['dist']

    elem_i = tf.cast(tf.gather(elems, ind_2[:, 0]), tf.float32)
    elem_j = tf.cast(tf.gather(elems, ind_2[:, 1]), tf.float32)
    pair = tf.stack([elem_i, elem_j, dist], axis=1)

    nodes = pair
    for n in n_nodes:
        nodes = tf.layers.dense(nodes, n, activation='tanh')

    e_pair = tf.layers.dense(nodes, 1, activation=None)
    e_atom = tf.unsorted_segment_sum(e_pair, ind_2[:, 0], natoms)
    return tf.squeeze(e_atom, axis=-1)

Training with the network

Once you’ve got your network_fn, you can directly use it in the existing models.
Just substitute network with your function definition.
[9]:
params={
    'model_dir': '/tmp/my_network',
    'network': my_network,
    'network_params': {
        'n_nodes':[16, 16]},
    'model_params':{}}
[10]:
model = potential_model(params)
train_spec = tf.estimator.TrainSpec(input_fn=train, max_steps=1000)
eval_spec = tf.estimator.EvalSpec(input_fn=test, steps=100)
tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/my_network', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcddb384150>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Calling model_fn.
Total number of trainable variables: 353
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/my_network/model.ckpt.
INFO:tensorflow:loss = 346152.0, step = 1
INFO:tensorflow:global_step/sec: 9.72204
INFO:tensorflow:loss = 40745.125, step = 101 (10.289 sec)
INFO:tensorflow:global_step/sec: 9.65198
INFO:tensorflow:loss = 19949.152, step = 201 (10.360 sec)
INFO:tensorflow:global_step/sec: 9.44661
INFO:tensorflow:loss = 14702.94, step = 301 (10.588 sec)
INFO:tensorflow:global_step/sec: 9.58072
INFO:tensorflow:loss = 9526.419, step = 401 (10.438 sec)
INFO:tensorflow:global_step/sec: 9.54791
INFO:tensorflow:loss = 3616.6038, step = 501 (10.474 sec)
INFO:tensorflow:global_step/sec: 9.2748
INFO:tensorflow:loss = 1968.7462, step = 601 (10.779 sec)
INFO:tensorflow:global_step/sec: 9.42503
INFO:tensorflow:loss = 2020.4435, step = 701 (10.612 sec)
INFO:tensorflow:global_step/sec: 9.53732
INFO:tensorflow:loss = 2231.4812, step = 801 (10.483 sec)
INFO:tensorflow:global_step/sec: 9.45915
INFO:tensorflow:loss = 1619.0415, step = 901 (10.573 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/my_network/model.ckpt.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /home/yunqi/.miniconda/envs/pinn_env/lib/python3.7/site-packages/tensorflow/python/ops/metrics_impl.py:363: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-09-30T22:43:41Z
INFO:tensorflow:Graph was finalized.
WARNING:tensorflow:From /home/yunqi/.miniconda/envs/pinn_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from /tmp/my_network/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [80/100]
INFO:tensorflow:Evaluation [90/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Finished evaluation at 2019-09-30-22:43:53
INFO:tensorflow:Saving dict for global step 1000: METRICS/E_LOSS = 1788.0197, METRICS/E_MAE = 32.56714, METRICS/E_RMSE = 42.28498, METRICS/TOT_LOSS = 1788.0195, global_step = 1000, loss = 1788.0195
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmp/my_network/model.ckpt-1000
INFO:tensorflow:Loss for final step: 1588.4066.
[10]:
({'METRICS/E_LOSS': 1788.0197,
  'METRICS/E_MAE': 32.56714,
  'METRICS/E_RMSE': 42.28498,
  'METRICS/TOT_LOSS': 1788.0195,
  'loss': 1788.0195,
  'global_step': 1000},
 [])