在上一篇文章中,我们讨论了在分子动力学里面使用LINCS约束算法及其在具备自动微分能力的Jax框架下的代码实现。约束算法,在分子动力学模拟的过程中时常会使用到,用于固定一些既定的成键关系。例如LINCS算法一般用于固定分子体系中的键长关系,而本文将要提到的SETTLE算法,常用于固定一个构成三角形的体系,最常见的就是水分子体系。对于一个水分子而言,O-H键的键长在模拟的过程中可以固定,H-H的长度,或者我们更常见的作为一个H-O-H的夹角出现的参量,也需要固定。纯粹从计算量来考虑的话,RATTLE约束算法需要迭代计算,LINCS算法需要求矩阵逆(虽然已经给出了截断优化的算法),而SETTLE只涉及到坐标变换,显然SETTLE在约束大规模的水盒子时,性能会更加优秀。
`$\begin{align}
a'_0&=0,r_a,0\
b'_0&=-r_c,-r_b,0\
c'_0&=r_c,-r_b,0
\end{align}
$`
关于这个坐标数值,再回头看下这个图可能会更加清晰明了一些:
那么我们最终可以得到的旋转角为:
`$\begin{align}
\phi&=arcsin\left(\frac{Z'_{A_1}}{r_a}\right)\
\psi&=arcsin\left(\frac{Z'{B_1}-Z'{C_1}}{2r_ccos\phi}\right)\
\theta&=arcsin\left(\frac{\gamma}{\sqrt{\alpha^2+\beta^2}}\right)-arctan\left(\frac{\beta}{\alpha}\right)
\end{align}
$`
我们发现这里其实不仅仅是包含有坐标轴的旋转,还包含了坐标系原点的偏移,不过这个漂移倒是比较好处理,可以在后续的计算过程中点出即可。
通过这三个点联立的方程组可以表示为:
`$\begin{align}
R\left[\left(\begin{matrix}
X_{A_0}\
Y_{A_0}\
Z_{A_0}
\end{matrix}\right)-\vec{M}\right]
&=\left(\begin{matrix}
0\
r_a\
0
\end{matrix}\right)\
R\left[\left(\begin{matrix}
X_{B_0}\
Y_{B_0}\
Z_{B_0}
\end{matrix}\right)-\vec{M}\right]
&=\left(\begin{matrix}
-r_c\
-r_b\
0
\end{matrix}\right)\
R\left[\begin{matrix}
\vec{BC}\otimes\vec{CA}
\end{matrix}\right]
&=\left(\begin{matrix}
0\
0\
1
\end{matrix}\right)
\end{align}
$`
相关的求解代码如下所示:
# settle.py
from jax import numpy as np
from jax import vmap, jit
def rotation(psi,phi,theta,v):
""" Module of rotation in 3 Euler angles. """
RY = np.array([[np.cos(psi),0,-np.sin(psi)],
[0, 1, 0],
[np.sin(psi),0,np.cos(psi)]])
RX = np.array([[1,0,0],
[0,np.cos(phi),-np.sin(phi)],
[0,np.sin(phi),np.cos(phi)]])
RZ = np.array([[np.cos(theta),-np.sin(theta),0],
[np.sin(theta),np.cos(theta),0],
[0,0,1]])
return np.dot(RZ,np.dot(RX,np.dot(RY,v)))
multi_rotation = jit(vmap(rotation,(None,None,None,0)))
if __name__ == '__main__':
import matplotlib.pyplot as plt
# construct params
ra = 0.5
rb = 0.7
rc = 1.2
psi = 0.4
phi = 0.5
theta = 1.3
# construct initial crd
crd = np.array([[0, ra, 0],
[-rc, -rb, 0],
[rc, -rb, 0]])
shift = np.array([0.1, 0.1, 0.1])
crd = multi_rotation(psi,phi,theta,crd) + shift
# get the center of mass
com = np.average(crd,0)
# 3 points are selected to solve the initial rotation matrix
xyz = [0,0,0]
xyz[0] = crd[0]-com
xyz[1] = crd[1]-com
cross = np.cross(crd[2]-crd[1],crd[0]-crd[2])
cross /= np.linalg.norm(cross)
xyz[2] = cross
xyz = np.array(xyz)
inv_xyz = np.linalg.inv(xyz)
v0 = np.array([0,-rc,0])
v1 = np.array([ra,-rb,0])
v2 = np.array([0,0,1])
# final rotation matrix is constructed by following
Rot = np.array([np.dot(inv_xyz,v0),np.dot(inv_xyz,v1),np.dot(inv_xyz,v2)])
print (Rot)
# some test cases and results
origin = crd[0]
print(np.dot(Rot, origin-com))
# [1.4901161e-08 5.0000000e-01 0.0000000e+00]
origin = crd[1]
print(np.dot(Rot, origin-com))
# [-1.2000000e+00 -7.0000005e-01 -5.9604645e-08]
origin = crd[2]
print(np.dot(Rot, origin-com))
# [ 1.2000000e+00 2.0000000e-01 -1.4901161e-08]
origin = xyz[2]
print(np.dot(Rot, origin))
# [0.0000000e+00 2.9802322e-08 1.0000000e+00]
上述代码中所得到的Rot
这个矩阵,就是我们所需的将
需要特别提及的是,上述代码中所使用到的JAX框架支持了vmap这种便捷矢量化计算的操作,因此在rotation
函数中只实现了一个旋转矩阵对一个向量的操作,再通过vmap将其扩展到了对多个矢量,也就是多个点空间旋转操作上,变成了multi_rotation
函数,这样的操作也更加符合我们对多个原子坐标的定义形式。
# settle.py
from jax import numpy as np
from jax import vmap, jit
def rotation(psi,phi,theta,v):
""" Module of rotation in 3 Euler angles. """
RY = np.array([[np.cos(psi),0,-np.sin(psi)],
[0, 1, 0],
[np.sin(psi),0,np.cos(psi)]])
RX = np.array([[1,0,0],
[0,np.cos(phi),-np.sin(phi)],
[0,np.sin(phi),np.cos(phi)]])
RZ = np.array([[np.cos(theta),-np.sin(theta),0],
[np.sin(theta),np.cos(theta),0],
[0,0,1]])
return np.dot(RZ,np.dot(RX,np.dot(RY,v)))
multi_rotation = jit(vmap(rotation,(None,None,None,0)))
def get_rot(crd):
""" Get the coordinates transform matrix. """
# get the center of mass
com = np.average(crd, 0)
rc = np.linalg.norm(crd[2]-crd[1])/2
ra = np.linalg.norm(crd[0]-com)
rb = np.sqrt(np.linalg.norm(crd[2]-crd[0])**2-rc**2)-ra
# 3 points are selected to solve the initial rotation matrix
xyz = [0, 0, 0]
xyz[0] = crd[0] - com
xyz[1] = crd[1] - com
cross = np.cross(crd[2] - crd[1], crd[0] - crd[2])
cross /= np.linalg.norm(cross)
xyz[2] = cross
xyz = np.array(xyz)
inv_xyz = np.linalg.inv(xyz)
v0 = np.array([0, -rc, 0])
v1 = np.array([ra, -rb, 0])
v2 = np.array([0, 0, 1])
# final rotation matrix is constructed by following
Rot = np.array([np.dot(inv_xyz, v0), np.dot(inv_xyz, v1), np.dot(inv_xyz, v2)])
inv_Rot = np.linalg.inv(Rot)
return Rot, inv_Rot
def xyzto(Rot, crd, com):
""" Apply the coordinates transform matrix. """
return np.dot(Rot, crd-com)
multi_xyzto = jit(vmap(xyzto,(None,0,None)))
def toxyz(Rot, crd, com):
""" Apply the inverse of transform matrix. """
return np.dot(Rot, crd-com)
multi_toxyz = jit(vmap(toxyz,(None,0,None)))
def get_circumference(crd):
""" Get the circumference of all triangles. """
return np.linalg.norm(crd[0]-crd[1])+np.linalg.norm(crd[0]-crd[2])+np.linalg.norm(crd[1]-crd[2])
jit_get_circumference = jit(get_circumference)
def get_angles(crd_0, crd_t0, crd_t1):
""" Get the rotation angle psi, phi and theta. """
com = np.average(crd_0, 0)
rc = np.linalg.norm(crd_0[2] - crd_0[1]) / 2
ra = np.linalg.norm(crd_0[0] - com)
rb = np.sqrt(np.linalg.norm(crd_0[2] - crd_0[0]) ** 2 - rc ** 2) - ra
phi = np.arcsin(crd_t1[0][2]/ra)
psi = np.arcsin((crd_t1[1][2]-crd_t1[2][2])/2/rc/np.cos(phi))
alpha = -rc*np.cos(psi)*(crd_t0[1][0]-crd_t0[2][0])+(-rb*np.cos(phi)-rc*np.sin(psi)*np.sin(phi))*(crd_t0[1][1]-crd_t0[0][1])+ \
(-rb*np.cos(phi)+rc*np.sin(psi)*np.sin(phi))*(crd_t0[2][1]-crd_t0[0][1])
beta = -rc*np.cos(psi)*(crd_t0[2][1]-crd_t0[1][1])+(-rb*np.cos(phi)-rc*np.sin(psi)*np.sin(phi))*(crd_t0[1][0]-crd_t0[0][0])+ \
(-rb*np.cos(phi)+rc*np.sin(psi)*np.sin(phi))*(crd_t0[2][0]-crd_t0[0][0])
gamma = crd_t1[1][1]*(crd_t0[1][0]-crd_t0[0][0])-crd_t1[1][0]*(crd_t0[1][1]-crd_t0[0][1])+\
crd_t1[2][1]*(crd_t0[2][0]-crd_t0[0][0])-crd_t1[2][0]*(crd_t0[2][1]-crd_t0[0][1])
sin_part = gamma/np.sqrt(alpha**2+beta**2)
theta = np.arcsin(sin_part)-np.arctan(beta/alpha)
return phi, psi, theta
jit_get_angles = jit(get_angles)
def get_d3(crd_0, psi, phi, theta):
""" Calculate the new coordinates by 3 given angles. """
com = np.average(crd_0, 0)
rc = np.linalg.norm(crd_0[2] - crd_0[1]) / 2
ra = np.linalg.norm(crd_0[0] - com)
rb = np.sqrt(np.linalg.norm(crd_0[2] - crd_0[0]) ** 2 - rc ** 2) - ra
return np.array([[-ra*np.cos(phi)*np.sin(theta), ra*np.cos(phi)*np.cos(theta), ra*np.sin(phi)],
[-rc*np.cos(psi)*np.cos(theta)+rb*np.sin(theta)*np.cos(phi)+rc*np.sin(theta)*np.sin(psi)*np.sin(phi),
-rc*np.cos(psi)*np.sin(theta)-rb*np.cos(theta)*np.cos(phi)-rc*np.cos(theta)*np.sin(psi)*np.sin(phi),
-rb*np.sin(phi)+rc*np.sin(psi)*np.cos(phi)],
[rc*np.cos(psi)*np.cos(theta)+rb*np.sin(theta)*np.cos(phi)-rc*np.sin(theta)*np.sin(psi)*np.sin(phi),
rc*np.cos(psi)*np.sin(theta)-rb*np.cos(theta)*np.cos(phi)+rc*np.cos(theta)*np.sin(psi)*np.sin(phi),
-rb*np.sin(phi)-rc*np.sin(psi)*np.cos(phi)]])
jit_get_d3 = jit(get_d3)
if __name__ == '__main__':
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as onp
onp.random.seed(0)
# construct params
ra = 1.0
rb = 0.5
rc = 1.2
psi = 0.4
phi = 0.5
theta = 1.3
# construct initial crd
crd = np.array([[0, ra, 0],
[-rc, -rb, 0],
[rc, -rb, 0]])
shift = np.array([0.1, 0.1, 0.1])
# get the initial crd
crd_0 = multi_rotation(psi,phi,theta,crd) + shift
vel = np.array(onp.random.random(crd_0.shape)-0.5)
dt = 1
# get the unconstraint crd
crd_1 = crd_0 + vel * dt
com_0 = np.average(crd_0, 0)
com_1 = np.average(crd_1, 0)
# get the coordinate transform matrix and correspond inverse operation
rot, inv_rot = get_rot(crd_0)
crd_t0 = multi_xyzto(rot, crd_0, com_0)
com_t0 = np.average(crd_t0, 0)
crd_t1 = multi_xyzto(rot, crd_1, com_1)+com_1
com_t1 = np.average(crd_t1, 0)
print ('crd_t1:\n', crd_t1)
# crd_t1:
# [[0.11285806 1.1888411 0.22201033]
# [-1.3182535 - 0.35559598 0.3994387]
# [1.5366794 - 0.00262779
# 0.3908713]]
phi, psi, theta = jit_get_angles(crd_0, crd_t0, crd_t1-com_t1)
crd_t3 = jit_get_d3(crd_t0,psi,phi,theta)+com_t1
com_t3 = np.average(crd_t3, 0)
crd_3 = multi_toxyz(inv_rot, crd_t3, com_t3) + com_1
print ('crd_t3:\n', crd_t3)
# crd_t3:
# [[0.01470824 1.2655654 0.22201033]
# [-1.0361676 - 0.3326143 0.39943868]
# [1.3527434 - 0.10233352
# 0.39087126]]
print(jit_get_circumference(crd_0))
# 6.2418747
print(jit_get_circumference(crd_t0))
# 6.2418737
print(jit_get_circumference(crd_1))
# 6.853938
print(jit_get_circumference(crd_t1))
# 6.8539376
print(jit_get_circumference(crd_t3))
# 6.2418737
print(jit_get_circumference(crd_3))
# 6.241874
# Plotting
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x_0 = np.append(crd_0[:,0],crd_0[0][0])
y_0 = np.append(crd_0[:,1],crd_0[0][1])
z_0 = np.append(crd_0[:,2],crd_0[0][2])
ax.plot(x_0, y_0, z_0, color='black')
x_1 = np.append(crd_1[:, 0], crd_1[0][0])
y_1 = np.append(crd_1[:, 1], crd_1[0][1])
z_1 = np.append(crd_1[:, 2], crd_1[0][2])
ax.plot(x_1, y_1, z_1, color='blue')
x_3 = np.append(crd_3[:, 0], crd_3[0][0])
y_3 = np.append(crd_3[:, 1], crd_3[0][1])
z_3 = np.append(crd_3[:, 2], crd_3[0][2])
ax.plot(x_3, y_3, z_3, color='red')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()
其中黑色的是原始的三角形,蓝色的是未施加约束条件的偏移,其中重心也发生了较为明显的变化,而红色的三角形对应的是施加约束后的三角形。还可以从另外一个角度来查看施加约束前后的两个三角形的平面关系:
当SETTLE应用在分子模拟当中的时候,不仅仅是更新约束前后的位置,相对应的,速度也需要更新。这里我们没有将其实现到代码当中,仅仅放一下公式,以供参考:
然后将
的值代入到如下的公式:
就可以得到更新后的速度。相关内容并不是很复杂,读者可以自行实现。
继上一篇文章介绍了分子动力学模拟中常用的LINCS约束算法之后,本文再介绍一种SETTLE约束算法,及其基于Jax的实现方案。LINCS约束算法相对来说比较通用,更适合于成键关系比较复杂的通用的体系,而SETTLE算法更加适用于三原子轴对称体系,比如水分子。SETTLE算法结合velocity-verlet算法,可以确保一个分子只进行整体的旋转运动,互相之间的距离又保持不变。比较关键的是,SETTLE算法所依赖的参数较少,也不需要微分,因此在性能上十分有优势。