前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >降水临近预报_Weather4cast_RainAI代码分享

降水临近预报_Weather4cast_RainAI代码分享

原创
作者头像
勤劳小王
发布2024-07-12 21:33:37
620
发布2024-07-12 21:33:37
举报
文章被收录于专栏:rainAI

降水临近预报_Weather4cast_RainAI代码分享

主程序w4c23

代码语言:python
代码运行次数:0
复制
def main():
    parser = set_parser()
    options = parser.parse_args()
    params = update_params_based_on_args(options)
    selected_model = params["model"]["model_name"]
    if selected_model == "2D_UNET_base":
        model = UNetModule
    elif selected_model == "SWIN":
        model = SWINModule
    train(params, options.gpus, options.mode, options.checkpoint, model)

set_parser()是一个函数,用于设置和返回一个argparse.ArgumentParser对象

parser.parse_args()方法来解析命令行参数并将结果存储在options变量中

代码语言:python
代码运行次数:0
复制
def update_params_based_on_args(options):
    config_p = os.path.join("configurations", options.config_path)
    params = load_config(config_p)

    if options.name != "":
        print(params["experiment"]["name"])
        params["experiment"]["name"] = options.name
    if options.epochs is not None:
        params["train"]["max_epochs"] = options.epochs
    if options.batch_size is not None:
        params["train"]["batch_size"] = options.batch_size
    if options.num_workers is not None:
        params["train"]["n_workers"] = options.num_workers
    if options.input_path != "":
        params["dataset"]["data_root"] = options.input_path
    if options.output_path != "":
        params["experiment"]["experiment_folder"] = options.output_path
    if options.region_to_predict != "":
        params["predict"]["region_to_predict"] = options.region_to_predict
    if options.year_to_predict != "":
        params["predict"]["year_to_predict"] = options.year_to_predict
    if options.submission_out_dir != "":
        params["predict"]["submission_out_dir"] = options.submission_out_dir
    return params

models

baseModule

具有强度输出和概率输出的模型的基本模块。需要验证和预测实现的抽象类。

BaseModule的类,它继承自LightningModuleABC

因为继承自LightningModule,要重写training_stepvalidation_steppredict_stepconfigure_optimizers方法,详见后续。

ABC(Abstract Base Class)是一个用于定义抽象基类的元类。抽象基类是不能被实例化的类,它主要用于定义接口和共享方法的规范。通过继承抽象基类,子类需要实现抽象基类中定义的抽象方法,以满足基类的接口规范。抽象基类可以提供一种约束,确保子类的一致性和可替换性。


代码语言:python
代码运行次数:0
复制
        if self.probabilistic:
            # Store bucket means (but not as model parameter) as the channel dimension of the data
            self.register_buffer(
                "bucket_means",
                torch.tensor(self.buckets.means).view(1, -1, 1, 1, 1),
            )
            self.bucket_means: torch.Tensor

如果损失函数是概率型的(probabilistic=True),则代码会使用self.register_buffer方法注册一个缓冲区(buffer)bucket_means,用于存储损失函数的桶均值。这里使用torch.tensor将桶均值转换成张量,并通过view方法对其进行形状变换,以便后续使用。需要注意的是,注册的缓冲区不会作为模型的参数进行优化。


代码语言:python
代码运行次数:0
复制
        if model_params["upsample"] == "bilinear":
            self.upsample = BilinearUpsample(42, 252, self.forecast_length)
        elif model_params["upsample"] == "nearest":
            self.upsample = NearestUpsample(42, 252, self.forecast_length)
        elif model_params["upsample"] == "ninasr":
            self.upsample = NinaSRUpsample(
                42, 252, self.forecast_length, self.num_classes
            )
        elif model_params["upsample"] == "edsr":
            self.upsample = EDSRUpsample(
                42, 252, self.forecast_length, self.num_classes
            )
        else:
            self.upsample = None

根据model_params["upsample"]的值选择相应的上采样方法对象赋值给self.upsample。根据代码片段提供的信息,上采样方法可以是BilinearUpsampleNearestUpsampleNinaSRUpsampleEDSRUpsample。如果model_params["upsample"]的值不在这些选项中,self.upsample将被设置为None


代码语言:python
代码运行次数:0
复制
    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()

abstractmethod是一个装饰器,用于定义抽象方法。抽象方法是在抽象基类中声明但没有实现的方法,它只有方法的声明部分,没有具体的方法体。抽象方法必须在子类中被重写实现,否则子类也会成为抽象类。通过使用abstractmethod装饰器,可以明确地表示某个方法是抽象方法。

