前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Keras 学习笔记(一)编码与简单快速上手

Keras 学习笔记(一)编码与简单快速上手

作者头像
种花家的奋斗兔
发布2020-11-13 10:40:54
3780
发布2020-11-13 10:40:54
举报

1. 使用Keras对类别进行编码,如one-hot

参考 keras中to_categorical函数解析

简单来说,to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示。其表现为将原有的类别向量转换为独热编码的形式。先上代码看一下效果:

代码语言:javascript
复制
from keras.utils.np_utils import *
#类别向量定义
b = [0,1,2,3,4,5,6,7,8]
#调用to_categorical将b按照9个类别来进行转换
b = to_categorical(b, 9)
print(b)
 
执行结果如下:
[[1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1.]]

2. 快速开始:30 秒上手 Keras

Keras 的核心数据结构是 model,一种组织网络层的方式。最简单的模型是 Sequential 顺序模型,它由多个网络层线性堆叠。对于更复杂的结构,你应该使用 Keras 函数式 API,它允许构建任意的神经网络图。

Sequential 模型如下所示:

代码语言:javascript
复制
from keras.models import Sequential

model = Sequential()

可以简单地使用 .add() 来堆叠模型:

代码语言:javascript
复制
from keras.layers import Dense

model.add(Dense(units=64, activation='relu', input_dim=100))
model.add(Dense(units=10, activation='softmax'))

在完成了模型的构建后, 可以使用 .compile() 来配置学习过程:

代码语言:javascript
复制
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])

如果需要,你还可以进一步地配置你的优化器。Keras 的核心原则是使事情变得相当简单,同时又允许用户在需要的时候能够进行完全的控制(终极的控制是源代码的易扩展性)。

代码语言:javascript
复制
model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.SGD(lr=0.01, momentum=0.9, nesterov=True))

现在,你可以批量地在训练数据上进行迭代了:

代码语言:javascript
复制
# x_train 和 y_train 是 Numpy 数组 -- 就像在 Scikit-Learn API 中一样。
model.fit(x_train, y_train, epochs=5, batch_size=32)

或者,你可以手动地将批次的数据提供给模型:

代码语言:javascript
复制
model.train_on_batch(x_batch, y_batch)

只需一行代码就能评估模型性能:

代码语言:javascript
复制
loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128)

或者对新的数据生成预测:

代码语言:javascript
复制
classes = model.predict(x_test, batch_size=128)
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019/11/25 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 使用Keras对类别进行编码,如one-hot
  • 2. 快速开始:30 秒上手 Keras
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档