前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PointNet代码解读

PointNet代码解读

作者头像
点云乐课堂
发布2020-05-18 14:23:24
1.2K0
发布2020-05-18 14:23:24
举报
文章被收录于专栏:3D点云深度学习3D点云深度学习

今天写点代码方面的内容,昨天已经简单讲解过paper了,只看文章的话,理解的效果一般,所以今天结合代码再来充分认识PointNet。代码分为分类、分割两部分,本文以分类为例。

关于论文的讲解,感兴趣的可以点这里

网络结构

这部分代码位于pointnet_cls.py中。

代码语言:javascript
复制
def get_model(point_cloud, is_training, bn_decay=None):  
   """ Classification PointNet,input is BxNx3, output Bx40 """
   batch_size =point_cloud.get_shape()[0].value  
   num_point =point_cloud.get_shape()[1].value  
   end_points = {}     with tf.variable_scope('transform_net1') assc:  
       transform = input_transform_net(point_cloud,is_training, bn_decay, K=3)  
   point_cloud_transformed =tf.matmul(point_cloud, transform)  
   input_image =tf.expand_dims(point_cloud_transformed, -1)#在最后增加一个维度     net = tf_util.conv2d(input_image, 64,[1,3],  
                        padding='VALID', stride=[1,1],  
                        bn=True,is_training=is_training,  
                        scope='conv1',bn_decay=bn_decay)  
   net = tf_util.conv2d(net, 64, [1,1],  
                        padding='VALID',stride=[1,1],  
                        bn=True,is_training=is_training,  
                        scope='conv2',bn_decay=bn_decay)     with tf.variable_scope('transform_net2') assc:  
       transform = feature_transform_net(net,is_training, bn_decay, K=64)  
   end_points['transform'] = transform  
   net_transformed = tf.matmul(tf.squeeze(net,axis=[2]), transform)  
   net_transformed =tf.expand_dims(net_transformed, [2])     net = tf_util.conv2d(net_transformed, 64,[1,1],  
                        padding='VALID',stride=[1,1],  
                        bn=True,is_training=is_training,  
                        scope='conv3',bn_decay=bn_decay)  
   net = tf_util.conv2d(net, 128, [1,1],  
                        padding='VALID',stride=[1,1],  
                        bn=True,is_training=is_training,  
                        scope='conv4',bn_decay=bn_decay)  
   net = tf_util.conv2d(net, 1024, [1,1],  
                        padding='VALID',stride=[1,1],  
                        bn=True,is_training=is_training,  
                        scope='conv5',bn_decay=bn_decay)     # Symmetric function: max pooling  
   net = tf_util.max_pool2d(net,[num_point,1],  
                            padding='VALID',scope='maxpool')     net = tf.reshape(net, [batch_size,-1])  
   net = tf_util.fully_connected(net, 512,bn=True, is_training=is_training,  
                                 scope='fc1',bn_decay=bn_decay)  
   net = tf_util.dropout(net, keep_prob=0.7,is_training=is_training,  
                         scope='dp1')  
   net = tf_util.fully_connected(net, 256,bn=True, is_training=is_training,  
                                 scope='fc2',bn_decay=bn_decay)  
   net = tf_util.dropout(net, keep_prob=0.7,is_training=is_training,  
                         scope='dp2')  
   net = tf_util.fully_connected(net, 40,activation_fn=None, scope='fc3')     return net, end_points

模型可以分为特征提取和分类两大块。

特征提取

在特征提取部分,与论文中描述的相同,T-net——>mlp(64,64)——>T-net——>mlp(64,128,1024)。首先上来是个T-net用来把点云摆放到一个合适的角度;接下来是两层conv2d卷积层,第一个用(1,3)的卷积核把(B,N,3,1)变为(B,N,1,64),第二个相当于全连接层,对数据结构没影响;再接下来又是一个T-net用于特征的对齐;然后是3层mlp用来升维得到特征。特征提取部分到此结束。

分类任务

接下来是分类任务部分(当然也可以换成分割任务),3个全连接层,最终得到40个类别。

至此网络框架就介绍完了。

PointNet中使用了maxpooling和T-net,作者文章中起到关键作用的是maxpooling,而T-net对性能的提升作用也还是有的。

接下来就重点分析PointNet中的T-net代码,这部分代码位于transform_nets.py脚本中。

代码语言:javascript
复制
def feature_transform_net(inputs, is_training, bn_decay=None, K=64):  
   """ Feature Transform Net,input is BxNx1xK
       Return:
           Transformation matrix of size KxK"""  
   batch_size =inputs.get_shape()[0].value  
   num_point =inputs.get_shape()[1].value     net = tf_util.conv2d(inputs, 64,[1,1],  
                        padding='VALID',stride=[1,1],  
                        bn=True, is_training=is_training,  
                        scope='tconv1',bn_decay=bn_decay)  
   net = tf_util.conv2d(net, 128, [1,1],  
                        padding='VALID',stride=[1,1],  
                        bn=True,is_training=is_training,  
                        scope='tconv2',bn_decay=bn_decay)  
   net = tf_util.conv2d(net, 1024, [1,1],  
                        padding='VALID',stride=[1,1],  
                        bn=True,is_training=is_training,  
                        scope='tconv3',bn_decay=bn_decay)  
   net = tf_util.max_pool2d(net,[num_point,1],#池化窗口是[num_point,1]  
                            padding='VALID',scope='tmaxpool')     net = tf.reshape(net, [batch_size, -1])#变成两维  
   net = tf_util.fully_connected(net, 512,bn=True, is_training=is_training,  
                                 scope='tfc1',bn_decay=bn_decay)  
   net = tf_util.fully_connected(net, 256,bn=True, is_training=is_training,  
                                 scope='tfc2',bn_decay=bn_decay)     with tf.variable_scope('transform_feat') assc:  
       weights = tf.get_variable('weights',[256, K*K],  
                                initializer=tf.constant_initializer(0.0),  
                                dtype=tf.float32)  
       biases = tf.get_variable('biases',[K*K],  
                               initializer=tf.constant_initializer(0.0),  
                               dtype=tf.float32)  
       biases +=tf.constant(np.eye(K).flatten(), dtype=tf.float32)  
       transform = tf.matmul(net,weights)  
       transform = tf.nn.bias_add(transform,biases)     transform = tf.reshape(transform,[batch_size, K, K])  
   return transform

代码主体部分,前三个conv2d用来升维,接着一个max_pool2d把1024个点的特征做了maxpooling,融合成一点。然后跟两个fully_connected把维度降到256,再然后是跟[256,K*K]的权值相乘再加K*K维的偏移,达到[batch_size, K*K],最后变形成[batch_size,K, K],大功告成。

下一篇会讲讲PointNet++,由于是改进版,所以可能会结合代码一起介绍。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-05-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 3D点云深度学习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档