前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >torch01:torch基础

torch01:torch基础

作者头像
MachineLP
发布2022-05-09 15:04:42
2920
发布2022-05-09 15:04:42
举报
文章被收录于专栏:小鹏的专栏小鹏的专栏

MachineLP的Github(欢迎follow):https://github.com/MachineLP

MachineLP的博客目录:小鹏的博客目录

本小节介绍torch的基础操作和流程:

(1)计算表达式的梯度值。

(2)数组与tensor。

(3)构建输入管道。

(4)加载预训练的模型。

(5)保存和加载权重。

代码部分:

(0)import

代码语言:javascript
复制
# coding=utf-8
import torch
import torchvision
import torch.nn as nn
import numpy as np 
import torchvision.transforms as transforms

print (torch.__version__)

(1)计算梯度值

代码语言:javascript
复制
# 创建tensor
x = torch.tensor(1, requires_grad=True)
w = torch.tensor(2, requires_grad=True)
b = torch.tensor(3, requires_grad=True)

# 构建模型, 建立计算图
y = w * x + b

# 计算梯度
y.backward()

# 输出计算后的梯度值
print ('x:grad', x.grad)
print ('w:grad', w.grad)
print ('b:grad', b.grad)

# 创建两个tensor
x = torch.randn(10, 3)
y = torch.randn(10, 2)

# 搭建全连接层
linear = nn.Linear(3,2)

# 打印模型权重值
print ('w', linear.weight)
print ('b', linear.bias)

# 构建你需要的损失函数和优化算法
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)

# 前向计算
pred = linear(x)


# 计算loss
loss = criterion(pred, y)
print('loss: ', loss.item())

loss.backward()
# 打印输出梯度
print ('dL/dw: ', linear.weight.grad) 
print ('dL/db: ', linear.bias.grad)

# 梯度下降
optimizer.step()

# 梯度下降后,再打印权重值就会减小。
print ('w', linear.weight)
print ('b', linear.bias)


# 梯度下降后的预测值和loss
pred = linear(x)
loss = criterion(pred, y)
print('loss after 1 step optimization: ', loss.item())

(2)数组与tensor。

代码语言:javascript
复制
# 创建数组, 转数组为tensor
x = np.array([[1, 2], [3, 4]])
y = torch.from_numpy(x)
# 转tensor为数组
z = y.numpy()

(3)构建输入管道。

代码语言:javascript
复制
# 下载 CIFAR-10 数据
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                             train=True, 
                                             transform=transforms.ToTensor(),
                                             download=True)

# 样本和标签
image, label = train_dataset[0]
print (image.size())
print (label)

# 通过队列的形式加载数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=64, 
                                           shuffle=True)

# 创建迭代器,为每次训练提供训练数据
data_iter = iter(train_loader)

# Mini-batch 样本和标签
images, labels = data_iter.next()

# 另外一种方式
for images, labels in train_loader:
    # 训练的代码
    pass
代码语言:javascript
复制
# 在你自己的数据上构建高效数据加载的方式
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0 

# 
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64, 
                                           shuffle=True)

(4)加载预训练的模型。

代码语言:javascript
复制
# 下载和加载预训练的模型ResNet-18.
resnet = torchvision.models.resnet18(pretrained=True)

# 只进行fine-tune top层:
for param in resnet.parameters():
    param.requires_grad = False

# Replace the top layer for finetuning.
resnet.fc = nn.Linear(resnet.fc.in_features, 100)  # 100 is an example.

# Forward pass.
images = torch.randn(64, 3, 224, 224)
outputs = resnet(images)
print (outputs.size())     # (64, 100)

(5)保存和加载权重。

代码语言:javascript
复制
# 保存和加载模型
torch.save(resnet, 'model.ckpt')
model = torch.load('model.ckpt')

# 只保存和加载模型参数
torch.save(resnet.state_dict(), 'params.ckpt')
resnet.load_state_dict(torch.load('params.ckpt'))

总结:

加餐:

在数据上进行加载数据:

其中,train.txt中的数据格式:

代码语言:javascript
复制
gender/0male/0(2).jpg 1
 gender/0male/0(3).jpeg 1
 gender/0male/0(1).jpg 0
代码语言:javascript
复制
# coding = utf-8
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
from PIL import Image


def default_loader(path):
    # 注意要保证每个batch的tensor大小时候一样的。
    return Image.open(path).convert('RGB')


class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            # line = line.rstrip()
            words = line.split(' ')
            imgs.append((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    
    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label
    
    def __len__(self):
        return len(self.imgs)

def get_loader(dataset='train.txt', crop_size=178, image_size=128, batch_size=2, mode='train', num_workers=1):
    """Build and return a data loader."""
    transform = []
    if mode == 'train':
        transform.append(transforms.RandomHorizontalFlip())
    transform.append(transforms.CenterCrop(crop_size))
    transform.append(transforms.Resize(image_size))
    transform.append(transforms.ToTensor())
    transform.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = transforms.Compose(transform)
    train_data=MyDataset(txt=dataset, transform=transform)
    data_loader = DataLoader(dataset=train_data,
                                  batch_size=batch_size,
                                  shuffle=(mode=='train'),
                                  num_workers=num_workers)
    return data_loader
# 注意要保证每个batch的tensor大小时候一样的。
# data_loader = DataLoader(train_data, batch_size=2,shuffle=True)
data_loader = get_loader('train.txt')
print(len(data_loader))

def show_batch(imgs):
    grid = utils.make_grid(imgs)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title('Batch from dataloader')


for i, (batch_x, batch_y) in enumerate(data_loader):
    if(i<4):
        print(i, batch_x.size(),batch_y.size())
        show_batch(batch_x)
        plt.axis('off')
        plt.show()

总结:

以上是torch的基础部分,总体的流程已经有了,上手就很快了。

参考:

(1)https://github.com/yunjey/pytorch-tutorial

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018-05-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档