前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >卷积网络与全连接网络比较分析

卷积网络与全连接网络比较分析

作者头像
算法与编程之美
发布2023-08-22 13:11:24
1560
发布2023-08-22 13:11:24
举报
文章被收录于专栏:算法与编程之美

1 问题

卷积网络与全连接网络比较分析。

2 方法

在全连接网络的5个周期内

from torchvision import datasetsfrom torchvision.transforms import ToTensorfrom torch import nnimport torchfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as pltfrom collections import defaultdictclass MyNet(nn.Module): def __init__(self) -> None: super().__init__() self.flatten = nn.Flatten() self.fc1 = nn.Linear(in_features=784, out_features=512) self.fc2 = nn.Linear(in_features=512, out_features=10) def forward(self, x): x = self.flatten(x) x = self.fc1(x) out = self.fc2(x) return outdef train(dataloader,net,loss_fn,optimizer): size = len((dataloader.dataset)) epoch_loss=0.0 batch_num=len(dataloader) net.train() correct = 0 for batch_ind, (X,y) in enumerate(dataloader): X,y=X.to(device),y.to(device) pred = net(X) loss=loss_fn(pred,y) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss+=loss.item() correct+=(pred.argmax(1)==y).type(torch.float).sum().item() if batch_ind % 100==0: print(f'[{batch_ind+1 :>5d} / {batch_num :>5d}], loss:{loss.item()}') avg_loss=epoch_loss/batch_num avg_accurcy = correct / size return avg_accurcy,avg_lossdef test(dataloder,net,loss): size=len(dataloder.dataset) batch_num = len(dataloder) net.eval() losses=0 correct = 0 with torch.no_grad(): for X,y in dataloder: X,y = X.to(device),y.to(device) pred = net(X) loss = loss_fn(pred,y) losses+=loss.item() correct+=(pred.argmax(1) ==y).type(torch.int).sum().item() accuracy = correct / size avg_loss = losses /batch_num print(f'accuracy is {accuracy*100}%') return accuracy,avg_lossclass getData: def __init__(self): self.train_ds = datasets.MNIST( root='data', download=True, train=True, transform=ToTensor(), ) self.text_ds = datasets.MNIST( root='data', download=True, train=False, transform=ToTensor(), ) self.train_loader = DataLoader( dataset=self.train_ds, batch_size=128, shuffle=True, ) self.test_loader = DataLoader( dataset=self.text_ds, batch_size=128, )if __name__=='__main__': batch_size=128 device='cuda' if torch.cuda.is_available() else 'cpu' train_ds = datasets.MNIST( root='data', download=True, train=True, transform=ToTensor(), ) text_ds = datasets.MNIST( root='data', download=True, train=False, transform=ToTensor(), ) train_ds,val_ds= torch.utils.data.random_split(train_ds,[50000,10000]) # (3) train_loader = DataLoader( dataset=train_ds, batch_size=batch_size, shuffle=True, ) val_loder = DataLoader( dataset=val_ds, batch_size=batch_size, ) test_loader = DataLoader( dataset=text_ds, batch_size=batch_size, ) net=MyNet().to(device) optimizer=torch.optim.SGD(net.parameters(),lr=0.15) loss_fn=nn.CrossEntropyLoss() train_loss=[] train_acc_list,train_loss_list,val_acc_list,val_loss_list=[],[],[],[] best_acc=0 for epoch in range(10): print('-'*50) print(f'eopch:{epoch+1}') train_accuracy,train_loss=train(train_loader,net,loss_fn,optimizer) val_accuracy,val_loss =test(train_loader,net,loss_fn) print(f'train acc:{train_accuracy},train_val:{train_loss}') train_acc_list.append(train_accuracy) train_loss_list.append(train_loss) val_acc_list.append(val_accuracy) val_loss_list.append(val_loss) if val_accuracy > best_acc: best_acc = val_accuracy torch.save(net.state_dict(),'model_best.pth') net.load_state_dict(torch.load('model_best.pth')) print('the best val_acc is:') test(test_loader,net,loss_fn)

全连接网络训练100个周期

import torchfrom torch import nnx = torch.rand(128,3,28,28)conv1 = nn.Conv2d(in_channels = 3,out_channels = 16,kernel_size = 3,stride = 1,padding = 1)conv2 = nn.Conv2d(in_channels = 16,out_channels = 32,kernel_size = 3,stride = 1,padding = 1)x = conv1(x)out = conv2(x)print(out.shape)

3 结语

我们通过对比全连接和卷积的学习过程最后的精确度等因素,发现卷积比全连接神经网络更适合做图像处理,在这个过程中,全连接模型中会有很多参数,这对于图像的要求太高,如果图像出现变动,会导致模型改动较大。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-05-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法与编程之美 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档