首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >CNN中的过度拟合

CNN中的过度拟合
EN

Data Science用户
提问于 2018-07-08 16:18:49
回答 1查看 1.4K关注 0票数 3

我正在训练STL-10数据集上的VGG网。

我得到了前5位验证精度约98%,前1位验证精度约83%。

但是前1名和前5名的训练准确率都达到了100%。

这是否意味着网络太合适了?还是不想?

代码::

代码语言:javascript
运行
复制
def conv2d(inp,name,kshape,s):
    with tf.variable_scope(name) as scope:
        kernel = get_weights('weights',shape=kshape)
        conv = tf.nn.conv2d(inp,kernel,[1,s,s,1],'SAME')
        bias = get_bias('biases',shape=kshape[3])
        preact = tf.nn.bias_add(conv,bias)
        convlayer = tf.nn.relu(preact,name=scope.name)
    return convlayer

def maxpool(inp,name,k,s):
    return tf.nn.max_pool(inp,ksize=[1,k,k,1],strides=[1,s,s,1],padding='SAME',name=name)

def loss(logits,labels):
    labels = tf.reshape(tf.cast(labels,tf.int64),[-1])
    #print labels.get_shape().as_list(),logits.get_shape().as_list()
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,logits=logits,name='cross_entropy_per_example')
    cross_entropy_mean = tf.reduce_mean(cross_entropy,name='cross_entropy')
    total_loss = tf.add(tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)),cross_entropy_mean,name='total_loss')
    return total_loss

def top_1_acc(logits,true_labels):
    pred_labels = tf.argmax(logits,1)
    true_labels = tf.cast(true_labels,tf.int64)
    #print pred_labels.get_shape().as_list(),true_labels
    correct_pred = tf.cast(tf.equal(pred_labels, true_labels), tf.float32)
    accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
    return accuracy

with tf.device('/gpu:0'):
    conv1 = conv2d(feed_images,'conv1',[3,3,3,64],1)
    conv2 = conv2d(conv1,'conv2',[3,3,64,64],1)
    pool1 = maxpool(conv2,'pool1',2,2)
    #size = [N,48,48,64]
    conv3 = conv2d(pool1,'conv3',[3,3,64,128],1)
    conv4 = conv2d(conv3,'conv4',[3,3,128,128],1)
    pool2 = maxpool(conv4,'pool2',2,2)
    #size = [N,24,24,128]
    conv5 = conv2d(pool2,'conv5',[3,3,128,256],1)
    conv6 = conv2d(conv5,'conv6',[3,3,256,256],1)
    pool3 = maxpool(conv6,'pool3',2,2)
    #size = [N,12,12,256]
    conv7 = conv2d(pool3,'conv7',[3,3,256,512],1)
    conv8 = conv2d(conv7,'conv8',[3,3,512,512],1)
    pool4 = maxpool(conv8,'pool4',2,2)
    #size = [N,6,6,512]
    conv9 = conv2d(pool4,'conv9',[3,3,512,512],1)
    conv10 = conv2d(conv9,'conv10',[3,3,512,512],1)
    pool5 = maxpool(conv10,'pool5',2,2)
    #size = [N,3,3,512]
    flattened_pool5 = tf.contrib.layers.flatten(pool5)
    fc1 = tf.contrib.layers.fully_connected(flattened_pool5,1024,weights_regularizer=tf.contrib.layers.l2_regularizer(tf.constant(0.001, dtype=tf.float32)))
    dropout1 = tf.nn.dropout(fc1,keep_prob)
    fc2 = tf.contrib.layers.fully_connected(dropout1,1024,weights_regularizer=tf.contrib.layers.l2_regularizer(tf.constant(0.001, dtype=tf.float32)))
    dropout2 = tf.nn.dropout(fc2,keep_prob)
    logits = tf.contrib.layers.fully_connected(dropout2,10,activation_fn=None,weights_regularizer=tf.contrib.layers.l2_regularizer(tf.constant(0.001, dtype=tf.float32)))

    cost = loss(logits,feed_labels)

    opt_mom = tf.train.MomentumOptimizer(learning_rate=lr,momentum=0.9)
    opt = opt_mom.minimize(cost)

    acc = top_1_acc(logits,feed_labels)
EN

回答 1

Data Science用户

发布于 2018-07-09 11:38:46

根据您的准确性,$12 $差额是引入高方差问题,这意味着您是过度拟合。由于VGG16的参数数量太多,而且您有一个比ImageNet小的中等大小的数据集,过度拟合是很明显的。尽量减少模型瓶颈中的参数数,减少全连接网络、卷积层和完全连接层之间的连接。另外,尝试使用AdamOptimizer哪个更好。也试着训练更多的时代。

票数 1
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/34164

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档