首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >工程级实现!深度残差收缩网络的TensorFlow程序详解

工程级实现!深度残差收缩网络的TensorFlow程序详解

原创
作者头像
用户12081513
发布2026-02-27 15:56:31
发布2026-02-27 15:56:31
1420
举报

在旋转机械故障诊断领域,实际采集的振动信号往往夹杂着大量的环境噪声,这对特征提取的准确性造成了较大挑战。论文“Deep Residual Shrinkage Networks for Fault Diagnosis”提出了深度残差收缩网络(DRSN),将“软阈值化(Soft Thresholding)”作为非线性算子引入到深度残差网络中。通过构建带有通道级阈值的残差收缩构建单元(Residual Shrinkage Building Unit with Channel-wise thresholds, RSBU-CW),利用子网络自适应地学习每一张特征图的阈值,自动剔除不重要的噪声特征,保留对分类有用的成分。这使得网络在无需信号处理知识的前提下,具备了降噪与分类一体化的能力。

一、RSBU-CW模块原理

RSBU-CW模块是DRSN的精髓,其逻辑结构如图1所示。在传统的残差连接基础上,该模块增加了一个阈值学习分支。首先对特征图取绝对值,通过全局平均池化(Global Average Pooling, GAP)将空间信息压缩。然后利用两层全连接(Fully Connected, FC)网络与Sigmoid激活函数,学习得到一个介于0到1之间的缩放因子α。

阈值τ的计算公式为:τ = α * average(|x|),其中 average(|x|) 代表特征通道内绝对值的平均数。然后执行软阈值化操作:y = sign(x) * max(|x| - τ, 0)

通过这种机制,模型在保留信号符号位的同时,将接近零的噪声分量强制收缩为零。通过多层堆叠,RSBU-CW能够像“过滤器”一样层层精炼特征。

二、基于TensorFlow的复现代码

本节提供完整的Python复现代码。代码包含了RSBU-CW单元的自定义实现、DRSN 整体架构的搭建以及针对凯斯西储大学(Case Western Reserve University, CWRU)轴承数据集的预处理逻辑。

代码语言:python
复制
"""
项目名称:
    基于深度残差收缩网络的凯斯西储大学轴承数据故障诊断程序
参考文献:
    Zhao M, Zhong S, Fu X, Tang B, Pecht M.
    Deep residual shrinkage networks for fault diagnosis.
    IEEE Transactions on Industrial Informatics, 2020, 16(7): 4681–4690.
"""

import os
import sys
import logging
import numpy as np
import scipy.io as sio
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from sklearn.model_selection import train_test_split

# =============================================================================
# 运行环境配置模块
# =============================================================================

def configure_computational_resources():
    """
    配置底层计算资源与日志系统。
    包含硬件加速器的显存增长策略设置及环境依赖验证。
    """
    logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
    
    # 屏蔽非关键的计算框架信息
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    physical_processors = tf.config.list_physical_devices('GPU')
    if physical_processors:
        try:
            for processor in physical_processors:
                tf.config.experimental.set_memory_growth(processor, True)
            logging.info("GPU 加速引擎初始化成功,已启用显存自适应分配。")
        except RuntimeError as error:
            logging.error("硬件资源重分配失败: %s", error)
    else:
        logging.info("未发现可用 GPU,系统切换至通用处理器计算模式。")

# 执行系统初始化
configure_computational_resources()

# =============================================================================
# 核心算法组件:深度残差收缩网络层定义
# =============================================================================

class SoftThresholding(layers.Layer):
    """
    软阈值化算子层。
    作为深度残差收缩网络的核心非线性映射,实现基于学习阈值的特征过滤:
    公式:Output = sign(Input) * max(|Input| - Threshold, 0)
    """
    def __init__(self, **kwargs):
        super(SoftThresholding, self).__init__(**kwargs)

    def call(self, inputs):
        x, threshold = inputs
        # 扩展维度以适配特征图通道 (Batch, 1, Channels)
        expanded_threshold = tf.expand_dims(threshold, axis=1)
        return tf.sign(x) * tf.maximum(tf.abs(x) - expanded_threshold, 0.0)

