import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from openff.toolkit.topology import Molecule, Topology
from openff.toolkit.typing.engines.smirnoff import ForceField
from openff.interchange import Interchange
# Construct a single-molecule system from toolkit classes
mol = Molecule.from_smiles("CCO")
mol.generate_conformers(n_conformers=1)
top = Topology.from_molecules([mol])
parsley = ForceField("openff-1.0.0.offxml")
off_sys = Interchange.from_smirnoff(force_field=parsley, topology=top)
bonds = off_sys.handlers["Bonds"]
# Transform parameters into matrix representations
p = bonds.get_force_field_parameters()
mapping = bonds.get_mapping()
q = bonds.get_system_parameters()
m = bonds.get_param_matrix()
# force field parameters, each row is something like [k (kcal/mol/A), length (A)]
p
# system parameters, a.k.a. force field parameters as they exist in a parametrized system
q
# m is the parametrization matrix, which can be dotted with p to get out q
assert np.allclose(m.dot(p.flatten()).reshape((-1, 2)), q)
m
# save and set initial values
q0 = q
p0 = p
# set learning rate
a = 0.1
from copy import deepcopy
# let jax run with autodiff
_, f_vjp_bonds = jax.vjp(bonds.parametrize, jnp.asarray(p0)) # d/dp
# jax.jvp( ..., has_aux=True) is another approach, but requires that bonds.parametrize returns the indices as well
q_target = deepcopy(q0)
p_target = deepcopy(p0)
# modify a few of the force field targets to arbitrary values;
# this mimic some "true" values we wish to tune to, despite
# these values not being known in real-world fitting
p_target[:, 1] = 0.5 + np.random.rand(4)
# obtain the target _sytem_ parameters by dotting the parametrization
# matrix with target force field values
q_target = m.dot(p_target.flatten()).reshape((-1, 2))
# create a dummy loss function via faking known target parameters;
# in practice this could be the result of an MD run, FE calculation, etc.
def loss(p):
return jnp.linalg.norm(bonds.parametrize(p) - q_target)
out, f_vjp_bonds = jax.vjp(loss, p0) # composes a jax.grad
# This also returns loss(p0), which we do not need to store
out == loss(p0)
f_vjp_bonds(1.0)
# this does the same as the jax.vjp above
jax_loss = jax.grad(loss)
jax_loss(p0) # dL/dp
f_vjp_bonds(1.0)[0] == jax_loss(p0) # dL/dp
# derivative of loss function evaluated at the original system parameters;
# note that column 0 mathces target values, so the derivate is flat
f_vjp_bonds(loss(q0)) # dL/dp (!) can be used as gradient in fitting
fig, ax = plt.subplots()
# label target values
ax.hlines(p_target[0, 1], 0, 100, color="k", ls="--", label="[#6X4:1]-[#6X4:2]")
ax.hlines(p_target[1, 1], 0, 100, color="r", ls="--", label="[#6X4:1]-[#1:2]")
ax.hlines(p_target[2, 1], 0, 100, color="g", ls="--", label="[#6:1]-[#8:2]")
ax.hlines(p_target[3, 1], 0, 100, color="b", ls="--", label="[#8:1]-[#1:2]")
for i in range(100):
if i % 10 == 0:
print(f"step {i}\tloss: {loss(p)}")
ax.plot(i, p[0][1], "k.")
ax.plot(i, p[1][1], "r.")
ax.plot(i, p[2][1], "g.")
ax.plot(i, p[3][1], "b.")
# use jax to get the gradient
_, f_vjp_bonds = jax.vjp(loss, p)
grad = f_vjp_bonds(1.0)[0]
# update force field parameters
p -= a * grad
# use the parametrization matrix to propagate new
# force field parameters into new system parameters
q = m.dot(p.flatten()).reshape((-1, 2))
ax.legend(loc=0)
ax.set_xlabel("iteration")
ax.set_ylabel("parameter value (bond length-ish)")
ax.set_xlim((0, 100))
ax.set_ylim((0, 1.5))
# We can do everything all over again with angles, almost identically
angles = off_sys.handlers["Angles"]
q0 = angles.get_system_parameters()
p0 = angles.get_force_field_parameters()
mapping = angles.get_mapping()
m = angles.get_param_matrix()
q = q0
p = p0
a = 0.1
q_target = deepcopy(q0)
p_target = deepcopy(p0)
p_target[:, 1] = np.random.randint(100, 120, 3)
q_target = angles.parametrize(p_target)
def loss(p):
return jnp.linalg.norm(angles.parametrize(p) - q_target)
fig, ax = plt.subplots()
# label target values
ax.hlines(p_target[0, 1], 0, 100, color="k", ls="--", label="[*:1]~[#6X4:2]-[*:3]")
ax.hlines(p_target[1, 1], 0, 100, color="r", ls="--", label="[*:1]-[#8:2]-[*:3]")
ax.hlines(p_target[2, 1], 0, 100, color="g", ls="--", label="[#1:1]-[#6X4:2]-[#1:3]")
for i in range(100):
if i % 10 == 0:
print(f"step {i}\tloss: {loss(p)}")
ax.plot(i, p[0][1], "k.")
ax.plot(i, p[1][1], "r.")
ax.plot(i, p[2][1], "g.")
# use jax to get the gradient
_, f_vjp_angles = jax.vjp(loss, p)
grad = f_vjp_angles(1.0)[0]
# update force field parameters
p -= a * grad
# print(p[0])
q = m.dot(p.flatten()).reshape((-1, 2))
ax.legend(loc=0)
ax.set_xlabel("iteration")
ax.set_ylabel("parameter value (angle-ish)")
ax.set_xlim((0, 100))
ax.set_ylim((100, 120))