专栏首页mathorPyTorch训练神经网络玩游戏

PyTorch训练神经网络玩游戏

Game rules

很简单的一个小游戏,名字叫"FizzBuzz",游戏规则如下:

从1开始数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数的时候,说buzz,当遇到15的倍数的时候,就说fizzbuzz,其他情况则正常数数

Game conversion to classification problem

可以想到,在这个游戏中,总共只有四类,fizzbuzzbuzzfizznumber

所以我们先定义一个函数,这个函数的作用是将输入的数字,离散为这四类中的某一类

def fizz_buzz_encode(i):
    if i % 15 == 0:
        return 3
    elif i % 5 == 0:
        return 2
    elif i % 3 == 0:
        return 1
    else:
        return 0

有了encode函数,还需要一个decode函数,参数是个数字,以及这个数字的类别,返回是这个数字应该喊什么,比方说decode(15, 3),返回的就应该是fizzbuzz,再比如decode(7, 0),就应该返回7

def fizz_buzz_decode(i, label):
    return [str(i), 'fizz', 'buzz', 'fizzbuzz'][label]

写个测试函数测试一下

def helper(i):
    print(fizz_buzz_decode(i, fizz_buzz_encode(i)))
    
for i in range(1, 16):
    helper(i)
输出:
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz

Generate training set

import numpy as np
import torch
from torch import nn

对于一个神经网络,我们的输入是一个数字,我们要他返回的是这个数字属于哪个类别(知道哪个类别之后调用decode函数就行了)

但其实输入如果单纯是个十进制数字特征不够明显,我们可以尝试把十进制转换为二进制,将01编码作为输入

NUM_DIGITS = 10
def binary_encode(i, NUM_DIGITS): # 将一个十进制数转换为二进制
    return np.array([i >> d & 1 for d in range(NUM_DIGITS)][::-1])

#print(binary_encode(15, NUM_DIGITS))

然后生成训练集Xy,我把$[101,1024]$之间的所有整数转为二进制作为X_train,掉用encode函数生成的标签作为y_train

X_train = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
y_train = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])

Construct neural network

首先设计网络结构

然后利用PyTorch定义模型

NUM_HIDDEN = 100 # 隐藏层100个神经元
model = nn.Sequential( # 网络结构:Input -> Hidden_Layer1 -> OutPut
    nn.Linear(NUM_DIGITS, NUM_HIDDEN, bias = False), # z = w1*x, 其中w1.shape=(10, 100), x.shape=(923, 10)
    nn.ReLU(), # z = relu(z), 其中z.shape=(923, 100)
    nn.Linear(NUM_HIDDEN, 4, bias = False) # y_pred = z*w2, 其中z.shape(923, 100), w2.shape=(100, 4)
    # 输出的是个923*4的矩阵
)

定义Loss_Function和梯度下降的方法

loss_fn = nn.CrossEntropyLoss() # 专为分类问题设计的Loss
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1) # lr is learning_rate

开始训练模型

BATCH_SIZE = 128
for epoch in range(10000):
    for start in range(0, len(X_train), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = X_train[start:end]
        batchY = y_train[start:end]
        
        y_pred = model(batchX)
        loss = loss_fn(y_pred, batchY)
        
        print('Epoch', epoch, loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

如果关于BATCH_SIZEEPOCH不清楚作用,可以看这篇文章

训练最终结果如下图,我们说,如果一个人通过瞎猜玩这个游戏,那他每次的正确率只有$\frac{1}{4}$,但是从训练结果来看,很明显我们的网络的准确度比瞎猜要高很多

训练完以后生成测试数据X_test

X_test = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 100)])

然后用训练好的模型对测试数据进行预测,生成y_test,假设测试数据有100个,那y_test的大小就是(100, 4),4列分别对应每个类型的概率,我们取出最大概率对应的下标值,带入decode函数,就能看到他在测试数据上的表现了

with torch.no_grad():
    y_test = model(X_test)

