前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >模型蒸馏-学习笔记

模型蒸馏-学习笔记

原创
作者头像
Johns
修改2022-06-30 10:25:02
3.3K0
修改2022-06-30 10:25:02
举报
文章被收录于专栏:代码工具

模型蒸馏

一. 基础理论

背景

知识蒸馏(Knowledge Distillation)最早是Hinton 2014年在论文Dislillation the Knowledge in a Neural Network中提出的概念,主要思想是通过教师模型(teacher)来指导学生模型(student)的训练,将复杂、学习能力强的教师模型学到的特征表示“知识蒸馏”出来,传递给参数小、学习能力弱的学生模型,从而得到一个速度快、表达能力强的学生模型。

核心问题

由于知识蒸馏在训练过程中,有两个模型(teacher模型和student模型),3个loss(teacher loss、student loss、蒸馏loss),因此如何平衡不同模型之间的训练会是影响模型效果的重要因素。

业界方案

image-20220416090124490.png
image-20220416090124490.png
image-20220416092918094.png
image-20220416092918094.png
(1) MD

在MD中,教师模型和学生模型处理相同的输入特征,其中教师模型会比学生模型更为复杂,比如教师模型会用更深的网络结构来指导使用浅层网络的学生模型进行学习。MD最早是在2014年的由Hinton和谷歌大佬Jeff Dean联合提出的《Distilling the Knowledge in a Neural Network》.

image-20220416094721661.png
image-20220416094721661.png

Hinton是一位超级学霸, 多领域的专家, 图灵奖得主, Jeff Dean 谷歌第20号员工, Google AI的负责人,Tensorflow就是在他的领导下开发的.

(2) PFD

在PFD中,教师模型和学生模型使用相同网络结构,而处理不同的输入特征。图1(b)中PFD的学生模型只处理常规特征,而教师模型同时处理常规特征和优势特征。

优势特征(Privileged Features): 区分度高、但只能离线获取的特征, 特征主要包含用户行为特征、用户特征、商品特征和交叉特征4个部分, 其中:

  • 所有的交叉特征为优势特征: 比如用户过去24小时同类别商品的点击次数。
  • 所有穿越特征为优势特征: 比如用户在商品详情页的相关特征,包括用户的停留时长等。穿越可以理解为线上无法直接获取到的特征.
    image-20220416094755892.png
    image-20220416094755892.png

MD方案是teacher和student的输出做拟合, PFD方案是teacher和student的特征分布做拟合, PFD是2018年阿里提出《Privileged Features Distillation at Taobao Recommendations》

(3) Rocket Training

工业界,一些在线模型,对响应时间提出非常严苛的要求,从而一定程度上限制了模型的复杂程度。模型复杂程度的受限可能会导致模型学习能力的降低从而带来效果的下降。 为了解决这一问题,2018年阿里妈妈提出了Rocket Training,利用复杂的模型来辅助一个精简模型的训练,测试阶段,利用学习好的小模型来进行推断。论文地址: 《Rocket Launching: A Universal and Efficient Framework for Training Well-performing Light Net》

image-20220416094403035.png
image-20220416094403035.png

为什么叫火箭发射(Rocket Training)?是因为整个训练到预测的过程就像火箭发射的过程。

  • 开始阶段(训练阶段):助推器(booster)载着卫星(light net)共同前进,助推器( booster )提供动力,推进卫星(light net)前行。
  • 第二阶段(预测阶段):助推器( booster )被丢弃,只剩下轻巧的卫星(light net)独自前行。

如上图就是Rocket Training模型结构,左侧是教师模型结构,右侧是学生模型结构。模型有以下特点:

  • 学生网络和教师网络共享底层参数
    [公式]
    [公式]
  • 教师网络使用更复杂的模型结构
    [公式]
    [公式]
  • 蒸馏目标是logits输出,学生网络logits拟合教师网络logits

蒸馏误差表示如下:

image-20220416094612698.png
image-20220416094612698.png

应用场景

详见《Privileged Features Distillation at Taobao Recommendations》

image-20220416095115243.png
image-20220416095115243.png
image-20220416095247433.png
image-20220416095247433.png

二. 模型蒸馏实战