在这段代码中,forward方法被定义为抽象方法,即没有具体的实现。抽象方法使用abstractmethod装饰器进行修饰,表示它是一个需要在子类中被重写实现的方法。子类必须提供forward方法的具体实现,以满足抽象基类的接口规范。


代码语言:python
代码运行次数:0
复制
    def augment_batch(self, batch):
        """Apply augmentation on training batches (flips and 90-degrees rotation)"""
        # TODO - Change to data loader
        if not self.transform:
            return batch
        input, label, metadata = batch
        angle = random.choice([-90, 0, 90, 180])
        transformations = [
            v2.RandomHorizontalFlip(),
            v2.RandomVerticalFlip(),
            v2.RandomRotation([angle, angle]),
        ]

        t = random.choice(transformations)
        input = t(input).contiguous()
        label = t(label).contiguous()
        # Transform masks
        metadata["input"]["mask"] = t(metadata["input"]["mask"])
        metadata["target"]["mask"] = t(metadata["target"]["mask"])
        # Transform static data if any
        if self.static_data:
            metadata["input"]["topo"] = t(metadata["input"]["topo"])
            metadata["target"]["topo"] = t(metadata["target"]["topo"])
            metadata["input"]["lat-long"] = t(metadata["input"]["lat-long"])
            metadata["target"]["lat-long"] = t(metadata["target"]["lat-long"])
        return input, label, metadata

augment_batch方法接受一个batch参数,表示训练批次数据。该方法的作用是对训练批次数据进行增强操作,包括翻转和旋转。增强操作可以提高模型的鲁棒性和泛化能力,使其能够更好地适应不同的输入样本。

在当前的实现中,首先判断是否需要进行数据增强操作,如果self.transformFalse,则直接返回原始的批次数据。否则,从批次数据中获取输入、标签和元数据。然后,随机选择一个角度(-90度、0度、90度或180度),并定义一些变换操作,包括随机水平翻转、随机垂直翻转和随机旋转。接下来,从变换操作中随机选择一个变换t,并将其应用于输入、标签和元数据的对应部分。其中,输入和标签通过调用变换对象的__call__方法进行转换,并使用contiguous方法保证数据的连续性。对于元数据中的掩码(mask)数据和静态数据(如果有的话),也需要进行相应的变换操作。最后,返回经过增强操作后的输入、标签和元数据。


代码语言:python
代码运行次数:0
复制
    def add_static(self, input, metadata):
        lat_long = (
            metadata["input"]["lat-long"]
            .unsqueeze(2)
            .repeat(1, 1, self.history_length, 1, 1)
        )
        topo = (
            metadata["input"]["topo"]
            .unsqueeze(2)
            .repeat(1, 1, self.history_length, 1, 1)
        )
        input = torch.cat([input, lat_long, topo], dim=1)
        return input

add_static方法接受两个参数,input表示输入数据,metadata表示元数据。该方法的作用是将静态数据(lat_longtopo)添加到输入数据中。

首先,代码从metadata中获取了lat_longtopo数据。这些数据可能是二维张量,表示地理坐标和地形信息。然后,通过使用unsqueeze方法在适当的维度上添加一个维度,以便进行重复复制。使用repeat方法将lat_longtopo在相应的维度上进行重复,以匹配输入数据的形状。接下来,使用torch.cat方法将输入数据、lat_longtopo在维度1上进行连接,将它们合并成一个更大的输入张量。最后,返回合并后的输入数据。


代码语言:python
代码运行次数:0
复制
    def training_step(self, batch):
        batch = self.augment_batch(batch)
        input, label, metadata = batch
        # Add static data to input if required
        if self.static_data:
            input = self.add_static(input, metadata)
        input = self.transform_input(input)
        prediction = self.forward(input)
        if self.upsample:
            prediction = self.upsample(prediction)
        mask = metadata["target"]["mask"]
        loss = self.loss_fn(prediction, label, mask)
        self.log("train/loss", loss, sync_dist=True)
        return loss

training_step方法接受一个batch参数,表示训练批次数据。该方法的作用是执行一次训练步骤,包括数据增强、添加静态数据、输入转换、模型前向传播、上采样、计算损失和记录训练损失。