class RSBU_CW(layers.Layer):
    """
    深度残差收缩网络单元 (Residual Shrinkage Building Unit)。
    该模块通过注意力机制自适应学习通道级阈值,并结合残差连接增强特征传播。
    """
    def __init__(self, filters, kernel_size, strides=1, **kwargs):
        super(RSBU_CW, self).__init__(**kwargs)
        self.out_channels = filters
        self.stride = strides
        self.kernel_size = kernel_size
        self.weight_decay = regularizers.l2(1e-4)

        # 快捷路径分支 (Shortcut)
        self.shortcut = None
        
        # 特征映射分支
        self.norm_a = layers.BatchNormalization()
        self.relu_a = layers.Activation('relu')
        self.conv_a = layers.Conv1D(filters, kernel_size, strides=strides, padding='same', 
                                   kernel_initializer='he_normal', kernel_regularizer=self.weight_decay)
        
        self.norm_b = layers.BatchNormalization()
        self.relu_b = layers.Activation('relu')
        self.conv_b = layers.Conv1D(filters, kernel_size, strides=1, padding='same', 
                                   kernel_initializer='he_normal', kernel_regularizer=self.weight_decay)
        
        # 阈值学习子网络
        self.gap = layers.GlobalAveragePooling1D()
        self.fc1 = layers.Dense(filters, kernel_initializer='he_normal')
        self.norm_enc = layers.BatchNormalization()
        self.relu_enc = layers.Activation('relu')
        self.fc2 = layers.Dense(filters, activation='sigmoid')
        self.soft_thresholding = SoftThresholding()

    def build(self, input_shape):
        # 若步长不为1或通道数不匹配,则构建1x1卷积进行投影
        if self.stride != 1 or input_shape[-1] != self.out_channels:
            self.shortcut = layers.Conv1D(self.out_channels, 1, strides=self.stride, padding='same')
        super(RSBU_CW, self).build(input_shape)

    def call(self, inputs):
        identity = inputs
        if self.shortcut:
            identity = self.shortcut(inputs)

        # 基础卷积变换
        x = self.norm_a(inputs)
        x = self.relu_a(x)
        x = self.conv_a(x)
        x = self.norm_b(x)
        x = self.relu_b(x)
        x = self.conv_b(x)

        # 注意力引导的阈值生成
        abs_x = tf.abs(x)
        average_abs_x = self.gap(abs_x)
        
        alpha = self.fc1(average_abs_x)
        alpha = self.norm_enc(alpha)
        alpha = self.relu_enc(alpha)
        alpha = self.fc2(alpha)
        
        # 缩放因子计算与收缩操作
        threshold = tf.multiply(alpha, average_abs_x)
        refined_x = self.soft_thresholding([x, threshold])
        
        return layers.Add()([refined_x, identity])

class DRSN_CW(models.Model):
    """
    基于深度残差收缩网络的分类器。
    架构包含:
    1. 初始特征提取
    2. 深度堆叠的自适应残差收缩单元
    3. 决策输出层
    """
    def __init__(self, num_classes):
        super(DRSN_CW, self).__init__(name="DRSN_Diagnostic_Core")
        
        # 初始特征提取
        self.conv1 = layers.Conv1D(32, 15, strides=2, padding='same', kernel_initializer='he_normal')
        self.bn1 = layers.BatchNormalization()
        self.relu1 = layers.Activation('relu')
        
        # 深度收缩模块堆叠
        self.rsbu_stack = [
            RSBU_CW(32, 5, strides=2),
            RSBU_CW(32, 5, strides=1),
            RSBU_CW(64, 5, strides=2),
            RSBU_CW(64, 5, strides=1),
            RSBU_CW(128, 5, strides=2),
            RSBU_CW(128, 5, strides=1)
        ]
        
        # 决策输出层
        self.bn2 = layers.BatchNormalization()
        self.relu2 = layers.Activation('relu')
        self.gap = layers.GlobalAveragePooling1D()
        self.fc_out = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu1(x)
        
        for rsbu_unit in self.rsbu_stack:
            x = rsbu_unit(x)
            
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.gap(x)
        return self.fc_out(x)

# =============================================================================
# 数据获取与信号增强模块
# =============================================================================