#y_test.max(1)[1]
predicts = zip(range(0, 101), list(y_test.max(1)[1].data.tolist()))

print([fizz_buzz_decode(i, x) for i, x in predicts])
输出:
['0', '1', 'fizz', '3', 'buzz', 'fizz', '6', '7', 'fizz', 'buzz', '10', 'fizz', '12', '13', 'fizzbuzz', '15', '16', 'fizz', '18', 'buzz', '20', '21', '22', 'fizz', 'buzz', '25', 'fizz', '27', '28', 'fizzbuzz', '30', 'fizz', 'fizz', '33', 'buzz', 'fizz', '36', '37', 'fizz', 'buzz', '40', 'fizz', '42', '43', 'fizzbuzz', '45', '46', 'fizz', '48', 'buzz', 'fizz', '51', '52', 'fizz', 'fizz', '55', 'fizz', '57', '58', 'fizzbuzz', '60', '61', 'fizz', '63', 'buzz', 'fizz', '66', '67', 'fizz', 'buzz', '70', 'fizz', '72', '73', '74', '75', '76', 'fizz', '78', 'buzz', 'fizz', '81', '82', 'fizz', 'buzz', '85', 'fizz', '87', '88', 'fizzbuzz', '90', '91', 'fizz', '93', 'buzz', 'fizz', '96', '97', 'fizz']

最终测试的效果并不是特别好,但是从一些数据当中可以看到,我们这个网络实际还是找到了这个游戏的部分规律。单从fizzbuzz的结果来看,虽然他并没有准确的达到每次都在15的倍数输出,但是它隐约知道在15的倍数附近要输出

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 搜索(2)

    mathor
  • 从暴力递归到动态规划

     动态规划没有那么难,但是很多老师在讲课的过程中讲的并不好,由此写下一篇文章记录学习过程

    mathor
  • DataIO & ByteArrayIo

    mathor
  • 成功部署云计算的关键10个技巧

    对于一些组织来说,云计算的使用已经开始,而无论他们知道与否,希望成功部署云计算,制定一个计划是有所帮助的。 Consultancy Cloud Technolo...

    静一
  • 白话易懂 编辑带你通俗解读云计算到底是什么

    本文,我们不谈那些云计算专业难懂的话题,我们用一些简单易懂的辞藻来和大家聊聊云计算市场的一些具体情况,以及云计算技术究竟与我们的工作和生活有何联系。我们都知道,...

    静一
  • 高德开放平台升级至2.0版,微软推出实时翻译应用冲破语言难关 | 大数据24小时

    数据猿导读 高德开放平台升级至2.0版,全面提升产品的大数据能力;微软推出实时翻译应用 Translator live,利用机器学习技术冲破语言障碍;世界高铁网...

    数据猿
  • 2016年中国SaaS市场必看的八大趋势|研报

    T客汇官网:tikehui.com 撰文 |人称T客 火热的2016年接近尾声,中国的SaaS市场从资本泡沫开始趋于理性,从注重市场驱动趋向于注重产品驱动,那么...

    人称T客
  • QCon全球软件开发大会随笔(一)

    难得有机会参加了这次QCon2018上海站,有幸见到了很多大牛,一天培训下来,虽然觉得很累,但真的是满满的干货。「也算对得起这个票价了」

    Bug生活2048
  • iOS 中的CIFilter(基础用法)

    本文大部分内容均来自:Core Image Tutorial: Getting Started Core Image 是一个很强大的库,PS图片时用到的各种滤...

    Haley_Wong
  • Top 30 战队揭晓,GeekPwn 云安全挑战赛第一阶段收官

    今天的推送,相信经常关注云安全的粉丝们,已经搓手手期盼很久了~ 腾讯安全与GeekPwn联合举办的GeekPwn云安全挑战赛,作为国内首个基于真实云平台的安全...

    云鼎实验室

扫码关注云+社区

领取腾讯云代金券