#!/usr/bin/env python # coding: utf-8 # In[ ]: import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from openff.toolkit.topology import Molecule, Topology from openff.toolkit.typing.engines.smirnoff import ForceField from openff.interchange import Interchange # In[ ]: # Construct a single-molecule system from toolkit classes mol = Molecule.from_smiles("CCO") mol.generate_conformers(n_conformers=1) top = Topology.from_molecules([mol]) parsley = ForceField("openff-1.0.0.offxml") off_sys = Interchange.from_smirnoff(force_field=parsley, topology=top) # In[ ]: bonds = off_sys.handlers["Bonds"] # Transform parameters into matrix representations p = bonds.get_force_field_parameters() mapping = bonds.get_mapping() q = bonds.get_system_parameters() m = bonds.get_param_matrix() # In[ ]: # force field parameters, each row is something like [k (kcal/mol/A), length (A)] p # In[ ]: # system parameters, a.k.a. force field parameters as they exist in a parametrized system q # In[ ]: # m is the parametrization matrix, which can be dotted with p to get out q assert np.allclose(m.dot(p.flatten()).reshape((-1, 2)), q) m # In[ ]: # save and set initial values q0 = q p0 = p # set learning rate a = 0.1 # In[ ]: from copy import deepcopy # In[ ]: # let jax run with autodiff _, f_vjp_bonds = jax.vjp(bonds.parametrize, jnp.asarray(p0)) # d/dp # jax.jvp( ..., has_aux=True) is another approach, but requires that bonds.parametrize returns the indices as well # In[ ]: q_target = deepcopy(q0) p_target = deepcopy(p0) # modify a few of the force field targets to arbitrary values; # this mimic some "true" values we wish to tune to, despite # these values not being known in real-world fitting p_target[:, 1] = 0.5 + np.random.rand(4) # obtain the target _sytem_ parameters by dotting the parametrization # matrix with target force field values q_target = m.dot(p_target.flatten()).reshape((-1, 2)) # create a dummy loss function via faking known target parameters; # in practice this could be the result of an MD run, FE calculation, etc. def loss(p): return jnp.linalg.norm(bonds.parametrize(p) - q_target) # In[ ]: out, f_vjp_bonds = jax.vjp(loss, p0) # composes a jax.grad # In[ ]: # This also returns loss(p0), which we do not need to store out == loss(p0) # In[ ]: f_vjp_bonds(1.0) # In[ ]: # this does the same as the jax.vjp above jax_loss = jax.grad(loss) jax_loss(p0) # dL/dp # In[ ]: f_vjp_bonds(1.0)[0] == jax_loss(p0) # dL/dp # In[ ]: # derivative of loss function evaluated at the original system parameters; # note that column 0 mathces target values, so the derivate is flat f_vjp_bonds(loss(q0)) # dL/dp (!) can be used as gradient in fitting # In[ ]: fig, ax = plt.subplots() # label target values ax.hlines(p_target[0, 1], 0, 100, color="k", ls="--", label="[#6X4:1]-[#6X4:2]") ax.hlines(p_target[1, 1], 0, 100, color="r", ls="--", label="[#6X4:1]-[#1:2]") ax.hlines(p_target[2, 1], 0, 100, color="g", ls="--", label="[#6:1]-[#8:2]") ax.hlines(p_target[3, 1], 0, 100, color="b", ls="--", label="[#8:1]-[#1:2]") for i in range(100): if i % 10 == 0: print(f"step {i}\tloss: {loss(p)}") ax.plot(i, p[0][1], "k.") ax.plot(i, p[1][1], "r.") ax.plot(i, p[2][1], "g.") ax.plot(i, p[3][1], "b.") # use jax to get the gradient _, f_vjp_bonds = jax.vjp(loss, p) grad = f_vjp_bonds(1.0)[0] # update force field parameters p -= a * grad # use the parametrization matrix to propagate new # force field parameters into new system parameters q = m.dot(p.flatten()).reshape((-1, 2)) ax.legend(loc=0) ax.set_xlabel("iteration") ax.set_ylabel("parameter value (bond length-ish)") ax.set_xlim((0, 100)) ax.set_ylim((0, 1.5)) # In[ ]: # We can do everything all over again with angles, almost identically angles = off_sys.handlers["Angles"] q0 = angles.get_system_parameters() p0 = angles.get_force_field_parameters() mapping = angles.get_mapping() m = angles.get_param_matrix() q = q0 p = p0 a = 0.1 q_target = deepcopy(q0) p_target = deepcopy(p0) p_target[:, 1] = np.random.randint(100, 120, 3) q_target = angles.parametrize(p_target) def loss(p): return jnp.linalg.norm(angles.parametrize(p) - q_target) # In[ ]: fig, ax = plt.subplots() # label target values ax.hlines(p_target[0, 1], 0, 100, color="k", ls="--", label="[*:1]~[#6X4:2]-[*:3]") ax.hlines(p_target[1, 1], 0, 100, color="r", ls="--", label="[*:1]-[#8:2]-[*:3]") ax.hlines(p_target[2, 1], 0, 100, color="g", ls="--", label="[#1:1]-[#6X4:2]-[#1:3]") for i in range(100): if i % 10 == 0: print(f"step {i}\tloss: {loss(p)}") ax.plot(i, p[0][1], "k.") ax.plot(i, p[1][1], "r.") ax.plot(i, p[2][1], "g.") # use jax to get the gradient _, f_vjp_angles = jax.vjp(loss, p) grad = f_vjp_angles(1.0)[0] # update force field parameters p -= a * grad # print(p[0]) q = m.dot(p.flatten()).reshape((-1, 2)) ax.legend(loc=0) ax.set_xlabel("iteration") ax.set_ylabel("parameter value (angle-ish)") ax.set_xlim((0, 100)) ax.set_ylim((100, 120))