首先,代码调用augment_batch方法对批次数据进行增强操作,得到增强后的批次数据。然后,从增强后的批次数据中获取输入、标签和元数据。接下来,根据是否需要添加静态数据的设置,判断是否需要将静态数据添加到输入中。如果需要添加静态数据,调用add_static方法将静态数据添加到输入数据中,得到添加了静态数据的输入。然后,调用transform_input方法对输入数据进行转换,得到转换后的输入数据。接着,调用forward方法对转换后的输入数据进行模型的前向传播,得到预测结果。如果定义了上采样方法(self.upsample不为None),则对预测结果进行上采样操作。接下来,从元数据中获取目标数据的掩码(mask)。然后,使用损失函数self.loss_fn计算预测结果与标签之间的损失,传入预测结果、标签和掩码作为参数。最后,使用self.log方法记录训练损失,并返回损失值。


代码语言:python
代码运行次数:0
复制
    def validation_step(self, batch, batch_idx) -> ValidationOutput:
        input, label, metadata = batch
        # Add static data to input if required
        if self.static_data:
            input = self.add_static(input, metadata)
        input = self.transform_input(input)
        prediction = self.forward(input)
        if self.upsample:
            prediction = self.upsample(prediction)
        mask = metadata["target"]["mask"]
        loss = self.loss_fn(prediction, label, mask)
        self.log("val/loss", loss, sync_dist=True)
        if self.probabilistic:
            # If no softmax, apply as it is required for the metrics (i.e. CRPS)
            if self.activation == "none":
                prediction = nn.functional.softmax(prediction, dim=1)
            probabilities = prediction
            intensity = self.integrate(prediction)
        else:
            probabilities = None
            intensity = prediction
        return ValidationOutput(intensity=intensity, probabilities=probabilities)

validation_step方法接受两个参数,batch表示验证批次数据,batch_idx表示批次索引。该方法的作用是执行一次验证步骤,包括添加静态数据、输入转换、模型前向传播、上采样、计算损失、记录验证损失和返回验证输出。

首先,代码从验证批次数据中获取输入、标签和元数据。接下来,根据是否需要添加静态数据的设置,判断是否需要将静态数据添加到输入中。如果需要添加静态数据,调用add_static方法将静态数据添加到输入数据中,得到添加了静态数据的输入。然后,调用transform_input方法对输入数据进行转换,得到转换后的输入数据。接着,调用forward方法对转换后的输入数据进行模型的前向传播,得到预测结果。如果定义了上采样方法(self.upsample不为None),则对预测结果进行上采样操作。接下来,从元数据中获取目标数据的掩码(mask)。然后,使用损失函数self.loss_fn计算预测结果与标签之间的损失,传入预测结果、标签和掩码作为参数。接着,使用self.log方法记录验证损失,并传入"val/loss"作为日志名称,loss作为损失值,并设置sync_dist=True以确保在分布式训练中同步日志。如果模型的损失函数是概率型的(self.probabilistic=True),则进行一些额外的操作。首先,如果激活函数是"none"(即没有使用激活函数),则将预测结果进行 softmax 操作,因为一些指标(如 CRPS)需要概率分布的预测结果。然后,将预测结果作为概率分布probabilities,并将预测结果进行积分得到intensity。最后,返回一个ValidationOutput对象,包含intensityprobabilities


代码语言:python
代码运行次数:0
复制
    def predict_step(self, batch, batch_idx=None) -> torch.Tensor:
        input, _, metadata = batch
        # Add static data to input if required
        if self.static_data:
            input = self.add_static(input, metadata)
        input = self.transform_input(input)
        prediction = self.forward(input)
        if self.upsample:
            prediction = self.upsample(prediction)
        if self.probabilistic:
            # If no softmax, apply as it to sum 1
            if self.activation == "none":
                prediction = nn.functional.softmax(prediction, dim=1)
            probabilities = prediction
            intensity = self.integrate(prediction)
        else:
            probabilities = None
            intensity = prediction
        intensity = intensity[:, :, : self.forecast_length, :, :]
        return intensity

首先,代码从预测批次数据中获取输入数据和元数据,忽略了标签数据(_)。接下来,根据是否需要添加静态数据的设置,判断是否需要将静态数据添加到输入中。如果需要添加静态数据,调用add_static方法将静态数据添加到输入数据中,得到添加了静态数据的输入。然后,调用transform_input方法对输入数据进行转换,得到转换后的输入数据。接着,调用forward方法对转换后的输入数据进行模型的前向传播,得到预测结果。如果定义了上采样方法(self.upsample不为None),则对预测结果进行上采样操作。如果模型的损失函数是概率型的(self.probabilistic=True),则进行一些额外的操作。首先,如果激活函数是"none"(即没有使用激活函数),则将预测结果进行 softmax 操作,以确保预测结果的和为1。然后,将预测结果作为概率分布probabilities,并将预测结果进行积分得到intensity。最后,根据预测长度截取intensity中的相应部分,并返回截取后的intensity作为预测结果。


