首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >ResourceExhaustedError:当分配形状为[16,224,224,256]且类型为bool的张量时,OOM

ResourceExhaustedError:当分配形状为[16,224,224,256]且类型为bool的张量时,OOM
EN

Stack Overflow用户
提问于 2020-11-10 00:22:24
回答 1查看 209关注 0票数 0

我对深度学习非常陌生。我正在实现用于图像彩色化的CNN+VGG 16模型。但是我在模型中使用了太多的层。当我使用300Batch size时,这个模型显示了一个错误。引导我删除哪些层?模型花费了太多时间进行拟合。

代码语言:javascript
运行
复制
# Importing Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, Input, InputLayer, Conv2D,UpSampling2D,DepthwiseConv2D
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Reshape, Flatten,LeakyReLU, Dropout,DepthwiseConv2D
from tensorflow.keras.layers import Flatten,MaxPooling2D,Conv2DTranspose, AveragePooling2D
from tensorflow.keras.applications.vgg16  import VGG16
from tensorflow.keras.models import Model,Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import layers
from PIL import Image
from tensorflow.keras.utils import plot_model
import cv2

#load data
ab = np.load('ab1.npy')
gray = np.load('gray_scale.npy')
ab.shape #(10000, 224, 224, 2)

def get_rbg(gray_imgs, ab_imgs, n = 10):
    
    #create an empty array to store images
    img1 = np.zeros((n, 224, 224, 3))
    
    img1[:, :, :, 0] = gray_imgs[0:n:]
    img1[:, :, :, 1:] = ab_imgs[0:n:]
    
    #convert all the images to type unit8
    img1 = img1.astype("uint8")
    
    #create a new empty array
    imgs= []
    
    for i in range(0, n):
        imgs.append(cv2.cvtColor(img1[i], cv2.COLOR_LAB2RGB))
        
    #convert the image matrix into a numpy array
    imgs = np.array(imgs)
    
    return imgs
img_out = tf.keras.applications.vgg16.preprocess_input(get_rbg(gray_imgs = gray, ab_imgs = ab, n = 100))# 300

img_out.shape #(100, 224, 224, 3)

模型定义:

代码语言:javascript
运行
复制
model6 = VGG16(weights='imagenet',include_top=False,input_shape=(224, 224, 3))
model = Sequential()
model.add(InputLayer(input_shape=(img_in.shape[1], img_in.shape[2], 3)))
model.add(Model(inputs=model6.inputs, outputs=model6.layers[-10].output))
model.add(UpSampling2D((2, 2)))
model.add(UpSampling2D((2, 2)))
model.add(DepthwiseConv2D(32, (2, 2), activation=tf.nn.relu, padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(DepthwiseConv2D(32, (2, 2), activation=tf.nn.relu, padding='same'))
model.add(layers.ReLU(0.3))
model.add(layers.Dropout(0.4))
model.add(UpSampling2D((2, 2)))
model.add(UpSampling2D((2, 2)))
model.add(DepthwiseConv2D(2, (2, 2), activation=tf.nn.relu, padding='same'))
model.add(layers.ReLU(0.3))
model.add(layers.Dropout(0.2))
model.add(UpSampling2D((2, 2)))
model.add(layers.ReLU(0.3))
model.add(layers.Dropout(0.2))
model.add(AveragePooling2D(pool_size = (2, 2)))
model.add(layers.Dense(units=3))
print(model.summary())

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
functional_1 (Functional)    (None, 56, 56, 256)       1735488   
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 112, 112, 256)     0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 224, 224, 256)     0         
_________________________________________________________________
depthwise_conv2d (DepthwiseC (None, 112, 112, 256)     262400    
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 224, 224, 256)     0         
_________________________________________________________________
depthwise_conv2d_1 (Depthwis (None, 112, 112, 256)     262400    
_________________________________________________________________
re_lu (ReLU)                 (None, 112, 112, 256)     0         
_________________________________________________________________
dropout (Dropout)            (None, 112, 112, 256)     0         
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 224, 224, 256)     0         
_________________________________________________________________
up_sampling2d_4 (UpSampling2 (None, 448, 448, 256)     0         
_________________________________________________________________
depthwise_conv2d_2 (Depthwis (None, 224, 224, 256)     1280      
_________________________________________________________________
re_lu_1 (ReLU)               (None, 224, 224, 256)     0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 224, 224, 256)     0         
_________________________________________________________________
up_sampling2d_5 (UpSampling2 (None, 448, 448, 256)     0         
_________________________________________________________________
re_lu_2 (ReLU)               (None, 448, 448, 256)     0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 448, 448, 256)     0         
_________________________________________________________________
average_pooling2d (AveragePo (None, 224, 224, 256)     0         
_________________________________________________________________
dense (Dense)                (None, 224, 224, 3)       771       
=================================================================
Total params: 2,262,339
Trainable params: 2,262,339
Non-trainable params: 0
_________________________________________________________________
None
model.compile(optimizer = tf.keras.optimizers.Adam(),loss = 'mse',metrics=tf.keras.metrics.Accuracy())
#if you encounter an OOM error, reduce the batch_size to 8
model.fit(img_in, img_out, epochs =5, batch_size = 5)

请告诉我我可以从模型中删除哪些层。所以我的模型可以用更少的时间运行,并且它将消耗更少的资源。

EN

回答 1

Stack Overflow用户

发布于 2020-11-10 02:34:53

减少批处理大小-删除图层会降低性能。快速搜索显示,VGG16大约消耗200mb用于反向传播,这意味着在批处理大小为300的情况下,您将需要大约60 or的内存或vRAM。

参考:https://forums.fast.ai/t/vgg16-memory-vs-parameter-in-3d-model/18117

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64755360

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档