上一篇文章谈了一下alphafold模型框架。现在来解决细节和实现问题。
文章链接: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8371605/
这个结构两个支路,上面是多头自注意力机制+MLP,下面也是如此。
我现在打开github,安装了pytorch版本的alphafold代码:
下面的seq和msa都是随机生成的0~21之间的整数。seq自然表示我们的目标氨基酸序列,可以看到序列长度为128,然后0~21表示有20种不同的氨基酸种类(0的含义不明)。msa自然是模拟的从数据库种匹配的具有相似氨基酸序列片段的序列。这里选出了5个氨基酸序列,同样也是128长度的。
这样我们就可以得到预测结果了。
先看全部模型代码:
class Alphafold2(nn.Module):
def __init__(
self,
*,
dim,
max_seq_len = 2048,
depth = 6,
heads = 8,
dim_head = 64,
max_rel_dist = 32,
num_tokens = constants.NUM_AMINO_ACIDS,
num_embedds = constants.NUM_EMBEDDS_TR,
max_num_msas = constants.MAX_NUM_MSA,
max_num_templates = constants.MAX_NUM_TEMPLATES,
extra_msa_evoformer_layers = 4,
attn_dropout = 0.,
ff_dropout = 0.,
templates_dim = 32,
templates_embed_layers = 4,
templates_angles_feats_dim = 55,
predict_angles = False,
symmetrize_omega = False,
predict_coords = False, # structure module related keyword arguments below
structure_module_depth = 4,
structure_module_heads = 1,
structure_module_dim_head = 4,
disable_token_embed = False,
mlm_mask_prob = 0.15,
mlm_random_replace_token_prob = 0.1,
mlm_keep_token_same_prob = 0.1,
mlm_exclude_token_ids = (0,),
recycling_distance_buckets = 32
):
super().__init__()
self.dim = dim
# token embedding
self.token_emb = nn.Embedding(num_tokens + 1, dim) if not disable_token_embed else Always(0)
self.to_pairwise_repr = nn.Linear(dim, dim * 2)
self.disable_token_embed = disable_token_embed
# positional embedding
self.max_rel_dist = max_rel_dist
self.pos_emb = nn.Embedding(max_rel_dist * 2 + 1, dim)
# extra msa embedding
self.extra_msa_evoformer = Evoformer(
dim = dim,
depth = extra_msa_evoformer_layers,
seq_len = max_seq_len,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
global_column_attn = True
)
# template embedding
self.to_template_embed = nn.Linear(templates_dim, dim)
self.templates_embed_layers = templates_embed_layers
self.template_pairwise_embedder = PairwiseAttentionBlock(
dim = dim,
dim_head = dim_head,
heads = heads,
seq_len = max_seq_len
)
self.template_pointwise_attn = Attention(
dim = dim,
dim_head = dim_head,
heads = heads,
dropout = attn_dropout
)
self.template_angle_mlp = nn.Sequential(
nn.Linear(templates_angles_feats_dim, dim),
nn.GELU(),
nn.Linear(dim, dim)
)
# projection for angles, if needed
self.predict_angles = predict_angles
self.symmetrize_omega = symmetrize_omega
if predict_angles:
self.to_prob_theta = nn.Linear(dim, constants.THETA_BUCKETS)
self.to_prob_phi = nn.Linear(dim, constants.PHI_BUCKETS)
self.to_prob_omega = nn.Linear(dim, constants.OMEGA_BUCKETS)
# custom embedding projection
self.embedd_project = nn.Linear(num_embedds, dim)
# main trunk modules
self.net = Evoformer(
dim = dim,
depth = depth,
seq_len = max_seq_len,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)
# MSA SSL MLM
self.mlm = MLM(
dim = dim,
num_tokens = num_tokens,
mask_id = num_tokens, # last token of embedding is used for masking
mask_prob = mlm_mask_prob,
keep_token_same_prob = mlm_keep_token_same_prob,
random_replace_token_prob = mlm_random_replace_token_prob,
exclude_token_ids = mlm_exclude_token_ids
)
# calculate distogram logits
self.to_distogram_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, constants.DISTOGRAM_BUCKETS)
)
# to coordinate output
self.predict_coords = predict_coords
self.structure_module_depth = structure_module_depth
self.msa_to_single_repr_dim = nn.Linear(dim, dim)
self.trunk_to_pairwise_repr_dim = nn.Linear(dim, dim)
with torch_default_dtype(torch.float32):
self.ipa_block = IPABlock(
dim = dim,
heads = structure_module_heads,
)
self.to_quaternion_update = nn.Linear(dim, 6)
init_zero_(self.ipa_block.attn.to_out)
self.to_points = nn.Linear(dim, 3)
# aux confidence measure
self.lddt_linear = nn.Linear(dim, 1)
# recycling params
self.recycling_msa_norm = nn.LayerNorm(dim)
self.recycling_pairwise_norm = nn.LayerNorm(dim)
self.recycling_distance_embed = nn.Embedding(recycling_distance_buckets, dim)
self.recycling_distance_buckets = recycling_distance_buckets
def forward(
self,
seq,
msa = None,
mask = None,
msa_mask = None,
extra_msa = None,
extra_msa_mask = None,
seq_index = None,
seq_embed = None,
msa_embed = None,
templates_feats = None,
templates_mask = None,
templates_angles = None,
embedds = None,
recyclables = None,
return_trunk = False,
return_confidence = False,
return_recyclables = False,
return_aux_logits = False
):
assert not (self.disable_token_embed and not exists(seq_embed)), 'sequence embedding must be supplied if one has disabled token embedding'
assert not (self.disable_token_embed and not exists(msa_embed)), 'msa embedding must be supplied if one has disabled token embedding'
# if MSA is not passed in, just use the sequence itself
if not exists(msa):
msa = rearrange(seq, 'b n -> b () n')
msa_mask = rearrange(mask, 'b n -> b () n')
# assert on sequence length
assert msa.shape[-1] == seq.shape[-1], 'sequence length of MSA and primary sequence must be the same'
# variables
b, n, device = *seq.shape[:2], seq.device
n_range = torch.arange(n, device = device)
# unpack (AA_code, atom_pos)
if isinstance(seq, (list, tuple)):
seq, seq_pos = seq
# embed main sequence
x = self.token_emb(seq)
if exists(seq_embed):
x += seq_embed
# mlm for MSAs
if self.training and exists(msa):
original_msa = msa
msa_mask = default(msa_mask, lambda: torch.ones_like(msa).bool())
noised_msa, replaced_msa_mask = self.mlm.noise(msa, msa_mask)
msa = noised_msa
# embed multiple sequence alignment (msa)
if exists(msa):
m = self.token_emb(msa)
if exists(msa_embed):
m = m + msa_embed
# add single representation to msa representation
m = m + rearrange(x, 'b n d -> b () n d')
# get msa_mask to all ones if none was passed
msa_mask = default(msa_mask, lambda: torch.ones_like(msa).bool())
elif exists(embedds):
m = self.embedd_project(embedds)
# get msa_mask to all ones if none was passed
msa_mask = default(msa_mask, lambda: torch.ones_like(embedds[..., -1]).bool())
else:
raise Error('either MSA or embeds must be given')
# derive pairwise representation
x_left, x_right = self.to_pairwise_repr(x).chunk(2, dim = -1)
x = rearrange(x_left, 'b i d -> b i () d') + rearrange(x_right, 'b j d-> b () j d') # create pair-wise residue embeds
x_mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j') if exists(mask) else None
# add relative positional embedding
seq_index = default(seq_index, lambda: torch.arange(n, device = device))
seq_rel_dist = rearrange(seq_index, 'i -> () i ()') - rearrange(seq_index, 'j -> () () j')
seq_rel_dist = seq_rel_dist.clamp(-self.max_rel_dist, self.max_rel_dist) + self.max_rel_dist
rel_pos_emb = self.pos_emb(seq_rel_dist)
x = x + rel_pos_emb
# add recyclables, if present
if exists(recyclables):
m[:, 0] = m[:, 0] + self.recycling_msa_norm(recyclables.single_msa_repr_row)
x = x + self.recycling_pairwise_norm(recyclables.pairwise_repr)
distances = torch.cdist(recyclables.coords, recyclables.coords, p=2)
boundaries = torch.linspace(2, 20, steps = self.recycling_distance_buckets, device = device)
discretized_distances = torch.bucketize(distances, boundaries[:-1])
distance_embed = self.recycling_distance_embed(discretized_distances)
x = x + distance_embed
# embed templates, if present
if exists(templates_feats):
_, num_templates, *_ = templates_feats.shape
# embed template
t = self.to_template_embed(templates_feats)
t_mask_crossed = rearrange(templates_mask, 'b t i -> b t i ()') * rearrange(templates_mask, 'b t j -> b t () j')
t = rearrange(t, 'b t ... -> (b t) ...')
t_mask_crossed = rearrange(t_mask_crossed, 'b t ... -> (b t) ...')
for _ in range(self.templates_embed_layers):
t = self.template_pairwise_embedder(t, mask = t_mask_crossed)
t = rearrange(t, '(b t) ... -> b t ...', t = num_templates)
t_mask_crossed = rearrange(t_mask_crossed, '(b t) ... -> b t ...', t = num_templates)
# template pos emb
x_point = rearrange(x, 'b i j d -> (b i j) () d')
t_point = rearrange(t, 'b t i j d -> (b i j) t d')
x_mask_point = rearrange(x_mask, 'b i j -> (b i j) ()')
t_mask_point = rearrange(t_mask_crossed, 'b t i j -> (b i j) t')
template_pooled = self.template_pointwise_attn(
x_point,
context = t_point,
mask = x_mask_point,
context_mask = t_mask_point
)
template_pooled_mask = rearrange(t_mask_point.sum(dim = -1) > 0, 'b -> b () ()')
template_pooled = template_pooled * template_pooled_mask
template_pooled = rearrange(template_pooled, '(b i j) () d -> b i j d', i = n, j = n)
x = x + template_pooled
# add template angle features to MSAs by passing through MLP and then concat
if exists(templates_angles):
t_angle_feats = self.template_angle_mlp(templates_angles)
m = torch.cat((m, t_angle_feats), dim = 1)
msa_mask = torch.cat((msa_mask, templates_mask), dim = 1)
# embed extra msa, if present
if exists(extra_msa):
extra_m = self.token_emb(msa)
extra_msa_mask = default(extra_msa_mask, torch.ones_like(extra_m).bool())
x, extra_m = self.extra_msa_evoformer(
x,
extra_m,
mask = x_mask,
msa_mask = extra_msa_mask
)
# trunk
x, m = self.net(
x,
m,
mask = x_mask,
msa_mask = msa_mask
)
# ready output container
ret = ReturnValues()
# calculate theta and phi before symmetrization
if self.predict_angles:
ret.theta_logits = self.to_prob_theta(x)
ret.phi_logits = self.to_prob_phi(x)
# embeds to distogram
trunk_embeds = (x + rearrange(x, 'b i j d -> b j i d')) * 0.5 # symmetrize
distance_pred = self.to_distogram_logits(trunk_embeds)
ret.distance = distance_pred
# calculate mlm loss, if training
msa_mlm_loss = None
if self.training and exists(msa):
num_msa = original_msa.shape[1]
msa_mlm_loss = self.mlm(m[:, :num_msa], original_msa, replaced_msa_mask)
# determine angles, if specified
if self.predict_angles:
omega_input = trunk_embeds if self.symmetrize_omega else x
ret.omega_logits = self.to_prob_omega(omega_input)
if not self.predict_coords or return_trunk:
return ret
# derive single and pairwise embeddings for structural refinement
single_msa_repr_row = m[:, 0]
single_repr = self.msa_to_single_repr_dim(single_msa_repr_row)
pairwise_repr = self.trunk_to_pairwise_repr_dim(x)
# prepare float32 precision for equivariance
original_dtype = single_repr.dtype
single_repr, pairwise_repr = map(lambda t: t.float(), (single_repr, pairwise_repr))
# iterative refinement with equivariant transformer in high precision
with torch_default_dtype(torch.float32):
quaternions = torch.tensor([1., 0., 0., 0.], device = device) # initial rotations
quaternions = repeat(quaternions, 'd -> b n d', b = b, n = n)
translations = torch.zeros((b, n, 3), device = device)
# go through the layers and apply invariant point attention and feedforward
for i in range(self.structure_module_depth):
is_last = i == (self.structure_module_depth - 1)
# the detach comes from
# https://github.com/deepmind/alphafold/blob/0bab1bf84d9d887aba5cfb6d09af1e8c3ecbc408/alphafold/model/folding.py#L383
rotations = quaternion_to_matrix(quaternions)
if not is_last:
rotations = rotations.detach()
single_repr = self.ipa_block(
single_repr,
mask = mask,
pairwise_repr = pairwise_repr,
rotations = rotations,
translations = translations
)
# update quaternion and translation
quaternion_update, translation_update = self.to_quaternion_update(single_repr).chunk(2, dim = -1)
quaternion_update = F.pad(quaternion_update, (1, 0), value = 1.)
print(quaternions)
quaternions = quaternion_multiply(quaternions, quaternion_update)
translations = translations + einsum('b n c, b n c r -> b n r', translation_update, rotations)
points_local = self.to_points(single_repr)
rotations = quaternion_to_matrix(quaternions)
coords = einsum('b n c, b n c d -> b n d', points_local, rotations) + translations
coords.type(original_dtype)
if return_recyclables:
coords, single_msa_repr_row, pairwise_repr = map(torch.detach, (coords, single_msa_repr_row, pairwise_repr))
ret.recyclables = Recyclables(coords, single_msa_repr_row, pairwise_repr)
if return_aux_logits:
return coords, ret
if return_confidence:
return coords, self.lddt_linear(single_repr.float())
return coords
贼长一段!!!
但是!!!好在这个结构是经典pytorch模型类构建,看到这个类继了nn.Module
和forward()
的时候,我已经谢天谢地了。
老规矩,和之前一样先从forward函数看起来。搞明白了数据流转,那么就搞明白了模型。
def forward(
self,
seq,
msa = None,
mask = None,
msa_mask = None,
extra_msa = None,
extra_msa_mask = None,
seq_index = None,
seq_embed = None,
msa_embed = None,
templates_feats = None,
templates_mask = None,
templates_angles = None,
embedds = None,
recyclables = None,
return_trunk = False,
return_confidence = False,
return_recyclables = False,
return_aux_logits = False
):
可以看到,alphafold其实必要的就是目标序列seq这个变量,msa还是其他的一些特征,都可以是None,就可以运行。
from einops import rearrange, repeat, reduce
if not exists(msa):
msa = rearrange(seq, 'b n -> b () n')
msa_mask = rearrange(mask, 'b n -> b () n')
这里面从einops库中使用了rearrange函数,这个函数其实就是pytorch的unsqueeze,squeeze,permute操作的大合集。上面代码中的实现的功能类似.unsqueeze(1)
,下图可以验证:
assert msa.shape[-1] == seq.shape[-1], 'sequence length of MSA and primary sequence must be the same'
这一样表示,目标的氨基酸序列和msa特征中其他物种的氨基酸序列长度要相同。
b, n, device = *seq.shape[:2], seq.device
n_range = torch.arange(n, device = device)
定义了几个变量,b是batch,n就是序列长度,我们这个例子中是128,device是cuda,n_range是一个数组,从0到n的一个递增的数组。
x = self.token_emb(seq)
现在,总算迎来了第一个module,token_emb:
self.token_emb = nn.Embedding(num_tokens + 1, dim) if not disable_token_embed else Always(0)
这里num_tokens为21,dim为256,zheli 这里用到了nn.embedding()
,作用就是为每一种氨基酸构建一个对应的词向量,下图为运行过程:
# mlm for MSAs
if self.training and exists(msa):
original_msa = msa
msa_mask = default(msa_mask, lambda: torch.ones_like(msa).bool())
noised_msa, replaced_msa_mask = self.mlm.noise(msa, msa_mask)
msa = noised_msa
对于训练阶段,msa特征要经过这样的处理。mlm.noise
,我们先来看MLM类的构建,MLM也是一个集成了nn.Module的模型类:
class MLM(nn.Module):
def __init__(
self,
dim,
num_tokens,
mask_id,
mask_prob = 0.15,
random_replace_token_prob = 0.1,
keep_token_same_prob = 0.1,
exclude_token_ids = (0,)
):
super().__init__()
self.to_logits = nn.Linear(dim, num_tokens)
self.mask_id = mask_id
self.mask_prob = mask_prob
self.exclude_token_ids = exclude_token_ids
self.keep_token_same_prob = keep_token_same_prob
self.random_replace_token_prob = random_replace_token_prob
def noise(self, seq, mask):
num_msa = seq.shape[1]
seq = rearrange(seq, 'b n ... -> (b n) ...')
mask = rearrange(mask, 'b n ... -> (b n) ...')
# prepare masks for noising sequence
excluded_tokens_mask = mask
for token_id in self.exclude_token_ids:
excluded_tokens_mask = excluded_tokens_mask & (seq != token_id)
#这个mlm_mask就是从序列中,选择了随机的0.15的比例
mlm_mask = get_mask_subset_with_prob(excluded_tokens_mask, self.mask_prob)
# keep some tokens the same
replace_token_with_mask = get_mask_subset_with_prob(mlm_mask, 1. - self.keep_token_same_prob)
# replace with mask
# 这一行我好像明白了!!!!
# 这里就是BERT训练策略?!把刚刚选择的15%的序列抹去,换成对应的位置编号!对应BERT的抹去!
seq = seq.masked_fill(mlm_mask, self.mask_id)
# generate random tokens
# 这下面似乎是一些突变?!也就是小概率大概0.09的概率,把部分氨基酸替换成其他的氨基酸。也就是类似于NLP任务中的错别字。
random_replace_token_prob_mask = get_mask_subset_with_prob(mlm_mask, (1 - self.keep_token_same_prob) * self.random_replace_token_prob)
random_tokens = torch.randint(1, constants.NUM_AMINO_ACIDS, seq.shape).to(seq.device)
for token_id in self.exclude_token_ids:
random_replace_token_prob_mask = random_replace_token_prob_mask & (random_tokens != token_id) # make sure you never substitute a token with an excluded token type (pad, start, end)
# noise sequence
noised_seq = torch.where(random_replace_token_prob_mask, random_tokens, seq)
noised_seq = rearrange(noised_seq, '(b n) ... -> b n ...', n = num_msa)
mlm_mask = rearrange(mlm_mask, '(b n) ... -> b n ...', n = num_msa)
return noised_seq, mlm_mask
def forward(self, seq_embed, original_seq, mask):
logits = self.to_logits(seq_embed)
seq_logits = logits[mask]
seq_labels = original_seq[mask]
loss = F.cross_entropy(seq_logits, seq_labels, reduction = 'mean')
return loss
其中,调用了noise方法,详解:
然后有一行是这个方法,我在代码中增加了详细的注释:
# 假定这个prob=0.15
mlm_mask = get_mask_subset_with_prob(excluded_tokens_mask, self.mask_prob)
里面用到了这个方法get_mask_subset_with_prob
:
def get_mask_subset_with_prob(mask, prob):
batch, seq_len, device = *mask.shape, mask.device
max_masked = math.ceil(prob * seq_len) # 20
# num_tokens就是每一个序列中非0序号的氨基酸数目。
num_tokens = mask.sum(dim=-1, keepdim=True)
# num_tokens * prob就是128*0.15=20
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
# max_masked = 20,所以mask_excess.shape=[5,20]
mask_excess = mask_excess[:, :max_masked]
rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
#这个sampled_indices其实就是5,128中随机选择了20个元素的索引罢了。
_, sampled_indices = rand.topk(max_masked, dim=-1)
# 这个20个索引中,还需要去掉位于后面0.15比例的样本好像。
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
new_mask = torch.zeros((batch, seq_len + 1), device=device)
# 把随机选择的位置,放回到new_mask当中。
new_mask.scatter_(-1, sampled_indices, 1)
return new_mask[:, 1:].bool()
# 总之,这个returnde new_mask.shape = [5,128],并且里面只有0.15比例的元素是1,其他都是0
里面有一个pytorch.masked_fill(mask,value)
函数,这个函数很简单,就是把mask当中为1的位置的元素替换成value就行了,在nlp当中比较常见。可以参考这个博客和官方文档中查看:https://blog.csdn.net/weixin_44737266/article/details/116486838
之后的部分就是,对刚刚添加过噪音的msa数据,进行embedding:
使用和squ相同的embedding层进行embedding。所以m的形状变成(1,5,128,256)。
至此,MSA特征算是完成了。输入的squ直接经过token_emb,而MSA先经过noise(抹去和突变),然后在经过token_emb
这里的输入x的shape=(1,128,256),这个self.to_pairwise_repr
是nn.Linear(dim,dim*2)
,所以变成了(1,128,512),然后经过torch.chunk
操作变成了x_left和x_right,两个每一个沿着最后一个维度切开,所以两个left和right每一个的shape都是(1,128,256)
经过上图的第二部的操作,里面设计到了torch不同尺寸的张量运算的广播机制,所以得到的x的shape变成了(1,128,128,256),开始有点我们之前理论讲解那一块的样子了!
至此,我们已经完成了论文中这一部分的深刻剖析!
上面讲述的内容可能有一些杂乱,因为我是边看代码边整理的,缺少一定的框架性。总结的来说,就是讲述了如何从代码中得到MSA特征和pair representation特征。前者先通过抹去15%的氨基酸加上突变氨基酸,然后token_emb;后者就是直接通过linear层,然后chunk一下,然后直接通过加法出来一个简单粗暴的pair representation。其实我这里并没有讲解上图中的structure database search这一块的代码。这一块的代码可有可无,缺少这一部分的特征从代码上看并不妨碍完整推理流程。因此先说这么多把。