代码语言:python
代码运行次数:0
复制
def configure_optimizers(self):
    optimizer = torch.optim.AdamW(
        self.parameters(),
        lr=self.lr,
        weight_decay=self.weight_decay,
    )
    return optimizer

使用了torch.optim.AdamW优化器类来创建一个AdamW优化器对象。AdamW是Adam优化器的一种变体,它在优化过程中引入了权重衰减(weight decay)的正则化项,有助于控制模型的复杂度并提高泛化能力。在创建优化器对象时,传入了两个参数。self.parameters()表示要优化的模型参数,即模型中所有需要进行梯度更新的参数。lr=self.lrweight_decay=self.weight_decay分别指定了学习率和权重衰减的数值,这些数值是在模型初始化时从参数中获取的。最后,将创建的优化器对象返回。

losses

交叉熵和均方误差计算,对应概率输出和强度输出。

callbacks

callbacks文件夹应该放回调代码就可以了,不知道为什么把metrics代码也放这里。

log
image.png
image.png

用于在PyTorch Lightning框架中记录和计算各种指标(metrics)的值

init
代码语言:python
代码运行次数:0
复制
    def __init__(self, num_leadtimes, probabilistic, buckets, logging):
        super().__init__()
        self.num_leadtimes = num_leadtimes
        self.probabilistic = probabilistic

        if buckets != "none":
            self.buckets = BUCKET_CONSTANTS[buckets]
        else:
            self.buckets = None

        self.logging = logging
        self.thresholds = [0.2, 1, 5, 10, 15]

接收参数num_leadtimes(leading time steps)、probabilistic(是否概率性指标)、buckets(用于概率性指标的桶大小)、logging(指标记录的方式)。

  1. num_leadtimes:leading time steps。
  2. probabilistic:一个布尔值,表示是否使用概率性指标。
  3. buckets:一个字符串,表示概率性指标中用于分桶的参数。如果不需要分桶,则为"none"。
  4. logging:一个字符串,表示指标记录的方式,可以是"tensorboard"或"wandb"。
  5. thresholds:一个列表,包含阈值的值。这些阈值将用于计算关键成功指数(Critical Success Index,CSI)

代码语言:python
代码运行次数:0
复制
from dataclasses import dataclass
from typing import List


@dataclass
class Bucket:
    idx: int
    mean: float
    max: float
    weight: float


@dataclass
class BucketConstants:
    buckets: List[Bucket]
    means: List[float]
    weights: List[float]
    boundaries: List[float]
    ranges: List[float]
    num_buckets: int
    
# Custom buckets used for classification when using mm/h
_buckets_mmh = [
    Bucket(idx=0, mean=0, max=0.08, weight=0.5107),
    Bucket(idx=1, mean=0.12, max=0.16, weight=0.6014),
    Bucket(idx=2, mean=0.2, max=0.25, weight=0.627),
    Bucket(idx=3, mean=0.32, max=0.4, weight=0.6295),
    Bucket(idx=4, mean=0.51, max=0.63, weight=0.631),
    Bucket(idx=5, mean=0.81, max=1, weight=0.6359),
    Bucket(idx=6, mean=1.3, max=1.6, weight=0.6472),
    Bucket(idx=7, mean=2.0, max=2.5, weight=0.6667),
    Bucket(idx=8, mean=3.25, max=4, weight=0.6901),
    Bucket(idx=9, mean=5.15, max=6.3, weight=0.7298),
    Bucket(idx=10, mean=8.1, max=10, weight=0.7823),
    Bucket(idx=11, mean=13, max=16, weight=0.8428),
    Bucket(idx=12, mean=20.5, max=25, weight=0.9084),
    Bucket(idx=13, mean=32.5, max=40, weight=0.9617),
    Bucket(
        idx=14, mean=45, max=128, weight=1.0
    ),  # Max is 128 as defined by preprocessing
]

def getBucketObject(buckets_list):
    return BucketConstants(
        buckets=buckets_list,
        means=[b.mean for b in buckets_list],
        weights=[b.weight for b in buckets_list],
        boundaries=[b.max for b in buckets_list[:-1]],
        ranges=[
            buckets_list[i].max - buckets_list[i - 1].max
            if i > 0
            else buckets_list[i].max
            for i in range(len(buckets_list))
        ],
        num_buckets=len(buckets_list),
    )


