前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Caffe2 - (十)训练数据集创建

Caffe2 - (十)训练数据集创建

作者头像
AIHGF
发布2019-02-18 10:19:00
5880
发布2019-02-18 10:19:00
举报
文章被收录于专栏:AIUAIAIUAI

Caffe2 - 训练数据集创建

caffe2 使用二值 DB 存储模型训练的数据,以 key-value 格式保存,

代码语言:javascript
复制
key1 value1 key2 value2 key3 value3 ...

DB 中,将 keys 和 values 保存为 strings 形式;可以通过 TensorProtos protocol buffer 来转换为结构化的数据:

TensorProtos protocol buffer: 记录 Tensors,也叫多维数组(multi-dimensional arrays, together),tensor 数据类型及数据 shape 信息.

故,采用 TensorProtosDBInput Operator 来加载数据,以进行 SGD 训练.

UCI Iris 数据集为例,Iris 花朵分类数据集,其包括 4 种实值特征来表示花,对三种类型的花进行分类.

数据集格式:

代码语言:javascript
复制
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
...
代码语言:javascript
复制
import urllib2 
import numpy as np
import matplotlib.pyplot as plt
from StringIO import StringIO
from caffe2.python import core, utils, workspace
from caffe2.proto import caffe2_pb2
print("Necessities imported!")

# Load txtdata 
# https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
raw_datas = open('iris_data.txt').readlines()
num_datas = len(raw_datas)

features = np.zeros((num_datas, 4), dtype=np.float32) # 每一行一个样本
labels = np.zeros((num_datas, ), dtype=np.int)
#label_dict = {'Iris-setosa':0, 'Iris-versicolor':1, 'Iris-virginica':2}
label_converter = lambda s : {'Iris-setosa':0, 'Iris-versicolor':1, 'Iris-virginica':2}[s]

for idx in range(num_datas):
    data = raw_datas[idx].strip()
    print data
    feature = np.loadtxt(StringIO(data), dtype=np.float32, delimiter=',', usecols=(0, 1, 2, 3))
    label = np.loadtxt(StringIO(data), dtype=np.int, delimiter=',', usecols=(4,), converters={4: label_converter})

    features[idx] = feature
    labels[idx] = label


# train: 100
# test:50
random_index = np.random.permutation(150) # 打乱顺序
features = features[random_index]
labels = labels[random_index]

train_features = features[:100]
train_labels = labels[:100]
test_features = features[100:]
test_labels = labels[100:]


# 可视化下特征
# first two features 和 label.
legend = ['rx', 'b+', 'go']
plt.title("Training data distribution, feature 0 and 1")
for i in range(3):
    plt.plot(train_features[train_labels==i, 0], train_features[train_labels==i, 1], legend[i])
plt.figure()
plt.title("Testing data distribution, feature 0 and 1")
for i in range(3):
    plt.plot(test_features[test_labels==i, 0], test_features[test_labels==i, 1], legend[i])
plt.show()
这里写图片描述
这里写图片描述

将数据放入 Caffe2 DB,key - train_xxx,value - 使用 TensorProtos 来存储每个数据样本的两个 tensor,feature 和 label.

代码语言:javascript
复制
# 测试
# 从 numpy arrays 创建 TensorProtos protocol buffer
feature_and_label = caffe2_pb2.TensorProtos()
feature_and_label.protos.extend([utils.NumpyArrayToCaffe2Tensor(features[0]), utils.NumpyArrayToCaffe2Tensor(labels[0])])
print('This is what the tensor proto looks like for a feature and its label:')
print(str(feature_and_label))
print('This is the compact string that gets written into the db:')
print(feature_and_label.SerializeToString())

# 数据写入 DB
def write_db(db_type, db_name, features, labels):
    db = core.C.create_db(db_type, db_name, core.C.Mode.write)
    transaction = db.new_transaction()
    for i in range(features.shape[0]):
        feature_and_label = caffe2_pb2.TensorProtos()
        feature_and_label.protos.extend([utils.NumpyArrayToCaffe2Tensor(features[i]), utils.NumpyArrayToCaffe2Tensor(labels[i])])
        transaction.put('train_%03d'.format(i), feature_and_label.SerializeToString())

    del transaction
    del db

write_db("minidb", "iris_train.minidb", train_features, train_labels)
write_db("minidb", "iris_test.minidb", test_features, test_labels)


# 创建网络,测试 DB 加载
net_proto = core.Net("example_iris_net")
dbreader = net_proto.CreateDB([], "dbreader", db="iris_train.minidb", db_type="minidb")
net_proto.TensorProtosDBInput([dbreader], ["X", "Y"], batch_size=16)

print("The net looks like this:")
print(str(net_proto.Proto()))

workspace.CreateNet(net_proto)

workspace.RunNet(net_proto.Proto().name)
print("The first batch of feature is:")
print(workspace.FetchBlob("X"))
print("The first batch of label is:")
print(workspace.FetchBlob("Y"))

Reference

[1] - Create Your Own Dataset

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年01月08日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Caffe2 - 训练数据集创建
    • Reference
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档