作者:杰少,炼丹笔记嘉宾
简介

TabNet是19年Google提出的一种新的可解释的处理表格数据的深度学习框架,之前大多数朋友在尝试的时候发现其效果相较于传统的树模型效果其实差了较多,但最近其在Optiver金融赛中大方异彩纯TabNet在开源中拿到了极高的分数,所以便打算重新回顾一遍TabNet。
文章我们直接依据TabNet的网络结构对其进行拆解,并且给出对应的代码实现,有兴趣的朋友可以学习或者温故一下。
TabNet

01
模型框架

02
TabNet流程
TabNet每次先对数据进行非线性变化谈后使用Attentive transfermer得到mask并与原始数据进行处理在进行特征transformer分别得到最终的输入和下一层的Mask,像是对原始输入进行多层特征的筛选,并将筛选的信息进行组合用于最终的预测。
原始的数值输入+类别特征Embedding输入之后无需做任何预处理,直接套用BN层。
每次我们将BN之后的数据输入到FC+BN+GLU的框架中,并将前一层得到的输入和后续的结果进行集成(残差相加的形式);
此处直接使用FC+BN之后再用先验相乘,之后使用Sparsemax进行处理得到我们的Mask。
代码

# hide
import multiprocessing as mp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow_addons.activations import sparsemax
%matplotlib inline
def GLU(x):
return x * tf.sigmoid(x)
class FCBlock(layers.Layer):
'''
FC+BN+GLU
'''
def __init__(self, units):
super().__init__()
self.layer = layers.Dense(units)
self.bn = layers.BatchNormalization()
def call(self, x):
return GLU(self.bn(self.layer(x)))
class SharedBlock(layers.Layer):
'''
两个FCBlock并且最终\sqrt(0.5) * o1 + o2
'''
def __init__(self, units, mult=tf.sqrt(0.5)):
super().__init__()
self.layer1 = FCBlock(units)
self.layer2 = FCBlock(units)
self.mult = mult
def call(self, x):
out1 = self.layer1(x)
out2 = self.layer2(out1)
return out2 + self.mult * out1
class DecisionBlock(SharedBlock):
'''
两个FCBlock并且中间都进行相加\sqrt(0.5) * o1 + o2
'''
def __init__(self, units, mult=tf.sqrt(0.5)):
super().__init__(units, mult)
def call(self, x):
out1 = x * self.mult + self.layer1(x)
out2 = out1 * self.mult + self.layer2(out1)
return out2
class Prior(layers.Layer):
'''
Prior Scale, P* \gamma - mask
'''
def __init__(self, gamma=1.1):
super().__init__()
self.gamma = gamma
def reset(self):
self.P = 1.0
def call(self, mask):
self.P = self.P * (self.gamma - mask)
return self.P
class AttentiveTransformer(layers.Layer):
def __init__(self, units):
super().__init__()
self.layer = layers.Dense(units)
self.bn = layers.BatchNormalization()
def call(self, x, prior):
return sparsemax(prior * self.bn(self.layer(x)))
# collapse
class TabNet(keras.Model):
def __init__(self, input_dim, output_dim, steps, n_d, n_a, gamma=1.3):
super().__init__()
# hyper-parameters
self.n_d, self.n_a, self.steps = n_d, n_a, steps
# input-normalisation
self.bn = layers.BatchNormalization()
# Feature Transformer
self.shared = SharedBlock(n_d+n_a)
self.first_block = DecisionBlock(n_d+n_a)
self.decision_blocks = [DecisionBlock(n_d+n_a)] * steps
# Attentive Transformer
self.attention = [AttentiveTransformer(input_dim)] * steps
self.prior_scale = Prior(gamma)
# final layer
self.final = layers.Dense(output_dim)
self.eps = 1e-8
self.add_layer = layers.Add()
@tf.function
def call(self, x):
self.prior_scale.reset()
final_outs = []
mask_losses = []
# 1.输入
x = self.bn(x)
# 2. 第一轮:Feature Transformer
attention = self.first_block(self.shared(x))[:,:self.n_a]
# 3. 后续都是Attention Transformer+Feature Transformer的组合
for i in range(self.steps):
# Attention Transformer
mask = self.attention[i](attention, self.prior_scale.P)
entropy = mask * tf.math.log(mask + self.eps)
mask_losses.append(
-tf.reduce_sum(entropy, axis=-1) / self.steps
)
######### Attention Transformer右侧 ##########
# Feature Transformer
prior = self.prior_scale(mask)
out = self.decision_blocks[i](self.shared(x * prior))
# Split
attention, output = out[:,:self.n_a], out[:,self.n_a:]
# Relu
final_outs.append(tf.nn.relu(output))
final_out = self.add_layer(final_outs)
mask_loss = self.add_layer(mask_losses)
return self.final(final_out), mask_loss
def mask_importance(self, x):
self.prior_scale.reset()
feature_importance = 0
x = self.bn(x)
attention = self.first_block(self.shared(x))[:,:self.n_a]
for i in range(self.steps):
mask = self.attention[i](attention, self.prior_scale.P)
prior = self.prior_scale(mask)
out = self.decision_blocks[i](self.shared(x * prior))
attention, output = out[:,:self.n_a], out[:,self.n_a:]
step_importance = tf.reduce_sum(tf.nn.relu(output), axis=1, keepdims=True)
feature_importance += mask * step_importance
return feature_importance
# collapse
from keras.losses import SparseCategoricalCrossentropy
sce = SparseCategoricalCrossentropy(from_logits=True)
reg_sparse = 0.01
def full_loss(y_true, y_pred):
logits, mask_loss = y_pred
return sce(y_true, logits) + reg_sparse * mask_loss.mean()
# collapse
def mask_loss(y_true, mask_losses):
return tf.reduce_mean(mask_losses)
model = TabNet(784, 10, 2, 10, 10, 1.3)
model.compile(
'Adam',
loss=[sce, mask_loss],
loss_weights=[1, 0.01]
)小结

本文的TabNet在Kaggle的Optiver竞赛中获得了非常好的效果,对NN处理表格数据感兴趣的朋友TabNet无疑是不得不学的网络结构了。
参考文献
