之前在基于Tensorflow的神经网络解决用户流失概率问题写了一个MLPs的网络,很多人在问,其实这个网络看起来很清晰,但是却写的比较冗长,这边优化了一个版本更方便大家修改后直接使用。
多层感知机网络
直接和大家过一遍核心部分:
din_all = tf.layers.batch_normalization(inputs=din_all, name='b1')
layer_1 = tf.layers.dense(din_all, self.layers_nodes[0], activation=tf.nn.sigmoid,use_bias=True,kernel_regularizer=tf.contrib.layers.l2_regularizer(self.regularzation_rate),name='f1')
layer_1 = tf.nn.dropout(layer_1, keep_prob=self.drop_rate[0])
layer_2 = tf.layers.dense(layer_1, self.layers_nodes[1], activation=tf.nn.sigmoid,use_bias=True,kernel_regularizer=tf.contrib.layers.l2_regularizer(self.regularzation_rate),name='f2')
layer_2 = tf.nn.dropout(layer_2, keep_prob=self.drop_rate[1])
layer_3 = tf.layers.dense(layer_2, self.layers_nodes[2], activation=tf.nn.sigmoid,use_bias=True,kernel_regularizer=tf.contrib.layers.l2_regularizer(self.regularzation_rate),name='f3')
上次我们计算过程中,通过的是先定义好多层网络中每层的weight,在通过tf.matual
进行层与层之间的计算,最后再通过tf.contrib.layers.l2_regularizer进行正则;而这次我们直接通过图像识别中经常使用的全连接(FC)的接口,只需要确定每层的节点数,通过layers_nodes
进行声明,自动可以计算出不同层下的weight,更加清晰明了。另外,还增加了dropout的部分,降低过拟合的问题。
tf.layers.dense
接口信息如下:
tf.layers.dense(
inputs,
units,
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
reuse=None
)
除此之外,之前我们定义y和y_的时候把1转化为1,0,转化为了0,1,增加了工程量,这次我们通过:
cross_entropy_mean = -tf.reduce_mean(self.y_ * tf.log(self.output + 1e-24))
self.loss = cross_entropy_mean
直接进行计算,避免了一些无用功。
最后,之前对于梯度的值没有进行限制,会导致整体模型的波动过大,这次优化中也做了修改,如果大家需要也可以参考一下:
# 我们用learning_rate_base作为速率η,来训练梯度下降的loss函数解,对梯度进行限制后计算loss
opt = tf.train.GradientDescentOptimizer(self.learning_rate_base)
trainable_params = tf.trainable_variables()
gradients = tf.gradients(self.loss, trainable_params)
clip_gradients, _ = tf.clip_by_global_norm(gradients, 5)
self.train_op = opt.apply_gradients(zip(clip_gradients, trainable_params), global_step=self.global_step)
MLPs是入门级别的神经网络算法,实际的工业开发中使用的频率也不高,后面我准备和大家过一下常见的FM、FFM、DeepFM、NFM、DIN、MLR等在工业开发中更为常见的网络,欢迎大家持续关注。
完整代码已经上传到Github中。