BUCKET_CONSTANTS = {
    "mmh": getBucketObject(_buckets_mmh),
    "test": getBucketObject(_buckets_test),
    "w4c23_1": getBucketObject(_buckets_w4c23_1),
    "w4c23_2": getBucketObject(_buckets_w4c23_2),
}

创建和管理不同的桶(Bucket)对象,并将其存储在BUCKET_CONSTANTS字典中。通过调用getBucketObject函数,可以根据桶列表获取相应的BucketConstants对象。这样做的目的是为了方便地创建和使用不同的桶,并将其关联到特定的名称,以供其他代码使用。

dataclasses 模块提供了一个装饰器 @dataclass,用于方便地创建和操作数据类(data class),它自动为类的属性生成相应的方法(如构造函数、属性访问方法、比较方法等),使得创建和操作数据对象更加简洁和方便。

代码语言:python
代码运行次数:0
复制
from dataclasses import dataclass

@dataclass
class Person:
    name: str
    age: int
    occupation: str

代码语言:python
代码运行次数:0
复制
        # # Code for checking if a metric can be optimized
        # check_forward_full_state_property(
        #     metrics.MeanSquaredError,
        #     input_args={
        #         "prediction": torch.Tensor([0.5, 2.5]),
        #         "label": torch.Tensor([1.0, 2.0]),
        #         "mask": torch.zeros([2], dtype=bool),
        #     },
        # )

被注释掉的代码是用于检查一个指标是否可以进行优化的示例代码。它使用torchmetrics库中的check_forward_full_state_property函数来检查均方误差(MeanSquaredError)指标是否可以进行优化。函数的输入参数为一个字典,包含了预测值(prediction)、标签值(label)和掩码(mask)。通过检查指标的前向计算是否可以成功执行,可以确保指标的正确性和可用性。

_threshold_str
代码语言:python
代码运行次数:0
复制
    def _threshold_str(self, threshold):
        """Remove .0 and change . by -"""
        return f"{threshold:g}".replace(".", "-")

该段代码定义了一个名为"_threshold_str"的私有方法,用于处理阈值(threshold)的字符串表示。

该方法接受一个阈值参数,将其转换为字符串表示。转换过程包括以下步骤:

  1. 使用"{threshold:g}"将阈值转换为一般格式的字符串表示,去除多余的零和小数点。
  2. 使用.replace(".", "-")将字符串中的小数点替换为短横线。

"g" 是格式化字符串中的一种格式化选项,用于表示通用格式。它会根据阈值的类型自动选择合适的表示方式,并去除多余的零和小数点。具体来说,对于整数类型的阈值,它会显示为普通整数的形式,如 5、10、100 等。而对于浮点数类型的阈值,它会显示为一般的浮点数格式,如 0.5、1.0、2.5 等。在这个过程中,多余的零和小数点会被去除。

最后,该方法返回处理后的字符串表示形式。

该方法的作用是将阈值转换为特定的字符串表示形式,可能是为了后续的指标命名或其他需要使用特定格式的字符串的目的。由于该方法是私有方法(以单个下划线开头),它在类外部不可直接访问,只能在类内部被调用。

setup
代码语言:python
代码运行次数:0
复制
    def setup(self, trainer, pl_module, stage):
        # Setup scalar metrics
        scalar_metrics = {}
        scalar_metrics["mse"] = metrics.MeanSquaredError()
        scalar_metrics["mae"] = metrics.MeanAverageError()

        for threshold in self.thresholds:
            csi = metrics.CriticalSuccessIndex(threshold=threshold)
            scalar_metrics[f"csi_{self._threshold_str(threshold)}"] = csi
        scalar_metrics["avg_csi"] = metrics.AverageCriticalSuccessIndex(
            thresholds=self.thresholds
        )

        if self.probabilistic:
            scalar_metrics["crps"] = metrics.ContinuousRankedProbabilityScore(
                self.buckets
            )

        # Create metric collections and put metrics on module to automatically place on correct device
        val_scalar_metrics = torchmetrics.MetricCollection(scalar_metrics)
        pl_module.val_metrics = val_scalar_metrics.clone(prefix="val/")

        # Lead time metrics
        lead_time_metrics = {}
        lead_time_metrics[f"mse"] = metrics.MeanSquaredError(
            num_leadtimes=self.num_leadtimes
        )
        for threshold in self.thresholds:
            csi = metrics.CriticalSuccessIndex(
                threshold=threshold, num_leadtimes=self.num_leadtimes
            )
            lead_time_metrics[f"csi_{self._threshold_str(threshold)}"] = csi
        lead_time_metrics["avg_csi"] = metrics.AverageCriticalSuccessIndex(
            thresholds=self.thresholds, num_leadtimes=self.num_leadtimes
        )
        pl_module.lead_time_metrics = torchmetrics.MetricCollection(lead_time_metrics)

