知识蒸馏(Knowledge Distillation)最早是Hinton 2014年在论文Dislillation the Knowledge in a Neural Network中提出的概念,主要思想是通过教师模型(teacher)来指导学生模型(student)的训练,将复杂、学习能力强的教师模型学到的特征表示“知识蒸馏”出来,传递给参数小、学习能力弱的学生模型,从而得到一个速度快、表达能力强的学生模型。
由于知识蒸馏在训练过程中,有两个模型(teacher模型和student模型),3个loss(teacher loss、student loss、蒸馏loss),因此如何平衡不同模型之间的训练会是影响模型效果的重要因素。
在MD中,教师模型和学生模型处理相同的输入特征,其中教师模型会比学生模型更为复杂,比如教师模型会用更深的网络结构来指导使用浅层网络的学生模型进行学习。MD最早是在2014年的由Hinton和谷歌大佬Jeff Dean联合提出的《Distilling the Knowledge in a Neural Network》.
Hinton是一位超级学霸, 多领域的专家, 图灵奖得主, Jeff Dean 谷歌第20号员工, Google AI的负责人,Tensorflow就是在他的领导下开发的.
在PFD中,教师模型和学生模型使用相同网络结构,而处理不同的输入特征。图1(b)中PFD的学生模型只处理常规特征,而教师模型同时处理常规特征和优势特征。
优势特征(Privileged Features): 区分度高、但只能离线获取的特征, 特征主要包含用户行为特征、用户特征、商品特征和交叉特征4个部分, 其中:
MD方案是teacher和student的输出做拟合, PFD方案是teacher和student的特征分布做拟合, PFD是2018年阿里提出《Privileged Features Distillation at Taobao Recommendations》
工业界,一些在线模型,对响应时间提出非常严苛的要求,从而一定程度上限制了模型的复杂程度。模型复杂程度的受限可能会导致模型学习能力的降低从而带来效果的下降。 为了解决这一问题,2018年阿里妈妈提出了Rocket Training,利用复杂的模型来辅助一个精简模型的训练,测试阶段,利用学习好的小模型来进行推断。论文地址: 《Rocket Launching: A Universal and Efficient Framework for Training Well-performing Light Net》
为什么叫火箭发射(Rocket Training)?是因为整个训练到预测的过程就像火箭发射的过程。
如上图就是Rocket Training模型结构,左侧是教师模型结构,右侧是学生模型结构。模型有以下特点:
蒸馏误差表示如下:
详见《Privileged Features Distillation at Taobao Recommendations》
说明: keras官方提供的蒸馏方案是一个标准的MD方案, teacher和student使用相同的输入, 通过teacher和student的输出拟合来将teacher的信息迁移到student.
参考的论文: 《Distilling the Knowledge in a Neural Network》
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
class Distiller(keras.Model):
def __init__(self, student, teacher):
super(Distiller, self).__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def train_step(self, data):
# Unpack data
x, y = data
# Forward pass of teacher
teacher_predictions = self.teacher(x, training=False)
with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
# Compute gradients
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
self.compiled_metrics.update_state(y, student_predictions)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update(
{"student_loss": student_loss, "distillation_loss": distillation_loss}
)
return results
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_prediction = self.student(x, training=False)
# Calculate the loss
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update({"student_loss": student_loss})
return results
# Create the teacher
teacher = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="teacher",
)
# Create the student
student = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="student",
)
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
# Train teacher as usual
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data. Teacher网络比较大, 需要更多轮次保证模型不会欠拟合
teacher.fit(x_train, y_train, epochs=6)
teacher.evaluate(x_test, y_test)
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=5)
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
# Train student as doen usually
student_scratch.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
def model_test(model):
num_images = 256
start = time.perf_counter()
for _ in range(num_images):
index = random.randint(0, x_test.shape[0])
x = x_test[index]
y = y_test[index]
x.shape = (1, 28, 28, 1) # 变成[[]]
predict = model.predict(x)
predict = np.argmax(predict) # 取最大值的位置
end = time.perf_counter()
time_ir = end - start
print(
f"model in Inference Engine/CPU: {time_ir/num_images:.4f} "
f"seconds per image, FPS: {num_images/time_ir:.2f}"
)
model_test(student_scratch)
def model_test(model):
num_images = 256
start = time.perf_counter()
for _ in range(num_images):
index = random.randint(0, x_test.shape[0])
x = x_test[index]
y = y_test[index]
x.shape = (1, 28, 28, 1) # 变成[[]]
predict = model.predict(x)
predict = np.argmax(predict) # 取最大值的位置
end = time.perf_counter()
time_ir = end - start
print(
f"model in Inference Engine/CPU: {time_ir/num_images:.4f} "
f"seconds per image, FPS: {num_images/time_ir:.2f}"
)
model_test(teacher)
model_test(distiller)
model_test(student_scratch)
模型名称 | 模型大小 | 模型评估 | 模型性能(-n 200) |
---|---|---|---|
teacher model | 5.46M | loss: 0.0565 aux: 0.9855 | per image: : 0.0561seconds FPS: 17.82 |
distiller model | 0.80M | loss: 0.0525 aux: 0.9801 | per image: : 0.0502seconds FPS: 19.86 |
student model | 0.80M | loss: 0.06131 auc: 0.97129 | per image: 0.0502 seconds FPS: 19.92 |
这里为了直接使用标准的wide_deep & din 的模型, 只是交叉层为3, 同时还加入了gateNet
对比teacher, 主要是减少了cross交叉的层数以及din层的每个dense的unit数.
这里放一下蒸馏模型的定义, 主要是使用
from datetime import time
import tensorflow as tf
from tensorflow import keras
# 模型蒸馏-MD方案
class KnowledgeDistillation(keras.Model):
def __init__(self, student_model, teacher_model):
super(KnowledgeDistillation, self).__init__()
self.teacher_model = teacher_model
self.student_model = student_model
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
):
super(KnowledgeDistillation, self).compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
def train_step(self, data):
# Unpack data
x, y = data
# Forward pass of teacher
teacher_predictions = self.teacher_model(x, training=False)
with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student_model(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
# KD loss
distillation_loss = self.distillation_loss_fn(teacher_predictions, student_predictions)
total_loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
# Compute gradients
trainable_vars = self.student_model.trainable_variables
gradients = tape.gradient(total_loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
self.compiled_metrics.update_state(y, student_predictions)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update(
{"student_loss": student_loss, "distillation_loss": distillation_loss}
)
return results
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_prediction = self.student_model(x, training=False)
# Calculate the loss
loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update({"student_loss": loss})
return results
def call(self, inputs, training=None, mask=None):
return self.student_model(inputs, training=training)
训练时, 不同之处在于不再使用KL散度作为teacher和student之间预估差异的衡量, 而是使用BinaryCrossentropy
optimizer = tf.keras.optimizers.Adam(learning_rate=params["learning_rate"])
loss = tf.keras.losses.BinaryCrossentropy()
metrics = [tf.keras.metrics.AUC()]
teacher_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
# Initialize and compile distiller
distiller = KnowledgeDistillation(student_model=student_model, teacher_model=teacher_model)
distiller.compile(
optimizer=optimizer,
metrics=metrics,
student_loss_fn=loss,
distillation_loss_fn=loss,
alpha=0.2
)
path = "models/distiller_model"
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir="./logs", histogram_freq=5, profile_batch=3),
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_student_loss', min_lr=1e-6),
tf.keras.callbacks.EarlyStopping(monitor='val_student_loss', patience=1, mode="min"),
tf.keras.callbacks.ModelCheckpoint(
path, monitor='val_student_loss', verbose=1, save_best_only=True, mode='min'
),
]
# Distill teacher to student
distiller.fit(x=train_dataset, validation_data=valid_dataset, epochs=20, callbacks=[callbacks])
# 保存模型
tf.saved_model.save(distiller.student_model, path)
# -*- coding: utf-8 -*-
import os
import argparse
import time
from statistics import mean, median
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
def decode_line(line):
line = tf.strings.regex_replace(line, "\x00", "") # tf.Tensor:shape=(1,)
columns = tf.strings.split([line], ' ') # tf.RaggedTensor:shape=(1, None)
# get columns
labels = tf.strings.to_number(columns.values[7], out_type=tf.float32) # tf.RaggedTensor:shape=()
features = tf.strings.split(columns.values[8], '&')
# get dense
dense = tf.strings.split(features[0], ',')[:587] # tf.EagerTensor:shape=(`field_size`, None)
dense = tf.strings.to_number(dense, out_type=tf.float32)
dense = tf.reshape(dense, shape=[-1, ])
# get sparse
pairs = tf.strings.split(features[1], ',') # tf.EagerTensor:shape=(`field_size`, None)
id_vals = tf.strings.split(pairs, ':').to_tensor()
feat_ids, feat_vals = tf.split(id_vals, num_or_size_splits=2, axis=1)
feat_ids = tf.strings.to_number(feat_ids, out_type=tf.int32) # tf.Tensor:shape=(`field_size`, 1)
feat_vals = tf.strings.to_number(feat_vals, out_type=tf.float32) # tf.Tensor:shape=(`field_size`, 1)
feat_ids = tf.reshape(feat_ids, [-1]) # tf.Tensor:shape=(`field_size`,)
feat_vals = tf.reshape(feat_vals, [-1]) # tf.Tensor:shape=(`field_size`,)
# get seq
seq_50 = tf.strings.split(features[2], '#')
seq_50 = tf.strings.split(seq_50, ',')
seq_50 = tf.strings.to_number(seq_50, out_type=tf.int32).to_tensor()
# return
return {"dense_input": dense, "sparse_input": feat_ids, "sparse_wgt_input": feat_vals,
"seq_input": seq_50, }, labels
def input_fn(filenames, batch_size=256, batch_shuffle=False):
"""Directly read from `filenames`."""
ds = tf.data.TextLineDataset(filenames).map(decode_line, num_parallel_calls=20)
if batch_shuffle:
ds = ds.shuffle(buffer_size=batch_size)
ds = ds.batch(batch_size)
return ds
def getFileSize(filePath, size=0):
for root, dirs, files in os.walk(filePath):
for f in files:
size += os.path.getsize(os.path.join(root, f))
return size
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', dest="model_path", help="model_path",
default="/Users/guirong/go/src/ai/kd/models/teacher_model", type=str)
parser.add_argument('--batch_size', dest="batch_size", type=int, default=100, help='Batch size')
parser.add_argument('--total_step', dest="total_step", type=int, default=1000, help='total_step')
parser.add_argument('--warm_step', dest="warm_step", type=int, default=500, help='warm_step')
parser.add_argument('--train_files', dest="train_files", type=str,
default="/Users/guirong/go/src/ai/data/0000000000/train_10005.libsvm",
help='train files')
parser.add_argument('--test_files', dest="test_files", type=str,
default="/Users/guirong/go/src/ai/data/0000000000/test_10005.libsvm",
help='test files')
parser.add_argument('--val_files', dest="val_files", type=str,
default="/Users/guirong/go/src/ai/data/0000000000/valid_10005.libsvm",
help='valid files')
args = parser.parse_args()
print("\n=========================================")
print("Inference using Native TensorFlow\n")
print("Model Path:", args.model_path)
print("Test File Path:", args.test_files)
print("Model size:", getFileSize(args.model_path) / 1024 / 1024, "M")
print("Batch size:", args.batch_size)
print("Total Step:", args.total_step)
print("Warm Step:", args.warm_step)
time.sleep(2)
root = tf.saved_model.load(os.path.join(args.model_path))
predict = root.signatures['serving_default']
output_tensor_name = list(predict.structured_outputs.keys())[0]
ds = input_fn(filenames=args.test_files, batch_size=args.batch_size)
iterator = iter(ds)
features, labels = iterator.get_next()
start_test = time.time()
print("\nStart Process At ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
start_first = start_test
try:
step_times = list()
for step in range(1, args.total_step + 1):
start_t = time.time()
if step % args.warm_step == 0:
print("Processing step: %04d ..." % step)
print("time cost %d" % (start_t - start_test))
start_test = time.time()
probs = predict(dense_input=features['dense_input'],
seq_50_input=features['seq_input'],
sparse_ids_input=features['sparse_input'],
sparse_wgt_input=features['sparse_wgt_input']
)[output_tensor_name]
inferred_class = tf.math.argmax(probs).numpy()
step_time = time.time() - start_t
if step >= args.warm_step:
step_times.append(step_time)
except tf.errors.OutOfRangeError:
pass
end_test = time.time()
print("Total time cost %d" % (end_test - start_first))
print("End Process At ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
print("\nAverage step time: %.1f msec" % (mean(step_times) * 1e3))
print("Average throughput: %d samples/sec" % (
args.total_step / (end_test - start_first)
))
模型名称 | 大小 | 效果指标 | QPS | 请求最大时延 | 提高 |
---|---|---|---|---|---|
teacher_model | 101.91 M | Test: Loss(0.05863), AUC(0.68112) Valid: Loss(0.05974), AUC(0.69937) | 60.85317 | 24.26386 ms | - |
distillation_model | 13.997 M | Test: Loss(0.05614), AUC(0.70414) Valid: Loss(0.05809), AUC(0.71531) | 149.97829 | 7.56311 ms | +1.46倍 |
student_model | 39.354 M | Test: Loss(0.05600), AUC(0.70073) Valid: Loss(0.05783), AUC(0.71431) | 137.67856 | 9.01294 ms | +1.26倍 |
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。