今天写点代码方面的内容,昨天已经简单讲解过paper了,只看文章的话,理解的效果一般,所以今天结合代码再来充分认识PointNet。代码分为分类、分割两部分,本文以分类为例。
关于论文的讲解,感兴趣的可以点这里。
网络结构
这部分代码位于pointnet_cls.py中。
def get_model(point_cloud, is_training, bn_decay=None):
""" Classification PointNet,input is BxNx3, output Bx40 """
batch_size =point_cloud.get_shape()[0].value
num_point =point_cloud.get_shape()[1].value
end_points = {} with tf.variable_scope('transform_net1') assc:
transform = input_transform_net(point_cloud,is_training, bn_decay, K=3)
point_cloud_transformed =tf.matmul(point_cloud, transform)
input_image =tf.expand_dims(point_cloud_transformed, -1)#在最后增加一个维度 net = tf_util.conv2d(input_image, 64,[1,3],
padding='VALID', stride=[1,1],
bn=True,is_training=is_training,
scope='conv1',bn_decay=bn_decay)
net = tf_util.conv2d(net, 64, [1,1],
padding='VALID',stride=[1,1],
bn=True,is_training=is_training,
scope='conv2',bn_decay=bn_decay) with tf.variable_scope('transform_net2') assc:
transform = feature_transform_net(net,is_training, bn_decay, K=64)
end_points['transform'] = transform
net_transformed = tf.matmul(tf.squeeze(net,axis=[2]), transform)
net_transformed =tf.expand_dims(net_transformed, [2]) net = tf_util.conv2d(net_transformed, 64,[1,1],
padding='VALID',stride=[1,1],
bn=True,is_training=is_training,
scope='conv3',bn_decay=bn_decay)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID',stride=[1,1],
bn=True,is_training=is_training,
scope='conv4',bn_decay=bn_decay)
net = tf_util.conv2d(net, 1024, [1,1],
padding='VALID',stride=[1,1],
bn=True,is_training=is_training,
scope='conv5',bn_decay=bn_decay) # Symmetric function: max pooling
net = tf_util.max_pool2d(net,[num_point,1],
padding='VALID',scope='maxpool') net = tf.reshape(net, [batch_size,-1])
net = tf_util.fully_connected(net, 512,bn=True, is_training=is_training,
scope='fc1',bn_decay=bn_decay)
net = tf_util.dropout(net, keep_prob=0.7,is_training=is_training,
scope='dp1')
net = tf_util.fully_connected(net, 256,bn=True, is_training=is_training,
scope='fc2',bn_decay=bn_decay)
net = tf_util.dropout(net, keep_prob=0.7,is_training=is_training,
scope='dp2')
net = tf_util.fully_connected(net, 40,activation_fn=None, scope='fc3') return net, end_points
模型可以分为特征提取和分类两大块。
特征提取
在特征提取部分,与论文中描述的相同,T-net——>mlp(64,64)——>T-net——>mlp(64,128,1024)。首先上来是个T-net用来把点云摆放到一个合适的角度;接下来是两层conv2d卷积层,第一个用(1,3)的卷积核把(B,N,3,1)变为(B,N,1,64),第二个相当于全连接层,对数据结构没影响;再接下来又是一个T-net用于特征的对齐;然后是3层mlp用来升维得到特征。特征提取部分到此结束。
分类任务
接下来是分类任务部分(当然也可以换成分割任务),3个全连接层,最终得到40个类别。
至此网络框架就介绍完了。
PointNet中使用了maxpooling和T-net,作者文章中起到关键作用的是maxpooling,而T-net对性能的提升作用也还是有的。
接下来就重点分析PointNet中的T-net代码,这部分代码位于transform_nets.py脚本中。
def feature_transform_net(inputs, is_training, bn_decay=None, K=64):
""" Feature Transform Net,input is BxNx1xK
Return:
Transformation matrix of size KxK"""
batch_size =inputs.get_shape()[0].value
num_point =inputs.get_shape()[1].value net = tf_util.conv2d(inputs, 64,[1,1],
padding='VALID',stride=[1,1],
bn=True, is_training=is_training,
scope='tconv1',bn_decay=bn_decay)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID',stride=[1,1],
bn=True,is_training=is_training,
scope='tconv2',bn_decay=bn_decay)
net = tf_util.conv2d(net, 1024, [1,1],
padding='VALID',stride=[1,1],
bn=True,is_training=is_training,
scope='tconv3',bn_decay=bn_decay)
net = tf_util.max_pool2d(net,[num_point,1],#池化窗口是[num_point,1]
padding='VALID',scope='tmaxpool') net = tf.reshape(net, [batch_size, -1])#变成两维
net = tf_util.fully_connected(net, 512,bn=True, is_training=is_training,
scope='tfc1',bn_decay=bn_decay)
net = tf_util.fully_connected(net, 256,bn=True, is_training=is_training,
scope='tfc2',bn_decay=bn_decay) with tf.variable_scope('transform_feat') assc:
weights = tf.get_variable('weights',[256, K*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases = tf.get_variable('biases',[K*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases +=tf.constant(np.eye(K).flatten(), dtype=tf.float32)
transform = tf.matmul(net,weights)
transform = tf.nn.bias_add(transform,biases) transform = tf.reshape(transform,[batch_size, K, K])
return transform
代码主体部分,前三个conv2d用来升维,接着一个max_pool2d把1024个点的特征做了maxpooling,融合成一点。然后跟两个fully_connected把维度降到256,再然后是跟[256,K*K]的权值相乘再加K*K维的偏移,达到[batch_size, K*K],最后变形成[batch_size,K, K],大功告成。
下一篇会讲讲PointNet++,由于是改进版,所以可能会结合代码一起介绍。