说明: keras官方提供的蒸馏方案是一个标准的MD方案, teacher和student使用相同的输入, 通过teacher和student的输出拟合来将teacher的信息迁移到student.

参考的论文: 《Distilling the Knowledge in a Neural Network》

官方案例

Step1. 准备
代码语言:python
代码运行次数:0
复制
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
Step2. 定义Distiller模型
代码语言:python
代码运行次数:0
复制
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results
Step3. 准备好teacher, student模型
代码语言:python
代码运行次数:0
复制
# Create the teacher
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="teacher",
)

# Create the student
student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
Step4. 准备好数据集
代码语言:python
代码运行次数:0
复制
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
Step5. 训练teacher模型
代码语言:python
代码运行次数:0
复制
# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate teacher on data. Teacher网络比较大, 需要更多轮次保证模型不会欠拟合
teacher.fit(x_train, y_train, epochs=6)
teacher.evaluate(x_test, y_test)
Step5. 蒸馏训练出一个student模型
代码语言:python
代码运行次数:0
复制
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=5)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
step6. 独立训练一个student模型用于对比
代码语言:python
代码运行次数:0
复制
# Train student as doen usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)

def model_test(model):     
  num_images = 256
  start = time.perf_counter()
  for _ in range(num_images):
      index = random.randint(0, x_test.shape[0])
      x = x_test[index]
      y = y_test[index]
      x.shape = (1, 28, 28, 1)  # 变成[[]]
      predict = model.predict(x)
      predict = np.argmax(predict)  # 取最大值的位置

  end = time.perf_counter()
  time_ir = end - start

  print(
      f"model in Inference Engine/CPU: {time_ir/num_images:.4f} "
      f"seconds per image, FPS: {num_images/time_ir:.2f}"
  )

model_test(student_scratch)
step7. 模型性能测试
代码语言:python
代码运行次数:0
复制
def model_test(model):     
  num_images = 256
  start = time.perf_counter()
  for _ in range(num_images):
      index = random.randint(0, x_test.shape[0])
      x = x_test[index]
      y = y_test[index]
      x.shape = (1, 28, 28, 1)  # 变成[[]]
      predict = model.predict(x)
      predict = np.argmax(predict)  # 取最大值的位置

  end = time.perf_counter()
  time_ir = end - start

  print(
      f"model in Inference Engine/CPU: {time_ir/num_images:.4f} "
      f"seconds per image, FPS: {num_images/time_ir:.2f}"
  )
  
model_test(teacher)
model_test(distiller)
model_test(student_scratch)
Step8. 实验结果

模型名称

模型大小

模型评估

模型性能(-n 200)

teacher model

5.46M

loss: 0.0565 aux: 0.9855

per image: : 0.0561seconds FPS: 17.82

distiller model

0.80M

loss: 0.0525 aux: 0.9801

per image: : 0.0502seconds FPS: 19.86

student model

0.80M

loss: 0.06131 auc: 0.97129

per image: 0.0502 seconds FPS: 19.92

业务中的某个召回模型测试

Step1. 定义teacher模型

这里为了直接使用标准的wide_deep & din 的模型, 只是交叉层为3, 同时还加入了gateNet

Step2. 定义student模型

对比teacher, 主要是减少了cross交叉的层数以及din层的每个dense的unit数.

Step3. 模型蒸馏

这里放一下蒸馏模型的定义, 主要是使用

代码语言:python
代码运行次数:0
复制
from datetime import time

import tensorflow as tf
from tensorflow import keras