class CWRULoader:
    """
    振动信号加载与解析引擎。
    负责从数据文件中提取加速度序列并执行滑动窗口切割。
    """
    def __init__(self, base_directory, window_width=1024):
        self.data_root = os.path.abspath(base_directory)
        self.sequence_len = window_width

    def _read_mat(self, file_path):
        try:
            mat_container = sio.loadmat(file_path)
            for key_id in mat_container.keys():
                if 'DE_time' in key_id:
                    return mat_container[key_id].flatten()
        except Exception:
            return None
        return None

    def load_and_segment(self, metadata_cfg):
        feature_collection, label_collection = [], []
        is_path_valid = False
        
        for label_id, source_files in metadata_cfg.items():
            for filename in source_files:
                full_path = os.path.join(self.data_root, f"{filename}.mat")
                if not os.path.exists(full_path):
                    continue
                
                time_series = self._read_mat(full_path)
                if time_series is None:
                    continue
                
                is_path_valid = True
                # 执行固定步长的非重叠切片
                for pointer in range(0, len(time_series) - self.sequence_len + 1, self.sequence_len):
                    slice_data = time_series[pointer : pointer + self.sequence_len]
                    feature_collection.append(slice_data)
                    label_collection.append(label_id)
        
        if not is_path_valid:
            raise IOError("无法在指定目录中定位合法的 .mat 数据资源。")
            
        return np.array(feature_collection, dtype='float32'), np.array(label_collection, dtype='int32')

def apply_random_white_noise(input_batch, ratio_db):
    """
    向信号注入加性高斯白噪声。
    数学基准:P_noise = P_signal / (10^(SNR/10))
    """
    input_batch = np.array(input_batch)
    random_state = np.random.default_rng()
    
    current_snr = ratio_db if isinstance(ratio_db, (int, float)) else random_state.uniform(ratio_db[0], ratio_db[1])
    
    signal_intensity = np.mean(input_batch**2, axis=1, keepdims=True)
    noise_intensity = signal_intensity / (10**(current_snr / 10))
    noise_pattern = random_state.normal(0, np.sqrt(noise_intensity), input_batch.shape)
    
    return (input_batch + noise_pattern).astype('float32')

# =============================================================================
# 工作流管理与训练管线
# =============================================================================

