我在试着建立一个CNN回归模型。输入的数据是10年的卫星图像。
输入形状是表示[Year, Image shape, Image Shape, Channels/Bands]
的[10, 256,256, 10]
模型的输出是一个介于0-1之间的数字,0-1是图像中区域的百分比值。
以下是使用的参数
CHANNELS=5
BATCH_SIZE=16
INPUT_SHAPE=(10,IMG_SIZE,IMG_SIZE,CHANNELS)
SAMPLES=100
LR=1e-7
EPOCHES=10
我使用Conv3D层作为输入层,因为它提供了向模型提供体积数据的能力,并使用密集层作为输出。
Model: sequential_FLATTEN_100_5_16_SGD_1e-07_30_v1
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv3d_3 (Conv3D) (None, 10, 254, 254, 32) 1472
_________________________________________________________________
max_pooling3d_3 (MaxPooling3 (None, 10, 127, 127, 32) 0
_________________________________________________________________
conv3d_4 (Conv3D) (None, 10, 125, 125, 64) 18496
_________________________________________________________________
max_pooling3d_4 (MaxPooling3 (None, 10, 62, 62, 64) 0
_________________________________________________________________
flatten_3 (Flatten) (None, 2460160) 0
_________________________________________________________________
dense_24 (Dense) (None, 256) 629801216
_________________________________________________________________
dense_25 (Dense) (None, 1) 257
=================================================================
Total params: 629,821,441
Trainable params: 629,821,441
Non-trainable params: 0
_________________________________________________________________
该模型在训练集上给出了以下分数:
mean_absolute_error: 0.09013315520024737
mean_squared_error: 0.11449361186977994
explained_variance_score: -0.2407465861253424
r2_score: -0.9382254392540899
在验证集上:
mean_absolute_error: 0.1923245317002776
mean_squared_error: 0.2579017795812263
explained_variance_score: -5.067052299015521
r2_score: -5.4177061135705475
我还尝试了一个不同的模型,如下所示:在这个模型中,只有第一层是Conv3D,其余的都是密集层
Model: "sequential_FLATTEN_100_5_16_Adam_1e-07_30_v1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv3d_1 (Conv3D) (None, 4, 250, 250, 32) 54912
_________________________________________________________________
max_pooling3d_1 (MaxPooling3 (None, 1, 83, 83, 32) 0
_________________________________________________________________
flatten_2 (Flatten) (None, 220448) 0
_________________________________________________________________
dense_16 (Dense) (None, 512) 112869888
_________________________________________________________________
dense_17 (Dense) (None, 256) 131328
_________________________________________________________________
dense_18 (Dense) (None, 128) 32896
_________________________________________________________________
dense_19 (Dense) (None, 64) 8256
_________________________________________________________________
dense_20 (Dense) (None, 32) 2080
_________________________________________________________________
dense_21 (Dense) (None, 16) 528
_________________________________________________________________
dense_22 (Dense) (None, 8) 136
_________________________________________________________________
dense_23 (Dense) (None, 1) 9
=================================================================
Total params: 113,100,033
Trainable params: 113,100,033
Non-trainable params: 0
_________________________________________________________________
这给了我在训练集上的以下分数:
mean_absolute_error: 0.08475626941395917
mean_squared_error: 0.1637630610914996
explained_variance_score: 0.19943303382780664
r2_score: 0.19214565669613703
在验证集上:
mean_absolute_error: 0.15135902269457854
mean_squared_error: 0.2650686092962602
explained_variance_score: -1.7471740284409094
r2_score: -1.7776585146674124
如你所见,该模型的MAE和MSE很低,但R2得分和解释方差得分也很低。
我怎样才能改进这些结果?此外,当样本大小增加时,模型开始为所有输入预测相似的值。
发布于 2021-02-25 15:20:12
我只是注意到对于这样的任务,参数的数量是如此之大。可能正在遭受Vanishing or Exploding gradient的折磨。尽量减少特征提取器的维数。您还可以应用dropout和正则化。
发布于 2021-02-25 20:23:31
发布于 2021-02-25 21:51:22
如果红线表示模型的预测线,那么我认为您的代码缺少一些优化来更多地概括预测,本质上它至少应该是一条曲线。参考下面的url,它是CNN回归的教程:
https://www.datatechnotes.com/2019/12/how-to-fit-regression-data-with-cnn.html?m=1
在上面的教程中,输出图是泛化的,这减少了误差。也许你可以尝试一下这些愚蠢的步骤。
https://stackoverflow.com/questions/66363862
复制相似问题