setup方法中,主要进行了以下操作:

  1. 设置标量指标(scalar metrics):
    • 创建一个空字典scalar_metrics,用于存储标量指标。
    • scalar_metrics字典中添加均方误差(MeanSquaredError)和平均绝对误差(MeanAverageError)指标。
    • 使用阈值列表(self.thresholds)循环遍历,为每个阈值创建关键成功指数(CriticalSuccessIndex)指标,并将其添加到scalar_metrics字典中。在添加时,指标的名称使用了f"csi_{self._threshold_str(threshold)}"的格式,其中self._threshold_str(threshold)将阈值转换为特定的字符串表示形式。
    • 添加平均关键成功指数(AverageCriticalSuccessIndex)指标到scalar_metrics字典中,其中的阈值使用了阈值列表(self.thresholds)。
    • 如果self.probabilistic为True,则添加连续排名概率评分(ContinuousRankedProbabilityScore)指标到scalar_metrics字典中,其中的桶(buckets)参数使用了self.buckets
  2. 创建指标集合(MetricCollection)并将指标放入模块(pl_module)中:
    • MetricCollectiontorchmetrics的一个方法,接收字典输入,创建指标集合。
    • 使用scalar_metrics字典创建标量指标集合(val_scalar_metrics)。
    • 使用val_scalar_metrics.clone(prefix="val/")创建一个带有前缀的克隆集合,前缀为"val/"。
    • 将克隆的标量指标集合赋值给模块的val_metrics属性,用于在验证过程中记录和计算指标。
  3. 设置引导时间指标(lead time metrics):
    • 创建一个空字典lead_time_metrics,用于存储引导时间指标。
    • lead_time_metrics字典中添加均方误差(MeanSquaredError)指标,其中的引导时间数量使用了self.num_leadtimes
    • 使用阈值列表(self.thresholds)循环遍历,为每个阈值创建引导时间关键成功指数(CriticalSuccessIndex)指标,并将其添加到lead_time_metrics字典中。在添加时,指标的名称使用了f"csi_{self._threshold_str(threshold)}"的格式,其中self._threshold_str(threshold)将阈值转换为特定的字符串表示形式。
    • 添加平均关键成功指数(AverageCriticalSuccessIndex)指标到lead_time_metrics字典中,其中的阈值使用了阈值列表(self.thresholds)和引导时间数量(self.num_leadtimes)。
    • 将引导时间指标集合(lead_time_metrics)赋值给模块的lead_time_metrics属性,用于在验证过程中记录和计算引导时间指标。

总的来说,setup方法主要用于设置回调函数中的指标,包括标量指标和引导时间指标。它创建了相应的指标对象,并将它们放入模块中,以便在训练过程中使用和记录。

on_validation_batch_end
代码语言:python
代码运行次数:0
复制
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
    ):
        """Called after each validation batch with scalar and lead time metrics"""
        _, label, metadata = batch
        pl_module.val_metrics(outputs, label, metadata["target"]["mask"])
        pl_module.lead_time_metrics(outputs, label, metadata["target"]["mask"])
  1. 获取批次数据:从batch参数中解包获取到三个值,即_(不使用)、labelmetadata。这些值通常代表了模型输出、标签和元数据等。
  2. 计算标量指标:通过调用模块(pl_module)的val_metrics指标集合,传递模型输出、标签和目标掩码(metadata["target"]["mask"]),来计算标量指标的值。
  3. 计算引导时间指标:通过调用模块的lead_time_metrics指标集合,传递模型输出、标签和目标掩码,来计算引导时间指标的值。

通过调用指标集合的方法,可以将模型的输出、标签和目标掩码传递给指标集合,以便计算相应的指标值。这些指标值将用于后续的记录和评估过程。

