# 基于交通灯数据集的端到端分类

pytorch：0.4.0

torchsummarypip install torchsummary

cv2: pip install opencv-python

matplotlib

numpy

## 2.代码实战

### 2.1 model.py

import torch.nn as nn
from torchsummary import summary

class A2NN(nn.Module):
def __init__(self, ):
super(A2NN, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 16, 3, 1, 1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.Conv2d(16, 32, 3, 1, 1),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, 1, 1),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, 1, 1),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.linear = nn.Linear(4*4*64, 9)

def forward(self, inp):
x = self.main(inp)
x = x.view(x.shape[0], -1)
x = self.linear(x)
return x

if __name__ == "__main__":
nn = A2NN()
summary(nn, (3, 32, 32))

model代码不复杂，很简单，这里不多介绍，缺少基础的朋友还请自行补基础。

### 2.2 dataset.py

import torch
import cv2
import torch.utils.data as data

class_light = {
'Red Circle': 0,
'Green Circle': 1,
'Red Left': 2,
'Green Left': 3,
'Red Up': 4,
'Green Up': 5,
'Red Right': 6,
'Green Right': 7,
'Red Negative': 8,
'Green Negative': 8
}

class Traffic_Light(data.Dataset):
def __init__(self, dataset_names, img_resize_shape):
super(Traffic_Light, self).__init__()
self.dataset_names = dataset_names
self.img_resize_shape = img_resize_shape

def __getitem__(self, ind):
img = cv2.resize(img, self.img_resize_shape)
img = img.transpose(2, 0, 1)-127.5/127.5
for key in class_light.keys():
if key in self.dataset_names[ind]:
label = class_light[key]
# pylint: disable=E1101,E1102
# pylint: disable=E1101,E1102

def __len__(self):
return len(self.dataset_names)

if __name__ == '__main__':
from glob import glob
import os

path = 'TL_Dataset/Green Up/'
names = glob(os.path.join(path, '*.png'))
dataset = Traffic_Light(names, (32, 32))
for ind, (inp, label) in enumerate(dataload):
print("{}-inp_size:{}-label_size:{}".format(ind, inp.numpy().shape,
label.numpy().shape))

### 2.3 util.py

import os
from glob import glob

train_names = []
val_names = []
dataset_paths = os.listdir(dataset_path)
for n in remove_names:
dataset_paths.remove(n)
for path in dataset_paths:
sub_dataset_path = os.path.join(dataset_path, path)
sub_dataset_names = glob(os.path.join(sub_dataset_path, '*.png'))
sub_dataset_len = len(sub_dataset_names)
return {'train': train_names, 'val': val_names}

def check_folder(path):
if not os.path.exists(path):
os.mkdir(path)

### 2.4 trainer.py

model构造好了，数据集也准备好了，现在就需要准备如果训练了，这就是trainer.py文件的作用，trainer.py构建了Trainer类，通过传入训练的一系列参数，调用Trainer.train函数进行训练，并返回loss，代码如下：

import torch.nn as nn

class Trainer:
def __init__(self, model, dataload, epoch, lr, device):
self.model = model
self.epoch = epoch
self.lr = lr
self.device = device
self.criterion = nn.CrossEntropyLoss().to(self.device)

def __epoch(self, epoch):
self.model.train()
loss_sum = 0
for ind, (inp, label) in enumerate(self.dataload):
inp = inp.float().to(self.device)
label = label.long().to(self.device)
out = self.model.forward(inp)
loss = self.criterion(out, label)
loss.backward()
loss_sum += loss.item()
self.optimizer.step()
print('epoch{}_step{}_train_loss_: {}'.format(epoch,
ind,
loss.item()))
return loss_sum/(ind+1)

def train(self):
train_loss = self.__epoch(self.epoch)
return train_loss

### 2.5 validator.py

trainer.py文件是用来进行训练数据集的，训练过程中，我们是需要有验证集来判断我们模型的训练效果，所以这里有validator.py文件，里面封装了Validator类，与Trainer.py类似，但不同的是，我们不训练，不更新参数，model处于eval模式，代码上会有一些跟Trainer不一样,通过调用Validator.eval函数返回loss，代码如下：

import torch.nn as nn

class Validator:
def __init__(self, model, dataload, epoch, device, batch_size):
self.model = model
self.epoch = epoch
self.device = device
self.batch_size = batch_size
self.criterion = nn.CrossEntropyLoss().to(self.device)

def __epoch(self, epoch):
self.model.eval()
loss_sum = 0
for ind, (inp, label) in enumerate(self.dataload):
inp = inp.float().to(self.device)
label = label.long().to(self.device)
out = self.model.forward(inp)
loss = self.criterion(out, label)
loss_sum += loss.item()
return {'val_loss': loss_sum/(ind+1)}

def eval(self):
val_loss = self.__epoch(self.epoch)
return val_loss

### 2.6 logger.py

import matplotlib.pyplot as plt
import os

class Logger:
def __init__(self, save_path):
self.save_path = save_path

def update(self, Kwarg):
self.__plot(Kwarg)