# 模型蒸馏-MD方案
class KnowledgeDistillation(keras.Model):
    def __init__(self, student_model, teacher_model):
        super(KnowledgeDistillation, self).__init__()
        self.teacher_model = teacher_model
        self.student_model = student_model

    def compile(
            self,
            optimizer,
            metrics,
            student_loss_fn,
            distillation_loss_fn,
            alpha=0.1,
    ):
        super(KnowledgeDistillation, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher_model(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student_model(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)

            # KD loss
            distillation_loss = self.distillation_loss_fn(teacher_predictions, student_predictions)

            total_loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student_model.trainable_variables
        gradients = tape.gradient(total_loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student_model(x, training=False)

        # Calculate the loss
        loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": loss})
        return results

    def call(self, inputs, training=None, mask=None):
        return self.student_model(inputs, training=training)

训练时, 不同之处在于不再使用KL散度作为teacher和student之间预估差异的衡量, 而是使用BinaryCrossentropy

代码语言:txt
复制
    optimizer = tf.keras.optimizers.Adam(learning_rate=params["learning_rate"])
    loss = tf.keras.losses.BinaryCrossentropy()
    metrics = [tf.keras.metrics.AUC()]
    teacher_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    # Initialize and compile distiller
    distiller = KnowledgeDistillation(student_model=student_model, teacher_model=teacher_model)
    distiller.compile(
        optimizer=optimizer,
        metrics=metrics,
        student_loss_fn=loss,
        distillation_loss_fn=loss,
        alpha=0.2
    )

    path = "models/distiller_model"
    callbacks = [
        tf.keras.callbacks.TensorBoard(log_dir="./logs", histogram_freq=5, profile_batch=3),
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_student_loss', min_lr=1e-6),
        tf.keras.callbacks.EarlyStopping(monitor='val_student_loss', patience=1, mode="min"),
        tf.keras.callbacks.ModelCheckpoint(
            path, monitor='val_student_loss', verbose=1, save_best_only=True, mode='min'
        ),
    ]

    # Distill teacher to student
    distiller.fit(x=train_dataset, validation_data=valid_dataset, epochs=20, callbacks=[callbacks])

    # 保存模型
    tf.saved_model.save(distiller.student_model, path)
step4. 性能测试
  • 性能测试脚本
代码语言:python
代码运行次数:0
复制
# -*- coding: utf-8 -*-
import os
import argparse
import time

from statistics import mean, median

os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf


def decode_line(line):
    line = tf.strings.regex_replace(line, "\x00", "")  # tf.Tensor:shape=(1,)
    columns = tf.strings.split([line], ' ')  # tf.RaggedTensor:shape=(1, None)
    # get columns
    labels = tf.strings.to_number(columns.values[7], out_type=tf.float32)  # tf.RaggedTensor:shape=()
    features = tf.strings.split(columns.values[8], '&')
    # get dense
    dense = tf.strings.split(features[0], ',')[:587]  # tf.EagerTensor:shape=(`field_size`, None)
    dense = tf.strings.to_number(dense, out_type=tf.float32)
    dense = tf.reshape(dense, shape=[-1, ])
    # get sparse
    pairs = tf.strings.split(features[1], ',')  # tf.EagerTensor:shape=(`field_size`, None)
    id_vals = tf.strings.split(pairs, ':').to_tensor()
    feat_ids, feat_vals = tf.split(id_vals, num_or_size_splits=2, axis=1)
    feat_ids = tf.strings.to_number(feat_ids, out_type=tf.int32)  # tf.Tensor:shape=(`field_size`, 1)
    feat_vals = tf.strings.to_number(feat_vals, out_type=tf.float32)  # tf.Tensor:shape=(`field_size`, 1)
    feat_ids = tf.reshape(feat_ids, [-1])  # tf.Tensor:shape=(`field_size`,)
    feat_vals = tf.reshape(feat_vals, [-1])  # tf.Tensor:shape=(`field_size`,)
    # get seq
    seq_50 = tf.strings.split(features[2], '#')
    seq_50 = tf.strings.split(seq_50, ',')
    seq_50 = tf.strings.to_number(seq_50, out_type=tf.int32).to_tensor()

    # return
    return {"dense_input": dense, "sparse_input": feat_ids, "sparse_wgt_input": feat_vals,
            "seq_input": seq_50, }, labels


def input_fn(filenames, batch_size=256, batch_shuffle=False):
    """Directly read from `filenames`."""
    ds = tf.data.TextLineDataset(filenames).map(decode_line, num_parallel_calls=20)
    if batch_shuffle:
        ds = ds.shuffle(buffer_size=batch_size)
    ds = ds.batch(batch_size)
    return ds


def getFileSize(filePath, size=0):
    for root, dirs, files in os.walk(filePath):
        for f in files:
            size += os.path.getsize(os.path.join(root, f))
    return size


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', dest="model_path", help="model_path",
                        default="/Users/guirong/go/src/ai/kd/models/teacher_model", type=str)
    parser.add_argument('--batch_size', dest="batch_size", type=int, default=100, help='Batch size')
    parser.add_argument('--total_step', dest="total_step", type=int, default=1000, help='total_step')
    parser.add_argument('--warm_step', dest="warm_step", type=int, default=500, help='warm_step')
    parser.add_argument('--train_files', dest="train_files", type=str,
                        default="/Users/guirong/go/src/ai/data/0000000000/train_10005.libsvm",
                        help='train files')

    parser.add_argument('--test_files', dest="test_files", type=str,
                        default="/Users/guirong/go/src/ai/data/0000000000/test_10005.libsvm",
                        help='test files')
    parser.add_argument('--val_files', dest="val_files", type=str,
                        default="/Users/guirong/go/src/ai/data/0000000000/valid_10005.libsvm",
                        help='valid files')
    args = parser.parse_args()

    print("\n=========================================")
    print("Inference using Native TensorFlow\n")
    print("Model Path:", args.model_path)
    print("Test File Path:", args.test_files)
    print("Model size:", getFileSize(args.model_path) / 1024 / 1024, "M")
    print("Batch size:", args.batch_size)
    print("Total Step:", args.total_step)
    print("Warm Step:", args.warm_step)
    time.sleep(2)
    root = tf.saved_model.load(os.path.join(args.model_path))
    predict = root.signatures['serving_default']
    output_tensor_name = list(predict.structured_outputs.keys())[0]

    ds = input_fn(filenames=args.test_files, batch_size=args.batch_size)
    iterator = iter(ds)
    features, labels = iterator.get_next()
    start_test = time.time()
    print("\nStart Process At ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    start_first = start_test
    try:
        step_times = list()
        for step in range(1, args.total_step + 1):
            start_t = time.time()
            if step % args.warm_step == 0:
                print("Processing step: %04d ..." % step)
                print("time cost %d" % (start_t - start_test))
                start_test = time.time()

            probs = predict(dense_input=features['dense_input'],
                            seq_50_input=features['seq_input'],
                            sparse_ids_input=features['sparse_input'],
                            sparse_wgt_input=features['sparse_wgt_input']
                            )[output_tensor_name]
            inferred_class = tf.math.argmax(probs).numpy()
            step_time = time.time() - start_t
            if step >= args.warm_step:
                step_times.append(step_time)
    except tf.errors.OutOfRangeError:
        pass
    end_test = time.time()
    print("Total time cost %d" % (end_test - start_first))
    print("End Process At ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

    print("\nAverage step time: %.1f msec" % (mean(step_times) * 1e3))
    print("Average throughput: %d samples/sec" % (
            args.total_step / (end_test - start_first)
    ))
Step5. 实验测试结果

模型名称

大小

效果指标

QPS

请求最大时延

提高

teacher_model

101.91 M

Test: Loss(0.05863), AUC(0.68112) Valid: Loss(0.05974), AUC(0.69937)

60.85317

24.26386 ms

-

distillation_model

13.997 M

Test: Loss(0.05614), AUC(0.70414) Valid: Loss(0.05809), AUC(0.71531)

149.97829

7.56311 ms

+1.46倍

student_model

39.354 M

Test: Loss(0.05600), AUC(0.70073) Valid: Loss(0.05783), AUC(0.71431)

137.67856

9.01294 ms

+1.26倍

三. 遇到的问题

  • 输出结果异常导致无法进行蒸馏
    image-20220416214338324.png
    image-20220416214338324.png
    解决方法: 调整模型输出结构, 从原来的dict变成直接输出一个单值
  • 蒸馏模型无法保存
    image-20220416214243869.png
    image-20220416214243869.png
    解决方法: 直接保存student模型, tensorflow2的api太乱了, 直接保存蒸馏模型内部会存在一些api冲突.

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 模型蒸馏
    • 一. 基础理论
      • 背景
      • 核心问题
      • 业界方案
      • 应用场景
    • 二. 模型蒸馏实战
      • 官方案例
    • 业务中的某个召回模型测试
      • 三. 遇到的问题
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档