on_validation_epoch_end
代码语言:python
代码运行次数:0
复制
    def on_validation_epoch_end(self, trainer, pl_module):
        # Log validation scalar metrics
        pl_module.log_dict(
            pl_module.val_metrics, on_step=False, on_epoch=True, sync_dist=True
        )

        # Compute and log lead time metrics
        lead_time_metrics = pl_module.lead_time_metrics.compute()
        lead_time_metrics_dict = {}
        wandb_data = []
        for metric_name, arr in lead_time_metrics.items():
            # Add to logging dictionary
            for leadtime, value in enumerate(arr):
                lead_time_metrics_dict[f"val_time/{metric_name}_{leadtime+1}"] = value
            # Save to file (tensorboard)
            if self.logging == "tensorboard":
                file_path = os.path.join(
                    pl_module.logger.log_dir, f"val_lead_time_{metric_name}.pt"
                )
                torch.save(arr.cpu(), file_path)
            # Generate table for wandb
            elif self.logging == "wandb":
                columns = ["metric"] + [f"t_{i+1}" for i in range(len(arr))]
                wandb_data.append([metric_name] + arr.tolist())

        # Save table in wandb
        if self.logging == "wandb":
            pl_module.logger.log_table(
                key="leadtimes", columns=columns, data=wandb_data
            )

        # Save lead time metrics over time
        pl_module.log_dict(
            lead_time_metrics_dict, on_step=False, on_epoch=True, sync_dist=True
        )
        pl_module.lead_time_metrics.reset()
  1. 记录和保存标量指标:
    • 使用pl_module.val_metrics指标集合,通过调用模块的log_dict方法,将标量指标的值记录到日志中。
    • 设置on_step=Falseon_epoch=True,以确保在验证周期结束时记录指标的值。
    • 使用sync_dist=True来同步跨多个设备的指标值。
  2. 计算和记录引导时间指标:
    • 使用pl_module.lead_time_metrics指标集合的compute方法,计算引导时间指标的值。
    • 创建一个空字典lead_time_metrics_dict,用于存储引导时间指标的名称和值。
    • 创建一个空列表wandb_data,用于存储生成表格所需的数据。
    • 遍历引导时间指标集合中的每个指标和对应的值:
      • 将指标的名称和对应的值添加到lead_time_metrics_dict字典中,以便后续的记录和保存。
      • 如果self.logging为"tensorboard",则将引导时间指标的值保存到文件中,文件名为val_lead_time_{metric_name}.pt
      • 如果self.logging为"wandb",则生成一个表格所需的数据,其中包括指标名称和对应的值。
    • 如果self.logging为"wandb",则将生成的表格数据使用pl_module.logger.log_table方法保存到wandb中,其中的key表示表格的唯一标识,columns表示表格的列名,data表示表格的数据。
  3. 记录引导时间指标的值:
    • 使用pl_modulelog_dict方法,将引导时间指标的名称和值记录到日志中。
    • 设置on_step=Falseon_epoch=True,以确保在验证周期结束时记录指标的值。
    • 使用sync_dist=True来同步跨多个设备的指标值。
  4. 重置引导时间指标集合:
    • 使用pl_module.lead_time_metrics指标集合的reset方法,重置引导时间指标的状态,以便在下一个验证周期开始时重新计算。
metrics
整体介绍
  • 继承自torchmetrics中的Metric类,重写了full_state_updatehigher_is_better两个属性、updatecompute两个方法。

在类的定义中,full_state_update被设置为False,表示不需要完全状态更新;higher_is_better被设置为True,表示指标的值越高越好。

在PyTorch的Metric类中,通常会定义一些状态变量,用于保存指标计算过程中的中间结果。这些状态变量可以在每次更新指标时被更新。而完全状态更新是指每次更新指标时,都会将所有的状态变量进行更新。然而,并不是所有的指标都需要进行完全状态更新。有些指标的计算只依赖于最近一次更新的状态,而不需要考虑之前的状态。在这种情况下,可以将full_state_update设置为False,以优化计算性能。这次计算的CSI指标跟之前的状态就无关,因此不需要完全状态更新。


update方法中,接受了三个参数predictionlabelmask,用于更新指标的计算。根据阈值列表和预测结果,将预测结果转换为二进制形式,并根据reduce_time的值进行不同的操作。


compute方法中,计算了平均关键成功指数(CSI),即真阳性(true positives)除以真阳性和假预测(false guesses)之和的平均值。