def __plot(self, Kwarg):
save_img_path = os.path.join(self.save_path, 'learning_curve.png')
plt.clf()
plt.plot(Kwarg['train_losses'], label='Train', color='g')
plt.plot(Kwarg['val_losses'], label='Val', color='b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.title('learning_curve')
plt.savefig(save_img_path)

### 2.7 main.py

main.py文件将上面所有的东西结合到一起，代码如下：

import torch
import argparse

from model import A2NN
from dataset import Traffic_Light
from utils import get_train_val_names, check_folder
from trainer import Trainer
from validator import Validator
from logger import Logger

def main():
parse = argparse.ArgumentParser()
'Testset'])

args = vars(parse.parse_args())

check_folder(args['save_path'])

# pylint: disable=E1101
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# pylint: disable=E1101

model = A2NN().to(device)

names = get_train_val_names(args['dataset_path'], args['remove_names'])

train_dataset = Traffic_Light(names['train'], args['img_resize_shape'])
val_dataset = Traffic_Light(names['val'], args['img_resize_shape'])

batch_size=args['batch_size'],
shuffle=True,
num_workers=args['num_workers'])

batch_size=args['batch_size'],
shuffle=True,
num_workers=args['num_workers'])

loss_logger = Logger(args['save_path'])

logger_dict = {'train_losses': [],
'val_losses': []}

for epoch in range(args['epochs']):
print('<Main> epoch{}'.format(epoch))
trainer = Trainer(model, train_dataload, epoch, args['lr'], device)
train_loss = trainer.train()
if args['save_model']:
state = model.state_dict()
torch.save(state, 'logs/nn_state.t7')
device, args['batch_size'])
val_loss = validator.eval()
logger_dict['train_losses'].append(train_loss)
logger_dict['val_losses'].append(val_loss['val_loss'])

loss_logger.update(logger_dict)

if __name__ == '__main__':
main()

### 2.8 compute_prec.py和submit.py

compute_prec.py代码如下：

import torch
import numpy as np
import argparse

from model import A2NN
from dataset import Traffic_Light
from utils import get_train_val_names, check_folder

def main():
parse = argparse.ArgumentParser()
'Testset'])

args = vars(parse.parse_args())

check_folder(args['save_path'])

# pylint: disable=E1101
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# pylint: disable=E1101

model = A2NN().to(device)

model.eval()

names = get_train_val_names(args['dataset_path'], args['remove_names'])

val_dataset = Traffic_Light(names['val'], args['img_resize_shape'])

batch_size=1,
num_workers=args['num_workers'])

count = 0
for ind, (inp, label) in enumerate(val_dataload):
inp = inp.float().to(device)
label = label.long().to(device)
output = model.forward(inp)
output = np.argmax(output.to('cpu').detach().numpy(), axis=1)
label = label.to('cpu').numpy()
count += 1 if output == label else 0

print('precision: {}'.format(count/(ind+1)))

if __name__ == "__main__":
main()

submit.py代码如下：

import torch
import numpy as np
import argparse
import os
import cv2

from model import A2NN
from utils import check_folder

def main():
parse = argparse.ArgumentParser()
default='TL_Dataset/Testset/')

args = vars(parse.parse_args())

check_folder(args['save_path'])

# pylint: disable=E1101
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# pylint: disable=E1101

model = A2NN().to(device)

model.eval()

txt_path = os.path.join(args['save_path'], 'result.txt')
with open(txt_path, 'w') as f:
for i in range(20000):
name = os.path.join(args['dataset_path'], '{}.png'.format(i))
img = cv2.resize(img, args['img_resize_shape'])
img = img.transpose(2, 0, 1)-127.5/127.5
img = torch.unsqueeze(torch.from_numpy(img).float(), dim=0)
img = img.to(device)
output = model.forward(img).to('cpu').detach().numpy()
img_class = np.argmax(output, axis=1)
f.write(name.split('/')[2] + ' ' + str(img_class[0]))
f.write('\n')

if __name__ == "__main__":
main()

## 3. 代码如下运行

$python main.py 如果还想计算精确度，在训练玩数据集之后，运行命令： $ python compute_prec.py

49 篇文章13 人订阅

0 条评论

## 相关文章

1055

3466

### HDUOJ------Worm

Worm Time Limit: 1000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java...

3378

### Pandas使用 (一）

What is pandas Pandas是python中用于处理矩阵样数据的功能强大的包，提供了R中的dataframe和vector的操作，使得我们在使用p...

4599

### Pandas，让Python像R一样处理数据，但快

What is pandas Pandas是python中用于处理矩阵样数据的功能强大的包，提供了R中的dataframe和vector的操作，使得我们在使用p...

2565

### Convolutional Neural Networks: Application

X_train_orig, Y_train_orig, X_test_orig, Y_test_orig, classes = load_dataset()

722

3935

### Python数据分析(二): Pandas技巧 (2)

Pandas的第一部分: http://www.cnblogs.com/cgzl/p/7681974.html github地址: https://github...

3016

3066

943