前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >resnet源码pytorch_pytorch conv1d

resnet源码pytorch_pytorch conv1d

作者头像
全栈程序员站长
发布2022-11-08 14:56:41
2930
发布2022-11-08 14:56:41
举报
文章被收录于专栏:全栈程序员必看
代码语言:javascript
复制
代码语言:javascript
复制
代码语言:javascript
复制
# Pytorch 0.4.0 ResNet34实现cifar10分类.
# @Time: 2018/6/17
# @Author: xfLi
import torchvision as tv
import torch as t
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
t.set_num_threads(8)
class ResidualBloak(nn.Module):
#残差块
def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
super(ResidualBloak, self).__init__()
self.left = nn.Sequential(
nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),
nn.BatchNorm2d(outchannel),
nn.ReLU(inplace=True),
nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),
nn.BatchNorm2d(outchannel))
self.right = shortcut
def forward(self, x):
out = self.left(x)
residual = x if self.right is None else self.right(x)
out += residual
return F.relu(out)
class ResNet34(nn.Module):
#  实现主module:ResNet34  
#  ResNet34 包含多个layer,每个layer又包含多个residual block  
#  用子module来实现residual block,用_make_layer函数来实现layer 
def __init__(self, num_classes):
super(ResNet34, self).__init__()
#前几层图像转换
self.pre = nn.Sequential(
nn.Conv2d(3, 16, 3, 1, 1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 1))
# 重复的layer,分别有3,4,6,3个residual block
self.layer1 = self._make_layer(16, 16, 3, stride=1)
self.layer2 = self._make_layer(16, 32, 4, stride=1)
self.layer3 = self._make_layer(32, 64, 6, stride=1)
self.layer4 = self._make_layer(64, 64, 3, stride=1)
#分类用的全连接
self.fc = nn.Linear(256, num_classes)
def _make_layer(self, inchannel, outchannel, block_num, stride=1):
#构建layer,包含多个residual block
shortcut = nn.Sequential(
nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),
nn.BatchNorm2d(outchannel))
layer = []
layer.append(ResidualBloak(inchannel, outchannel, stride, shortcut))
for i in range(1, block_num):
layer.append(ResidualBloak(outchannel, outchannel))
return nn.Sequential(*layer)
def forward(self, x):
x = self.pre(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = F.avg_pool2d(x, 7)
x = x.view(x.size(0), -1)
return self.fc(x)
def getData(): # 定义对数据的预处理  
transform = transforms.Compose([
transforms.Resize(40),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
#训练集
trainset = tv.datasets.CIFAR10(root='/data/', train=True, transform=transform, download=True)
trainset_loader = DataLoader(trainset, batch_size=4, shuffle=True)
#测试集
testset = tv.datasets.CIFAR10(root='/data/', train=False, transform=transform, download=True)
testset_loader = DataLoader(testset, batch_size=4, shuffle=False)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
return trainset_loader, testset_loader, classes
def train(): #训练
trainset_loader, testset_loader, _ = getData() #获取数据
net = ResNet34(10)
print(net)
criterion = nn.CrossEntropyLoss()
optimizer = t.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #优化器
for epoch in range(1):
for step, (inputs,labels) in enumerate(trainset_loader):
optimizer.zero_grad() #梯度清零
output = net(inputs)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
if step % 10 ==9:
acc = test(net, testset_loader)
print('Epoch', epoch, '|step ', step, 'loss: %.4f' %loss.item(), 'test accuracy:%.4f' %acc)
print('Finished Training')
return net
def test(net, testdata): #测试集
correct, total = .0, .0
for inputs, label in testdata:
net.eval()
output = net(inputs)
_, predicted = t.max(output, 1) #分类结果
total += label.size(0)
correct += (predicted == label).sum()
return float(correct) / total
if __name__ == '__main__':
net = train()

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/185357.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022年10月6日 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档