
looooooooong time no post!
我又回来咯,谢谢各位粉丝大佬们的捧场。
今天要分享的是kaggle的一场比赛:Stanford RNA 3D Folding

因为这个比赛的模型和可以学习的地方有很多,所以今天只是一个简单的inference的demo代码,先跑通,然后再一步一步了解。所以各位粉丝朋友持续关注就好咯,有时间我就会更新的。
一.Boltz-1模型介绍

一句话来介绍这个模型就是:
Boltz-1 是第一个达到 AlphaFold3 报告准确度水平、完全可商用的开源深度学习模型,用于预测生物分子复合物的 3D 结构,是由麻省理工学院(MIT)的研究人员宣布推出一款开源模型。所以boltz-1的模型性能是和alphafold3相比肩的。!

从图中可以看出Boltz-1在预测蛋白质结构方面可以达到很高的精度。
下图是boltz-1在casp15大赛中的表现:

所以我们本章要分享的内容就是使用Boltz-1模型的权重对RNA的结构进行预测。
废话不多说,直接上代码咯:
二.环境配置
data介绍:



以上就是关于本次比赛的数据集介绍,还有一些时间节点的说明:

PS:本次代码运行在kaggle服务器上
2.1加载我们的输入数据:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
for filename in filenames:
print(os.path.join(dirname, filename))输出:

这个文件夹是我们模型的依赖文件和输入的数据。
2.2安装依赖
%ls /kaggle/input/boltz-dependencies
上图是我们需要安装的依赖,接下来我们执行安装命令:
!pip install --no-index /kaggle/input/boltz-dependencies/*whl --no-deps开始安装:

继续安装:
!pip install --no-index /kaggle/input/fairscale-0413/*whl --no-deps
继续安装下一个(boipython):
!pip install --no-index /kaggle/input/biopython/*whl --no-deps
至此,依赖已经安装完成。
2.3准备脚本

2.4进行预测
大概的思路是这样的:

先上代码:
%%writefile inference.py
import pickle
import urllib.request
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Literal, Optional
import click
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm
from boltz.data import const
from boltz.data.module.inference import BoltzInferenceDataModule
from boltz.data.msa.mmseqs2 import run_mmseqs2
from boltz.data.parse.a3m import parse_a3m
from boltz.data.parse.csv import parse_csv
from boltz.data.parse.fasta import parse_fasta
from boltz.data.parse.yaml import parse_yaml
from boltz.data.types import MSA, Manifest, Record
from boltz.data.write.writer import BoltzWriter
from boltz.model.model import Boltz1
CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
MODEL_URL = (
"https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1_conf.ckpt"
)
@dataclass
class BoltzProcessedInput:
"""Processed input data."""
manifest: Manifest
targets_dir: Path
msa_dir: Path
@dataclass
class BoltzDiffusionParams:
"""Diffusion process parameters."""
gamma_0: float = 0.605
gamma_min: float = 1.107
noise_scale: float = 0.901
rho: float = 8
step_scale: float = 1.638
sigma_min: float = 0.0004
sigma_max: float = 160.0
sigma_data: float = 16.0
P_mean: float = -1.2
P_std: float = 1.5
coordinate_augmentation: bool = True
alignment_reverse_diff: bool = True
synchronize_sigmas: bool = True
use_inference_model_cache: bool = True
@rank_zero_only
def download(cache: Path) -> None:
"""Download all the required data.
Parameters
----------
cache : Path
The cache directory.
"""
# Download CCD
ccd = cache / "ccd.pkl"
if not ccd.exists():
click.echo(
f"Downloading the CCD dictionary to {ccd}. You may "
"change the cache directory with the --cache flag."
)
urllib.request.urlretrieve(CCD_URL, str(ccd)) # noqa: S310
# Download model
model = cache / "boltz1_conf.ckpt"
if not model.exists():
click.echo(
f"Downloading the model weights to {model}. You may "
"change the cache directory with the --cache flag."
)
urllib.request.urlretrieve(MODEL_URL, str(model)) # noqa: S310
def check_inputs(
data: Path,
outdir: Path,
override: bool = False,
) -> list[Path]:
"""Check the input data and output directory.
If the input data is a directory, it will be expanded
to all files in this directory. Then, we check if there
are any existing predictions and remove them from the
list of input data, unless the override flag is set.
Parameters
----------
data : Path
The input data.
outdir : Path
The output directory.
override: bool
Whether to override existing predictions.
Returns
-------
list[Path]
The list of input data.
"""
click.echo("Checking input data.")
# Check if data is a directory
if data.is_dir():
data: list[Path] = list(data.glob("*"))
# Filter out non .fasta or .yaml files, raise
# an error on directory and other file types
filtered_data = []
for d in data:
if d.suffix in (".fa", ".fas", ".fasta", ".yml", ".yaml"):
filtered_data.append(d)
elif d.is_dir():
msg = f"Found directory {d} instead of .fasta or .yaml."
raise RuntimeError(msg)
else:
msg = (
f"Unable to parse filetype {d.suffix}, "
"please provide a .fasta or .yaml file."
)
raise RuntimeError(msg)
data = filtered_data
else:
data = [data]
# Check if existing predictions are found
existing = (outdir / "predictions").rglob("*")
existing = {e.name for e in existing if e.is_dir()}
# Remove them from the input data
if existing and not override:
data = [d for d in data if d.stem not in existing]
num_skipped = len(existing) - len(data)
msg = (
f"Found some existing predictions ({num_skipped}), "
f"skipping and running only the missing ones, "
"if any. If you wish to override these existing "
"predictions, please set the --override flag."
)
click.echo(msg)
elif existing and override:
msg = "Found existing predictions, will override."
click.echo(msg)
return data
def compute_msa(
data: dict[str, str],
target_id: str,
msa_dir: Path,
msa_server_url: str,
msa_pairing_strategy: str,
) -> None:
"""Compute the MSA for the input data.
Parameters
----------
data : dict[str, str]
The input protein sequences.
target_id : str
The target id.
msa_dir : Path
The msa directory.
msa_server_url : str
The MSA server URL.
msa_pairing_strategy : str
The MSA pairing strategy.
"""
if len(data) > 1:
paired_msas = run_mmseqs2(
list(data.values()),
msa_dir / f"{target_id}_paired_tmp",
use_env=True,
use_pairing=True,
host_url=msa_server_url,
pairing_strategy=msa_pairing_strategy,
)
else:
paired_msas = [""] * len(data)
unpaired_msa = run_mmseqs2(
list(data.values()),
msa_dir / f"{target_id}_unpaired_tmp",
use_env=True,
use_pairing=False,
host_url=msa_server_url,
pairing_strategy=msa_pairing_strategy,
)
for idx, name in enumerate(data):
# Get paired sequences
paired = paired_msas[idx].strip().splitlines()
paired = paired[1::2] # ignore headers
paired = paired[: const.max_paired_seqs]
# Set key per row and remove empty sequences
keys = [idx for idx, s in enumerate(paired) if s != "-" * len(s)]
paired = [s for s in paired if s != "-" * len(s)]
# Combine paired-unpaired sequences
unpaired = unpaired_msa[idx].strip().splitlines()
unpaired = unpaired[1::2]
unpaired = unpaired[: (const.max_msa_seqs - len(paired))]
if paired:
unpaired = unpaired[1:] # ignore query is already present
# Combine
seqs = paired + unpaired
keys = keys + [-1] * len(unpaired)
# Dump MSA
csv_str = ["key,sequence"] + [f"{key},{seq}" for key, seq in zip(keys, seqs)]
msa_path = msa_dir / f"{name}.csv"
with msa_path.open("w") as f:
f.write("\n".join(csv_str))
@rank_zero_only
def process_inputs( # noqa: C901, PLR0912, PLR0915
data: list[Path],
out_dir: Path,
ccd_path: Path,
msa_server_url: str,
msa_pairing_strategy: str,
max_msa_seqs: int = 4096,
use_msa_server: bool = False,
) -> None:
"""Process the input data and output directory.
Parameters
----------
data : list[Path]
The input data.
out_dir : Path
The output directory.
ccd_path : Path
The path to the CCD dictionary.
max_msa_seqs : int, optional
Max number of MSA sequences, by default 4096.
use_msa_server : bool, optional
Whether to use the MMSeqs2 server for MSA generation, by default False.
Returns
-------
BoltzProcessedInput
The processed input data.
"""
click.echo("Processing input data.")
existing_records = None
# Check if manifest exists at output path
manifest_path = out_dir / "processed" / "manifest.json"
if manifest_path.exists():
click.echo(f"Found a manifest file at output directory: {out_dir}")
manifest: Manifest = Manifest.load(manifest_path)
input_ids = [d.stem for d in data]
existing_records, processed_ids = zip(
*[
(record, record.id)
for record in manifest.records
if record.id in input_ids
]
)
if isinstance(existing_records, tuple):
existing_records = list(existing_records)
# Check how many examples need to be processed
missing = len(input_ids) - len(processed_ids)
if not missing:
click.echo("All examples in data are processed. Updating the manifest")
# Dump updated manifest
updated_manifest = Manifest(existing_records)
updated_manifest.dump(out_dir / "processed" / "manifest.json")
return
click.echo(f"{missing} missing ids. Preprocessing these ids")
missing_ids = list(set(input_ids).difference(set(processed_ids)))
data = [d for d in data if d.stem in missing_ids]
assert len(data) == len(missing_ids)
# Create output directories
msa_dir = out_dir / "msa"
structure_dir = out_dir / "processed" / "structures"
processed_msa_dir = out_dir / "processed" / "msa"
predictions_dir = out_dir / "predictions"
out_dir.mkdir(parents=True, exist_ok=True)
msa_dir.mkdir(parents=True, exist_ok=True)
structure_dir.mkdir(parents=True, exist_ok=True)
processed_msa_dir.mkdir(parents=True, exist_ok=True)
predictions_dir.mkdir(parents=True, exist_ok=True)
# Load CCD
with ccd_path.open("rb") as file:
ccd = pickle.load(file) # noqa: S301
if existing_records is not None:
click.echo(f"Found {len(existing_records)} records. Adding them to records")
# Parse input data
records: list[Record] = existing_records if existing_records is not None else []
for path in tqdm(data):
try:
# Parse data
if path.suffix in (".fa", ".fas", ".fasta"):
target = parse_fasta(path, ccd)
elif path.suffix in (".yml", ".yaml"):
target = parse_yaml(path, ccd)
elif path.is_dir():
msg = f"Found directory {path} instead of .fasta or .yaml, skipping."
raise RuntimeError(msg)
else:
msg = (
f"Unable to parse filetype {path.suffix}, "
"please provide a .fasta or .yaml file."
)
raise RuntimeError(msg)
# Get target id
target_id = target.record.id
# Get all MSA ids and decide whether to generate MSA
to_generate = {}
prot_id = const.chain_type_ids["PROTEIN"]
for chain in target.record.chains:
# Add to generate list, assigning entity id
if (chain.mol_type == prot_id) and (chain.msa_id == 0):
entity_id = chain.entity_id
msa_id = f"{target_id}_{entity_id}"
to_generate[msa_id] = target.sequences[entity_id]
chain.msa_id = msa_dir / f"{msa_id}.csv"
# We do not support msa generation for non-protein chains
elif chain.msa_id == 0:
chain.msa_id = -1
# Generate MSA
if to_generate and not use_msa_server:
msg = "Missing MSA's in input and --use_msa_server flag not set."
raise RuntimeError(msg)
if to_generate:
msg = f"Generating MSA for {path} with {len(to_generate)} protein entities."
click.echo(msg)
compute_msa(
data=to_generate,
target_id=target_id,
msa_dir=msa_dir,
msa_server_url=msa_server_url,
msa_pairing_strategy=msa_pairing_strategy,
)
# Parse MSA data
msas = sorted({c.msa_id for c in target.record.chains if c.msa_id != -1})
msa_id_map = {}
for msa_idx, msa_id in enumerate(msas):
# Check that raw MSA exists
msa_path = Path(msa_id)
if not msa_path.exists():
msg = f"MSA file {msa_path} not found."
raise FileNotFoundError(msg)
# Dump processed MSA
processed = processed_msa_dir / f"{target_id}_{msa_idx}.npz"
msa_id_map[msa_id] = f"{target_id}_{msa_idx}"
if not processed.exists():
# Parse A3M
if msa_path.suffix == ".a3m":
msa: MSA = parse_a3m(
msa_path,
taxonomy=None,
max_seqs=max_msa_seqs,
)
elif msa_path.suffix == ".csv":
msa: MSA = parse_csv(msa_path, max_seqs=max_msa_seqs)
else:
msg = f"MSA file {msa_path} not supported, only a3m or csv."
raise RuntimeError(msg)
msa.dump(processed)
# Modify records to point to processed MSA
for c in target.record.chains:
if (c.msa_id != -1) and (c.msa_id in msa_id_map):
c.msa_id = msa_id_map[c.msa_id]
# Keep record
records.append(target.record)
# Dump structure
struct_path = structure_dir / f"{target.record.id}.npz"
target.structure.dump(struct_path)
except Exception as e:
if len(data) > 1:
print(f"Failed to process {path}. Skipping. Error: {e}.")
else:
raise e
# Dump manifest
manifest = Manifest(records)
manifest.dump(out_dir / "processed" / "manifest.json")
def predict(
data: str,
out_dir: str,
cache: str = "~/.boltz",
checkpoint: Optional[str] = None,
devices: int = 1,
accelerator: str = "gpu",
recycling_steps: int = 3,
sampling_steps: int = 200,
diffusion_samples: int = 1,
step_scale: float = 1.638,
write_full_pae: bool = False,
write_full_pde: bool = False,
output_format: Literal["pdb", "mmcif"] = "mmcif",
num_workers: int = 2,
override: bool = False,
seed: Optional[int] = None,
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
msa_pairing_strategy: str = "greedy",
) -> None:
"""Run predictions with Boltz-1."""
# If cpu, write a friendly warning
if accelerator == "cpu":
msg = "Running on CPU, this will be slow. Consider using a GPU."
click.echo(msg)
# Set no grad
torch.set_grad_enabled(False)
# Ignore matmul precision warning
torch.set_float32_matmul_precision("highest")
# Set seed if desired
if seed is not None:
seed_everything(int(seed))
# Set cache path
cache = Path(cache).expanduser()
cache.mkdir(parents=True, exist_ok=True)
# Create output directories
data = Path(data).expanduser()
out_dir = Path(out_dir).expanduser()
out_dir = out_dir / f"boltz_results_{data.stem}"
out_dir.mkdir(parents=True, exist_ok=True)
# Download necessary data and model
download(cache)
# Validate inputs
data = check_inputs(data, out_dir, override)
if not data:
click.echo("No predictions to run, exiting.")
return
# Set up trainer
strategy = "auto"
if (isinstance(devices, int) and devices > 1) or (
isinstance(devices, list) and len(devices) > 1
):
strategy = DDPStrategy()
if len(data) < devices:
msg = (
"Number of requested devices is greater "
"than the number of predictions."
)
raise ValueError(msg)
msg = f"Running predictions for {len(data)} structure"
msg += "s" if len(data) > 1 else ""
click.echo(msg)
# Process inputs
ccd_path = cache / "ccd.pkl"
process_inputs(
data=data,
out_dir=out_dir,
ccd_path=ccd_path,
use_msa_server=use_msa_server,
msa_server_url=msa_server_url,
msa_pairing_strategy=msa_pairing_strategy,
)
# Load processed data
processed_dir = out_dir / "processed"
processed = BoltzProcessedInput(
manifest=Manifest.load(processed_dir / "manifest.json"),
targets_dir=processed_dir / "structures",
msa_dir=processed_dir / "msa",
)
# Create data module
data_module = BoltzInferenceDataModule(
manifest=processed.manifest,
target_dir=processed.targets_dir,
msa_dir=processed.msa_dir,
num_workers=num_workers,
)
# Load model
if checkpoint is None:
checkpoint = cache / "boltz1_conf.ckpt"
predict_args = {
"recycling_steps": recycling_steps,
"sampling_steps": sampling_steps,
"diffusion_samples": diffusion_samples,
"write_confidence_summary": True,
"write_full_pae": write_full_pae,
"write_full_pde": write_full_pde,
}
diffusion_params = BoltzDiffusionParams()
diffusion_params.step_scale = step_scale
model_module: Boltz1 = Boltz1.load_from_checkpoint(
checkpoint,
strict=True,
predict_args=predict_args,
map_location="cpu",
diffusion_process_args=asdict(diffusion_params),
ema=False,
)
model_module.eval()
# Create prediction writer
pred_writer = BoltzWriter(
data_dir=processed.targets_dir,
output_dir=out_dir / "predictions",
output_format=output_format,
)
trainer = Trainer(
default_root_dir=out_dir,
strategy=strategy,
callbacks=[pred_writer],
accelerator=accelerator,
devices=devices,
precision=32,
)
# Compute predictions
trainer.predict(
model_module,
datamodule=data_module,
return_predictions=False,
)
if __name__ == "__main__":
predict(data="./inputs_prediction",
out_dir="./outputs_prediction",
cache="/kaggle/input/rna-prediction-boltz/",
diffusion_samples=5,
seed=42,
override=True)让我们稍微解释一下这个代码:
以下是针对该代码的详细解释,包括其逻辑、各模块作用及关键步骤的用途说明:
该脚本(inference.py)是用于蛋白质结构预测模型 Boltz-1 的推理流程代码,涉及数据预处理、多序列比对(MSA)生成、模型加载、预测过程管理,以及预测结果保存。
定义了两个重要的数据类:
BoltzProcessedInput:BoltzDiffusionParams:download(cache: Path)
从远程服务器下载必要的文件CCD_URLMODEL_URLcheck_inputs(...) .fasta或.yaml格式。override=True)。compute_msa(...)MSA结果以CSV形式存储:
key,sequence
0,MSA序列1
-1,MSA序列2
manifest.json文件,记录所有处理好的数据路径及信息。
流程:tqdm显示进度条。该函数定义完整的预测流程:
data:输入数据路径。out_dir:输出目录。checkpoint:模型权重路径。devices:使用的GPU/CPU数量。accelerator:使用的计算设备类型。sampling_steps、diffusion_samples:扩散模型推理参数。output_format:预测结果输出格式(PDB或mmCIF)。Trainer类,设置分布式计算策略(如有多GPU)。process_inputs进行数据预处理(生成或加载MSA)。Boltz1),加载权重并设置预测参数。BoltzInferenceDataModule加载处理后的数据。BoltzWriter回调函数将预测结果写入磁盘。trainer.predict(...))。给出了明确的调用示例:
predict(
data="./inputs_prediction",
out_dir="./outputs_prediction",
cache="/kaggle/input/rna-prediction-boltz/",
diffusion_samples=5,
seed=42,
override=True
)
click:提供CLI(命令行界面)支持。torch和pytorch_lightning:PyTorch深度学习框架及其分布式训练工具。tqdm:进度条显示。执行推理后的目录结构可能如下:
outputs_prediction
├── msa # MSA中间数据(CSV)
├── processed
│ ├── structures # 处理后的结构文件(npz)
│ ├── msa # 处理后的MSA(npz)
│ └── manifest.json # 数据处理记录
└── predictions # 最终预测结构(PDB或mmCIF)
三.准备预测数据
sub_file = pd.read_csv('/kaggle/input/stanford-rna-3d-folding/test_sequences.csv')
sub_file.head()
names = sub_file['target_id'].tolist()
sequences = sub_file['sequence'].tolist()
# Inference
idx = 0
for tmp_id, tmp_sequence in zip(names, sequences):
with open(f'/kaggle/working/inputs_prediction/{tmp_id}.yaml', 'w') as f:
f.write("constraints: []\n")
f.write("sequences:\n")
f.write("- rna:\n")
f.write(" id:\n")
f.write(" - A1\n")
f.write(f" sequence: {tmp_sequence}")我们来看一下这个文件:

这一步的代码作用是为了把csv中的rna序列拆分成为单独的文件,运行结果如下图:

四.开始预测
这里需要说明的是,本次预测的执行方案是通过命令行里去执行pythn文件,图中的inference.py就是第三部分的代码。

打印一下最后的result:

五.获取结果
from Bio.PDB.MMCIF2Dict import MMCIF2Dict
def get_coords(tmp_id, idx):
cif_file = f"outputs_prediction/boltz_results_inputs_prediction/predictions/{tmp_id}/{tmp_id}_model_{idx}.cif"
mmcif_dict = MMCIF2Dict(cif_file)
entity_poly_seq = mmcif_dict.get("_entity_poly_seq.mon_id", [])
sequence = "".join(entity_poly_seq)
print("RNA sequence:", sequence)
x_coords = mmcif_dict["_atom_site.Cartn_x"]
y_coords = mmcif_dict["_atom_site.Cartn_y"]
z_coords = mmcif_dict["_atom_site.Cartn_z"]
atom_names = mmcif_dict["_atom_site.label_atom_id"]
c1_coords = []
for i, atom in enumerate(atom_names):
if atom == "C1'":
c1_coords.append((float(x_coords[i]), float(y_coords[i]), float(z_coords[i])))
return c1_coords
all_preds = os.listdir('outputs_prediction/boltz_results_inputs_prediction/predictions')
submission = pd.read_csv('/kaggle/input/stanford-rna-3d-folding/sample_submission.csv')这一步代码的主要作用是通过blotz-1预测出来的到的cif文件来提取每个核苷酸的C1'原子坐标,用于构建三维空间中的RNA骨架。
应赛事要求,每条rna的的每个核苷酸需要预测五组坐标:

所以让我们来看一下,我们的输出:

与此同时我们需要将这个输出转化成为csv文件:

可以看一下csv文件的内容:

这个文件记录了每个rna序列的id和核苷酸的类型,后面分别是预测的五组坐标,每一组坐标包含了x,y,z的坐标。
这里大家应该很熟悉了,蛋白质的pdb文件就是主要由这些信息组成的,rna的cif文件也是由这些信息组成。
那么今天的分享就到这里咯,下一篇文章我们会来看看如何计算rna预测结构和真实结构之间的loss,以及如何使用tm-score来打分。
回见咯。