init
  • 接下来,根据传入的参数thresholdsnum_leadtimes的值,选择不同的默认值和设置self.reduce_time的值。
    • 如果num_leadtimesNone或者等于1,表示只有一个时间步,那么默认值default将被设置为一个形状为(len(thresholds),)的全零张量,并且self.reduce_time将被设置为True,表示需要减少时间维度。
    • 如果num_leadtimes大于1,表示有多个时间步,那么默认值default将被设置为一个形状为(len(thresholds), num_leadtimes)的全零张量,并且self.reduce_time将被设置为False,表示不需要减少时间维度。
    • 如果num_leadtimes小于等于0,则会抛出ValueError异常,提示num_leadtimes必须大于0。
  • 将传入的thresholds参数赋值给self.thresholds属性,以便在后续的计算中使用。
  • 通过调用self.add_state方法,将名为"true_positives"和"false_guesses"的状态变量添加到指标类中。这两个状态变量的默认值都是通过default.clone()来创建的,同时设置了分布式合并函数dist_reduce_fx为"sum"
    • 关于dist_reduce_fx,Metric类中使用分布式合并函数的目的是支持在分布式计算环境中进行指标的计算和合并,在分布式计算环境中,通常有多个计算节点或进程同时进行计算任务。每个节点或进程都可能独立地计算指标的一部分,并生成局部的状态变量。为了得到整体的指标结果,需要将各个节点或进程上计算得到的状态变量进行合并。
update

更新状态变量。

首先,根据阈值列表self.thresholds,使用enumerate函数遍历阈值列表的索引和值,因为CSI指标的计算在不同thresholds下是不同的。

接下来,将预测结果prediction的强度(intensity)赋给变量pred

然后,将predlabel转换为二进制形式。将pred中大于等于当前thresholds的元素设置为真(True),其余为假(False)。同样,将label中大于等于当前thresholds的元素设置为真,其余为假。

接着,根据self.reduce_time的值进行不同的操作。

  • 如果self.reduce_time为True,表示只有一个时间步,那么将根据maskpredlab进行掩码操作,即将掩码为真(True)的位置从predlab中剔除。
  • 如果self.reduce_time为False,表示有多个时间步,那么通过重新排列张量的维度,将predlab的时间维度放到最后的位置,即将形状由"b c t h w"变为"(b c h w) t"。同时,对mask进行相同的重新排列操作,并使用torch.logical_and函数将predlab与掩码取反(~m)进行逻辑与操作,以将掩码位置视为真(True)。这样可以保留其他维度的信息并考虑掩码。

最后,根据预测结果和标签计算真阳性(true positives)和假预测(false guesses)的总数。使用torch.logical_and函数计算predlab的逻辑与,得到同时为真的位置,然后使用sum(dim=0)对每个时间步的结果进行求和,将结果累加到self.true_positives[i]中。使用(pred != lab)进行逻辑不等于操作,得到不一致的位置,然后使用sum(dim=0)对每个时间步的结果进行求和,将结果累加到self.false_guesses[i]中。

通过循环遍历阈值列表和计算真阳性和假预测的总数,update方法更新了指标类中的状态变量。

compute

根据状态变量计算最终指标。

utils

buckets

各种分箱策略。

config

继承自yaml库的SafeLoader类,用于解析YAML文件(里没事各种参数设定)。

data_utils

用于各种数据处理。使用的情况有:

代码语言:python
代码运行次数:0
复制
train.py:
from w4c23.utils.data_utils import get_cuda_memory_usage, tensor_to_submission_file
sampler.py:
from w4c23.utils.data_utils import get_file
w4c_dataloader.py:
from w4c23.utils.data_utils import *
sampler

数据集中样本的抽样策略,实现重要性采样。

w4c_dataloader

读取并归一化大赛数据。

其他项目

checkpoints

保存模型参数。

configurations

保存定义模型的各种参数组合。

data

原始数据。

images

2D U-Net 架构的输出与其输入具有相同的空间维度。这意味着对于大小为 128 x 128 像素的输入序列,通过 U-Net 的前向传播将生成大小为 128 x 128 像素的输出。标签对应于大小为 42 x 42 像素的中心块。因此,为了指导降水临近预报模型,我们采用中央 42 x 42 像素块并上采样到 252 x 252 像素标签。这种裁剪和上采样是在 MetNet 9 中引入的,这是由于输入和标签的空间分辨率不同所致,如第 3 节中所述。

解释为什么要对标签值上采样。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 降水临近预报_Weather4cast_RainAI代码分享
    • 主程序w4c23
      • models
      • losses
      • callbacks
      • utils
    • 其他项目
      • checkpoints
      • configurations
      • data
      • images
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档