首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >使用Boltz-1进行RNA结构预测

使用Boltz-1进行RNA结构预测

作者头像
Tom2Code
发布2026-04-17 17:24:35
发布2026-04-17 17:24:35
00
举报

looooooooong time no post!

我又回来咯,谢谢各位粉丝大佬们的捧场。

今天要分享的是kaggle的一场比赛:Stanford RNA 3D Folding

因为这个比赛的模型和可以学习的地方有很多,所以今天只是一个简单的inference的demo代码,先跑通,然后再一步一步了解。所以各位粉丝朋友持续关注就好咯,有时间我就会更新的。

一.Boltz-1模型介绍

一句话来介绍这个模型就是:

Boltz-1 是第一个达到 AlphaFold3 报告准确度水平、完全可商用的开源深度学习模型,用于预测生物分子复合物的 3D 结构,是由麻省理工学院(MIT)的研究人员宣布推出一款开源模型。所以boltz-1的模型性能是和alphafold3相比肩的。!

从图中可以看出Boltz-1在预测蛋白质结构方面可以达到很高的精度。

下图是boltz-1在casp15大赛中的表现:

所以我们本章要分享的内容就是使用Boltz-1模型的权重对RNA的结构进行预测。

废话不多说,直接上代码咯:

二.环境配置

data介绍:

以上就是关于本次比赛的数据集介绍,还有一些时间节点的说明:

PS:本次代码运行在kaggle服务器上

2.1加载我们的输入数据:

代码语言:javascript
复制
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

输出:

这个文件夹是我们模型的依赖文件和输入的数据。

2.2安装依赖

代码语言:javascript
复制
%ls /kaggle/input/boltz-dependencies

上图是我们需要安装的依赖,接下来我们执行安装命令:

代码语言:javascript
复制
!pip install --no-index /kaggle/input/boltz-dependencies/*whl --no-deps

开始安装:

继续安装:

代码语言:javascript
复制
!pip install --no-index /kaggle/input/fairscale-0413/*whl --no-deps

继续安装下一个(boipython):

代码语言:javascript
复制
!pip install --no-index /kaggle/input/biopython/*whl --no-deps

至此,依赖已经安装完成。

2.3准备脚本

2.4进行预测

大概的思路是这样的:

先上代码:

代码语言:javascript
复制
%%writefile inference.py

import pickle
import urllib.request
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Literal, Optional

import click
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm

from boltz.data import const
from boltz.data.module.inference import BoltzInferenceDataModule
from boltz.data.msa.mmseqs2 import run_mmseqs2
from boltz.data.parse.a3m import parse_a3m
from boltz.data.parse.csv import parse_csv
from boltz.data.parse.fasta import parse_fasta
from boltz.data.parse.yaml import parse_yaml
from boltz.data.types import MSA, Manifest, Record
from boltz.data.write.writer import BoltzWriter
from boltz.model.model import Boltz1

CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
MODEL_URL = (
    "https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1_conf.ckpt"
)


@dataclass
class BoltzProcessedInput:
    """Processed input data."""

    manifest: Manifest
    targets_dir: Path
    msa_dir: Path


@dataclass
class BoltzDiffusionParams:
    """Diffusion process parameters."""

    gamma_0: float = 0.605
    gamma_min: float = 1.107
    noise_scale: float = 0.901
    rho: float = 8
    step_scale: float = 1.638
    sigma_min: float = 0.0004
    sigma_max: float = 160.0
    sigma_data: float = 16.0
    P_mean: float = -1.2
    P_std: float = 1.5
    coordinate_augmentation: bool = True
    alignment_reverse_diff: bool = True
    synchronize_sigmas: bool = True
    use_inference_model_cache: bool = True


@rank_zero_only
def download(cache: Path) -> None:
    """Download all the required data.

    Parameters
    ----------
    cache : Path
        The cache directory.

    """
    # Download CCD
    ccd = cache / "ccd.pkl"
    if not ccd.exists():
        click.echo(
            f"Downloading the CCD dictionary to {ccd}. You may "
            "change the cache directory with the --cache flag."
        )
        urllib.request.urlretrieve(CCD_URL, str(ccd))  # noqa: S310

    # Download model
    model = cache / "boltz1_conf.ckpt"
    if not model.exists():
        click.echo(
            f"Downloading the model weights to {model}. You may "
            "change the cache directory with the --cache flag."
        )
        urllib.request.urlretrieve(MODEL_URL, str(model))  # noqa: S310


def check_inputs(
    data: Path,
    outdir: Path,
    override: bool = False,
) -> list[Path]:
    """Check the input data and output directory.

    If the input data is a directory, it will be expanded
    to all files in this directory. Then, we check if there
    are any existing predictions and remove them from the
    list of input data, unless the override flag is set.

    Parameters
    ----------
    data : Path
        The input data.
    outdir : Path
        The output directory.
    override: bool
        Whether to override existing predictions.

    Returns
    -------
    list[Path]
        The list of input data.

    """
    click.echo("Checking input data.")

    # Check if data is a directory
    if data.is_dir():
        data: list[Path] = list(data.glob("*"))

        # Filter out non .fasta or .yaml files, raise
        # an error on directory and other file types
        filtered_data = []
        for d in data:
            if d.suffix in (".fa", ".fas", ".fasta", ".yml", ".yaml"):
                filtered_data.append(d)
            elif d.is_dir():
                msg = f"Found directory {d} instead of .fasta or .yaml."
                raise RuntimeError(msg)
            else:
                msg = (
                    f"Unable to parse filetype {d.suffix}, "
                    "please provide a .fasta or .yaml file."
                )
                raise RuntimeError(msg)

        data = filtered_data
    else:
        data = [data]

    # Check if existing predictions are found
    existing = (outdir / "predictions").rglob("*")
    existing = {e.name for e in existing if e.is_dir()}

    # Remove them from the input data
    if existing and not override:
        data = [d for d in data if d.stem not in existing]
        num_skipped = len(existing) - len(data)
        msg = (
            f"Found some existing predictions ({num_skipped}), "
            f"skipping and running only the missing ones, "
            "if any. If you wish to override these existing "
            "predictions, please set the --override flag."
        )
        click.echo(msg)
    elif existing and override:
        msg = "Found existing predictions, will override."
        click.echo(msg)

    return data


def compute_msa(
    data: dict[str, str],
    target_id: str,
    msa_dir: Path,
    msa_server_url: str,
    msa_pairing_strategy: str,
) -> None:
    """Compute the MSA for the input data.

    Parameters
    ----------
    data : dict[str, str]
        The input protein sequences.
    target_id : str
        The target id.
    msa_dir : Path
        The msa directory.
    msa_server_url : str
        The MSA server URL.
    msa_pairing_strategy : str
        The MSA pairing strategy.

    """
    if len(data) > 1:
        paired_msas = run_mmseqs2(
            list(data.values()),
            msa_dir / f"{target_id}_paired_tmp",
            use_env=True,
            use_pairing=True,
            host_url=msa_server_url,
            pairing_strategy=msa_pairing_strategy,
        )
    else:
        paired_msas = [""] * len(data)

    unpaired_msa = run_mmseqs2(
        list(data.values()),
        msa_dir / f"{target_id}_unpaired_tmp",
        use_env=True,
        use_pairing=False,
        host_url=msa_server_url,
        pairing_strategy=msa_pairing_strategy,
    )

    for idx, name in enumerate(data):
        # Get paired sequences
        paired = paired_msas[idx].strip().splitlines()
        paired = paired[1::2]  # ignore headers
        paired = paired[: const.max_paired_seqs]

        # Set key per row and remove empty sequences
        keys = [idx for idx, s in enumerate(paired) if s != "-" * len(s)]
        paired = [s for s in paired if s != "-" * len(s)]

        # Combine paired-unpaired sequences
        unpaired = unpaired_msa[idx].strip().splitlines()
        unpaired = unpaired[1::2]
        unpaired = unpaired[: (const.max_msa_seqs - len(paired))]
        if paired:
            unpaired = unpaired[1:]  # ignore query is already present

        # Combine
        seqs = paired + unpaired
        keys = keys + [-1] * len(unpaired)

        # Dump MSA
        csv_str = ["key,sequence"] + [f"{key},{seq}" for key, seq in zip(keys, seqs)]

        msa_path = msa_dir / f"{name}.csv"
        with msa_path.open("w") as f:
            f.write("\n".join(csv_str))


@rank_zero_only
def process_inputs(  # noqa: C901, PLR0912, PLR0915
    data: list[Path],
    out_dir: Path,
    ccd_path: Path,
    msa_server_url: str,
    msa_pairing_strategy: str,
    max_msa_seqs: int = 4096,
    use_msa_server: bool = False,
) -> None:
    """Process the input data and output directory.

    Parameters
    ----------
    data : list[Path]
        The input data.
    out_dir : Path
        The output directory.
    ccd_path : Path
        The path to the CCD dictionary.
    max_msa_seqs : int, optional
        Max number of MSA sequences, by default 4096.
    use_msa_server : bool, optional
        Whether to use the MMSeqs2 server for MSA generation, by default False.

    Returns
    -------
    BoltzProcessedInput
        The processed input data.

    """
    click.echo("Processing input data.")
    existing_records = None

    # Check if manifest exists at output path
    manifest_path = out_dir / "processed" / "manifest.json"
    if manifest_path.exists():
        click.echo(f"Found a manifest file at output directory: {out_dir}")

        manifest: Manifest = Manifest.load(manifest_path)
        input_ids = [d.stem for d in data]
        existing_records, processed_ids = zip(
            *[
                (record, record.id)
                for record in manifest.records
                if record.id in input_ids
            ]
        )

        if isinstance(existing_records, tuple):
            existing_records = list(existing_records)

        # Check how many examples need to be processed
        missing = len(input_ids) - len(processed_ids)
        if not missing:
            click.echo("All examples in data are processed. Updating the manifest")
            # Dump updated manifest
            updated_manifest = Manifest(existing_records)
            updated_manifest.dump(out_dir / "processed" / "manifest.json")
            return

        click.echo(f"{missing} missing ids. Preprocessing these ids")
        missing_ids = list(set(input_ids).difference(set(processed_ids)))
        data = [d for d in data if d.stem in missing_ids]
        assert len(data) == len(missing_ids)

    # Create output directories
    msa_dir = out_dir / "msa"
    structure_dir = out_dir / "processed" / "structures"
    processed_msa_dir = out_dir / "processed" / "msa"
    predictions_dir = out_dir / "predictions"

    out_dir.mkdir(parents=True, exist_ok=True)
    msa_dir.mkdir(parents=True, exist_ok=True)
    structure_dir.mkdir(parents=True, exist_ok=True)
    processed_msa_dir.mkdir(parents=True, exist_ok=True)
    predictions_dir.mkdir(parents=True, exist_ok=True)

    # Load CCD
    with ccd_path.open("rb") as file:
        ccd = pickle.load(file)  # noqa: S301

    if existing_records is not None:
        click.echo(f"Found {len(existing_records)} records. Adding them to records")

    # Parse input data
    records: list[Record] = existing_records if existing_records is not None else []
    for path in tqdm(data):
        try:
            # Parse data
            if path.suffix in (".fa", ".fas", ".fasta"):
                target = parse_fasta(path, ccd)
            elif path.suffix in (".yml", ".yaml"):
                target = parse_yaml(path, ccd)
            elif path.is_dir():
                msg = f"Found directory {path} instead of .fasta or .yaml, skipping."
                raise RuntimeError(msg)
            else:
                msg = (
                    f"Unable to parse filetype {path.suffix}, "
                    "please provide a .fasta or .yaml file."
                )
                raise RuntimeError(msg)

            # Get target id
            target_id = target.record.id

            # Get all MSA ids and decide whether to generate MSA
            to_generate = {}
            prot_id = const.chain_type_ids["PROTEIN"]
            for chain in target.record.chains:
                # Add to generate list, assigning entity id
                if (chain.mol_type == prot_id) and (chain.msa_id == 0):
                    entity_id = chain.entity_id
                    msa_id = f"{target_id}_{entity_id}"
                    to_generate[msa_id] = target.sequences[entity_id]
                    chain.msa_id = msa_dir / f"{msa_id}.csv"

                # We do not support msa generation for non-protein chains
                elif chain.msa_id == 0:
                    chain.msa_id = -1

            # Generate MSA
            if to_generate and not use_msa_server:
                msg = "Missing MSA's in input and --use_msa_server flag not set."
                raise RuntimeError(msg)

            if to_generate:
                msg = f"Generating MSA for {path} with {len(to_generate)} protein entities."
                click.echo(msg)
                compute_msa(
                    data=to_generate,
                    target_id=target_id,
                    msa_dir=msa_dir,
                    msa_server_url=msa_server_url,
                    msa_pairing_strategy=msa_pairing_strategy,
                )

            # Parse MSA data
            msas = sorted({c.msa_id for c in target.record.chains if c.msa_id != -1})
            msa_id_map = {}
            for msa_idx, msa_id in enumerate(msas):
                # Check that raw MSA exists
                msa_path = Path(msa_id)
                if not msa_path.exists():
                    msg = f"MSA file {msa_path} not found."
                    raise FileNotFoundError(msg)

                # Dump processed MSA
                processed = processed_msa_dir / f"{target_id}_{msa_idx}.npz"
                msa_id_map[msa_id] = f"{target_id}_{msa_idx}"
                if not processed.exists():
                    # Parse A3M
                    if msa_path.suffix == ".a3m":
                        msa: MSA = parse_a3m(
                            msa_path,
                            taxonomy=None,
                            max_seqs=max_msa_seqs,
                        )
                    elif msa_path.suffix == ".csv":
                        msa: MSA = parse_csv(msa_path, max_seqs=max_msa_seqs)
                    else:
                        msg = f"MSA file {msa_path} not supported, only a3m or csv."
                        raise RuntimeError(msg)

                    msa.dump(processed)

            # Modify records to point to processed MSA
            for c in target.record.chains:
                if (c.msa_id != -1) and (c.msa_id in msa_id_map):
                    c.msa_id = msa_id_map[c.msa_id]

            # Keep record
            records.append(target.record)

            # Dump structure
            struct_path = structure_dir / f"{target.record.id}.npz"
            target.structure.dump(struct_path)

        except Exception as e:
            if len(data) > 1:
                print(f"Failed to process {path}. Skipping. Error: {e}.")
            else:
                raise e

    # Dump manifest
    manifest = Manifest(records)
    manifest.dump(out_dir / "processed" / "manifest.json")

def predict(
    data: str,
    out_dir: str,
    cache: str = "~/.boltz",
    checkpoint: Optional[str] = None,
    devices: int = 1,
    accelerator: str = "gpu",
    recycling_steps: int = 3,
    sampling_steps: int = 200,
    diffusion_samples: int = 1,
    step_scale: float = 1.638,
    write_full_pae: bool = False,
    write_full_pde: bool = False,
    output_format: Literal["pdb", "mmcif"] = "mmcif",
    num_workers: int = 2,
    override: bool = False,
    seed: Optional[int] = None,
    use_msa_server: bool = False,
    msa_server_url: str = "https://api.colabfold.com",
    msa_pairing_strategy: str = "greedy",
) -> None:
    """Run predictions with Boltz-1."""
    # If cpu, write a friendly warning
    if accelerator == "cpu":
        msg = "Running on CPU, this will be slow. Consider using a GPU."
        click.echo(msg)

    # Set no grad
    torch.set_grad_enabled(False)

    # Ignore matmul precision warning
    torch.set_float32_matmul_precision("highest")

    # Set seed if desired
    if seed is not None:
        seed_everything(int(seed))

    # Set cache path
    cache = Path(cache).expanduser()
    cache.mkdir(parents=True, exist_ok=True)

    # Create output directories
    data = Path(data).expanduser()
    out_dir = Path(out_dir).expanduser()
    out_dir = out_dir / f"boltz_results_{data.stem}"
    out_dir.mkdir(parents=True, exist_ok=True)

    # Download necessary data and model
    download(cache)

    # Validate inputs
    data = check_inputs(data, out_dir, override)
    if not data:
        click.echo("No predictions to run, exiting.")
        return

    # Set up trainer
    strategy = "auto"
    if (isinstance(devices, int) and devices > 1) or (
        isinstance(devices, list) and len(devices) > 1
    ):
        strategy = DDPStrategy()
        if len(data) < devices:
            msg = (
                "Number of requested devices is greater "
                "than the number of predictions."
            )
            raise ValueError(msg)

    msg = f"Running predictions for {len(data)} structure"
    msg += "s" if len(data) > 1 else ""
    click.echo(msg)

    # Process inputs
    ccd_path = cache / "ccd.pkl"
    process_inputs(
        data=data,
        out_dir=out_dir,
        ccd_path=ccd_path,
        use_msa_server=use_msa_server,
        msa_server_url=msa_server_url,
        msa_pairing_strategy=msa_pairing_strategy,
    )

    # Load processed data
    processed_dir = out_dir / "processed"
    processed = BoltzProcessedInput(
        manifest=Manifest.load(processed_dir / "manifest.json"),
        targets_dir=processed_dir / "structures",
        msa_dir=processed_dir / "msa",
    )

    # Create data module
    data_module = BoltzInferenceDataModule(
        manifest=processed.manifest,
        target_dir=processed.targets_dir,
        msa_dir=processed.msa_dir,
        num_workers=num_workers,
    )

    # Load model
    if checkpoint is None:
        checkpoint = cache / "boltz1_conf.ckpt"

    predict_args = {
        "recycling_steps": recycling_steps,
        "sampling_steps": sampling_steps,
        "diffusion_samples": diffusion_samples,
        "write_confidence_summary": True,
        "write_full_pae": write_full_pae,
        "write_full_pde": write_full_pde,
    }
    diffusion_params = BoltzDiffusionParams()
    diffusion_params.step_scale = step_scale
    model_module: Boltz1 = Boltz1.load_from_checkpoint(
        checkpoint,
        strict=True,
        predict_args=predict_args,
        map_location="cpu",
        diffusion_process_args=asdict(diffusion_params),
        ema=False,
    )
    model_module.eval()

    # Create prediction writer
    pred_writer = BoltzWriter(
        data_dir=processed.targets_dir,
        output_dir=out_dir / "predictions",
        output_format=output_format,
    )

    trainer = Trainer(
        default_root_dir=out_dir,
        strategy=strategy,
        callbacks=[pred_writer],
        accelerator=accelerator,
        devices=devices,
        precision=32,
    )

    # Compute predictions
    trainer.predict(
        model_module,
        datamodule=data_module,
        return_predictions=False,
    )


if __name__ == "__main__":
    predict(data="./inputs_prediction",
            out_dir="./outputs_prediction",
            cache="/kaggle/input/rna-prediction-boltz/",
            diffusion_samples=5,
            seed=42,
            override=True)

让我们稍微解释一下这个代码:

以下是针对该代码的详细解释,包括其逻辑、各模块作用及关键步骤的用途说明:


(1)、整体介绍

该脚本(inference.py)是用于蛋白质结构预测模型 Boltz-1 的推理流程代码,涉及数据预处理、多序列比对(MSA)生成、模型加载、预测过程管理,以及预测结果保存。


(2)、核心模块与用途

a.数据类定义(Dataclasses)

定义了两个重要的数据类:

  • BoltzProcessedInput
    • 包含处理后的输入数据(Manifest文件路径、目标结构路径和MSA路径)。
  • BoltzDiffusionParams
    • 存储扩散模型(Diffusion Model)的各种超参数(例如噪声级别、扩散步长等)。

b.文件下载与缓存管理

  • 函数:download(cache: Path)
  • 从远程服务器下载必要的文件
    • CCD(化学成分字典)数据。
    • 模型权重文件(checkpoint)来自huggungface。
    • 文件下载地址定义于常量:
      • CCD_URL
      • MODEL_URL

c.输入数据检查与验证

  • 函数:check_inputs(...)
    • 输入:
      • 数据路径(文件夹或单个文件)。
      • 输出路径。
    • 功能:
      • 验证输入为.fasta.yaml格式。
      • 检测并跳过已有的预测结果(除非设置了override=True)。

d.多序列比对(MSA)计算

  • 函数:compute_msa(...)
    • 用于调用远程服务器(如MMseqs2)生成蛋白质序列的MSA。
    • MSA序列:
      • paired用于结构预测,能提供更可靠的比对信息。
      • unpaired则为补充信息。

MSA结果以CSV形式存储:

代码语言:javascript
复制
key,sequence
0,MSA序列1
-1,MSA序列2

e.数据预处理(process_inputs函数)

  • 功能:
    • 解析原始输入文件(.fasta或者.ymal)
    • 为每个蛋白质计算(或者加载)msa
    • 处理后的数据保存为npz格式的文件
    • 生成一个总的manifest.json文件,记录所有处理好的数据路径及信息。 流程:
    • 检查现有处理记录(Manifest),避免重复处理。
    • 使用tqdm显示进度条。
    • 解析输入文件 → 生成或加载MSA → 存储处理后的数据。

f.预测过程(predict函数)

该函数定义完整的预测流程:

  • 参数说明(部分关键参数)
    • data:输入数据路径。
    • out_dir:输出目录。
    • checkpoint:模型权重路径。
    • devices:使用的GPU/CPU数量。
    • accelerator:使用的计算设备类型。
    • sampling_stepsdiffusion_samples:扩散模型推理参数。
    • output_format:预测结果输出格式(PDB或mmCIF)。
  • 步骤:
    1. 设置PyTorch环境(禁用梯度计算、设定随机种子等)。
    2. 检查或创建数据缓存目录,下载必要的模型权重和数据文件。
    3. 验证输入数据、避免重复预测。
    4. 利用PyTorch Lightning的Trainer类,设置分布式计算策略(如有多GPU)。
    5. 调用process_inputs进行数据预处理(生成或加载MSA)。
    6. 初始化模型(Boltz1),加载权重并设置预测参数。
    7. 构建BoltzInferenceDataModule加载处理后的数据。
    8. 使用BoltzWriter回调函数将预测结果写入磁盘。
    9. 启动预测过程(trainer.predict(...))。

g.主函数入口(if __name__ == "__main__":)

给出了明确的调用示例:

代码语言:javascript
复制
predict(
    data="./inputs_prediction",
    out_dir="./outputs_prediction",
    cache="/kaggle/input/rna-prediction-boltz/",
    diffusion_samples=5,
    seed=42,
    override=True
)

(3)、关键第三方库

  • click:提供CLI(命令行界面)支持。
  • torchpytorch_lightning:PyTorch深度学习框架及其分布式训练工具。
  • tqdm:进度条显示。

(4)、输出文件结构示意图

执行推理后的目录结构可能如下:

代码语言:javascript
复制
outputs_prediction
├── msa                    # MSA中间数据(CSV)
├── processed
│   ├── structures         # 处理后的结构文件(npz)
│   ├── msa                # 处理后的MSA(npz)
│   └── manifest.json      # 数据处理记录
└── predictions            # 最终预测结构(PDB或mmCIF)

三.准备预测数据

代码语言:javascript
复制
sub_file = pd.read_csv('/kaggle/input/stanford-rna-3d-folding/test_sequences.csv')

sub_file.head()

names = sub_file['target_id'].tolist()
sequences = sub_file['sequence'].tolist()

# Inference
idx = 0 
for tmp_id, tmp_sequence in zip(names, sequences):
    with open(f'/kaggle/working/inputs_prediction/{tmp_id}.yaml', 'w') as f:
        f.write("constraints: []\n")
        f.write("sequences:\n")
        f.write("- rna:\n")
        f.write("    id:\n")
        f.write("    - A1\n")
        f.write(f"    sequence: {tmp_sequence}")

我们来看一下这个文件:

这一步的代码作用是为了把csv中的rna序列拆分成为单独的文件,运行结果如下图:

四.开始预测

这里需要说明的是,本次预测的执行方案是通过命令行里去执行pythn文件,图中的inference.py就是第三部分的代码。

打印一下最后的result:

五.获取结果

代码语言:javascript
复制
from Bio.PDB.MMCIF2Dict import MMCIF2Dict

def get_coords(tmp_id, idx):
    cif_file = f"outputs_prediction/boltz_results_inputs_prediction/predictions/{tmp_id}/{tmp_id}_model_{idx}.cif"

    mmcif_dict = MMCIF2Dict(cif_file)

    entity_poly_seq = mmcif_dict.get("_entity_poly_seq.mon_id", [])
    sequence = "".join(entity_poly_seq)
    print("RNA sequence:", sequence)

    x_coords = mmcif_dict["_atom_site.Cartn_x"]
    y_coords = mmcif_dict["_atom_site.Cartn_y"]
    z_coords = mmcif_dict["_atom_site.Cartn_z"]
    atom_names = mmcif_dict["_atom_site.label_atom_id"]

    c1_coords = []
    for i, atom in enumerate(atom_names):
        if atom == "C1'":
            c1_coords.append((float(x_coords[i]), float(y_coords[i]), float(z_coords[i])))
    return c1_coords

all_preds = os.listdir('outputs_prediction/boltz_results_inputs_prediction/predictions')
submission = pd.read_csv('/kaggle/input/stanford-rna-3d-folding/sample_submission.csv')

这一步代码的主要作用是通过blotz-1预测出来的到的cif文件来提取每个核苷酸的C1'原子坐标,用于构建三维空间中的RNA骨架。

应赛事要求,每条rna的的每个核苷酸需要预测五组坐标:

所以让我们来看一下,我们的输出:

与此同时我们需要将这个输出转化成为csv文件:

可以看一下csv文件的内容:

这个文件记录了每个rna序列的id和核苷酸的类型,后面分别是预测的五组坐标,每一组坐标包含了x,y,z的坐标。

这里大家应该很熟悉了,蛋白质的pdb文件就是主要由这些信息组成的,rna的cif文件也是由这些信息组成。

那么今天的分享就到这里咯,下一篇文章我们会来看看如何计算rna预测结构和真实结构之间的loss,以及如何使用tm-score来打分。

回见咯。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-05-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Tom的小院 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • (1)、整体介绍
  • (2)、核心模块与用途
    • a.数据类定义(Dataclasses)
    • b.文件下载与缓存管理
    • c.输入数据检查与验证
    • d.多序列比对(MSA)计算
    • e.数据预处理(process_inputs函数)
    • f.预测过程(predict函数)
    • g.主函数入口(if __name__ == "__main__":)
  • (3)、关键第三方库
  • (4)、输出文件结构示意图
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档