def train_drsn(dataset_dir, input_dim=1024):
    """
    主控程序:执行深度残差收缩网络的完整训练与离线评估流程。
    """
    # 类别定义
    category_registry = {
        0: ['Normal_0', 'Normal_1', 'Normal_2', 'Normal_3'],
        1: ['IR007_0', 'IR007_1', 'IR007_2', 'IR007_3'],
        2: ['IR014_0', 'IR014_1', 'IR014_2', 'IR014_3'],
        3: ['IR021_0', 'IR021_1', 'IR021_2', 'IR021_3'],
        4: ['B007_0', 'B007_1', 'B007_2', 'B007_3'],
        5: ['B014_0', 'B014_1', 'B014_2', 'B014_3'],
        6: ['B021_0', 'B021_1', 'B021_2', 'B021_3'],
        7: ['OR007@6_0', 'OR007@6_1', 'OR007@6_2', 'OR007@6_3'],
        8: ['OR014@6_0', 'OR014@6_1', 'OR014@6_2', 'OR014@6_3'],
        9: ['OR021@6_0', 'OR021@6_1', 'OR021@6_2', 'OR021@6_3']
    }
    
    # 实例化加载器并读取数据
    loader = CWRULoader(base_directory=dataset_dir, window_width=input_dim)
    try:
        x_all, y_all = loader.load_and_segment(category_registry)
    except Exception as failure:
        logging.error("数据集构建失败: %s", failure)
        return

    # 数据集分层抽样
    x_train_pre, x_holdout, y_train_pre, y_holdout = train_test_split(
        x_all, y_all, test_size=0.3, random_state=42
    )
    x_valid_pre, x_test_pre, y_valid_pre, y_test_pre = train_test_split(
        x_holdout, y_holdout, test_size=0.5, random_state=42
    )
    
    # 归一化统计量计算
    global_mean = np.mean(x_train_pre)
    global_std = np.std(x_train_pre)
    
    def normalize_and_reshape(raw_array):
        return ((raw_array - global_mean) / global_std).reshape(-1, input_dim, 1)

    train_tensor = normalize_and_reshape(x_train_pre)
    valid_tensor = normalize_and_reshape(x_valid_pre)
    test_tensor = normalize_and_reshape(x_test_pre)
    
    class_count = len(category_registry)
    train_labels = tf.keras.utils.to_categorical(y_train_pre, class_count).astype('float32')
    valid_labels = tf.keras.utils.to_categorical(y_valid_pre, class_count).astype('float32')
    test_labels = tf.keras.utils.to_categorical(y_test_pre, class_count).astype('float32')

    # 模拟极端噪声环境 (-8dB)
    noisy_valid_set = apply_random_white_noise(valid_tensor, ratio_db=-8)
    noisy_test_set = apply_random_white_noise(test_tensor, ratio_db=-8)

    def dynamic_augmentation_engine(feat_batch, label_batch):
        """
        在线数据增强:融合循环位移、脉冲干扰及随机信噪比噪声。
        """
        rng = np.random.default_rng()
        processed_feat = feat_batch.copy()
        m, n, _ = processed_feat.shape

        # 1. 随机域平移
        for idx in range(m):
            offset = rng.integers(0, n)
            processed_feat[idx, :, 0] = np.roll(processed_feat[idx, :, 0], offset)

        # 2. 模拟机械冲击噪声
        if rng.random() > 0.9: 
            for idx in range(m):
                if rng.random() > 0.5: 
                    spike_num = rng.integers(1, 3) 
                    indices = rng.integers(0, n, spike_num)
                    impact_amp = np.std(processed_feat[idx]) * rng.uniform(1.5, 2.5) 
                    processed_feat[idx, indices, 0] += impact_amp * rng.choice([-1, 1], size=spike_num)

        # 3. 动态噪声混合 (SNR 范围 -8dB 至 8dB)
        if rng.random() > 0.5: 
            processed_feat = apply_random_white_noise(processed_feat, ratio_db=(-8, 8))

        return processed_feat.astype(np.float32), label_batch.astype(np.float32)

    def enforce_tensor_meta(f_tensor, l_tensor):
        f_tensor.set_shape([None, input_dim, 1])
        l_tensor.set_shape([None, class_count])
        return f_tensor, l_tensor

    # 构建数据供给管线
    ds_train = tf.data.Dataset.from_tensor_slices((train_tensor.astype('float32'), train_labels))
    ds_train = ds_train.shuffle(len(train_tensor)).batch(64)
    ds_train = ds_train.map(
        lambda x, y: tf.numpy_function(dynamic_augmentation_engine, [x, y], [tf.float32, tf.float32]),
        num_parallel_calls=tf.data.AUTOTUNE
    ).map(enforce_tensor_meta).prefetch(tf.data.AUTOTUNE)

    # 模型编译
    model = DRSN_CW(num_classes=class_count)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    logging.info("深度残差收缩网络架构编译完成,输入尺寸: %d", input_dim)
    
    # 训练监控回调
    monitors = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=7, min_lr=1e-6, verbose=1),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
    ]

    # 模型拟合
    model.fit(
        ds_train,
        epochs=100,
        validation_data=(noisy_valid_set, valid_labels),
        callbacks=monitors,
        verbose=2
    )

    # 测试集的分类精度
    final_loss, final_acc = model.evaluate(noisy_test_set, test_labels, verbose=0)
    print(f"\n强噪声测试环境 (-8dB SNR) 下,模型分类精度: {final_acc*100:.2f}%")

# =============================================================================
# 系统入口
# =============================================================================

if __name__ == "__main__":
    TARGET_DATA_PATH = os.path.join(os.getcwd(), 'data_path')
    
    if not os.path.exists(TARGET_DATA_PATH):
        logging.warning("默认数据路径未定位: %s", TARGET_DATA_PATH)
        user_input = input("请手动指定 CWRU 数据集所在的物理路径: ").strip()
        if user_input:
            TARGET_DATA_PATH = user_input
        else:
            logging.error("未获取有效路径,执行终止。")
            sys.exit(0)

    train_drsn(TARGET_DATA_PATH, input_dim=1024)

三、实验结果分析

本次复现采用加噪的CWRU轴承数据集作为基准。如实验数据表所示,涵盖了正常、内圈故障、球体故障及外圈故障等10类状态,缺陷尺寸包含7mil、14mil及21mil。为了验证模型的鲁棒性,在训练过程中向信号注入了加性高斯白噪声。

从实验结果截图可以看到,模型在训练到100个轮次后,测试集准确率达到了90%以上。DRSN在极低信噪比环境下展现出了较强的抗噪性能。

论文标题: Deep residual shrinkage networks for fault diagnosis

出版期刊: IEEE Transactions on Industrial Informatics. 2020, 16(7): 4681-4690.

DOI: 10.1109/TII.2019.2943898

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
作者已关闭评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、RSBU-CW模块原理
  • 二、基于TensorFlow的复现代码
  • 三、实验结果分析
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档