前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Caffe2 - (十六) 创建 LMDB 数据库

Caffe2 - (十六) 创建 LMDB 数据库

作者头像
AIHGF
发布2019-02-18 09:43:38
9280
发布2019-02-18 09:43:38
举报
文章被收录于专栏:AIUAI

Caffe2 - 创建 lmdb

Caffe2 提供了将数据转换为 lmdb 的 Demo.

代码语言:javascript
复制
## @package lmdb_create_example
# Module caffe2.python.examples.lmdb_create_example
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import numpy as np

import lmdb 
from caffe2.proto import caffe2_pb2
from caffe2.python import workspace, model_helper

'''
基于随机生成的 image data 和 labels,创建 lmdb database.
'''

def create_db(output_file):
    print(">>> Write database...")
    LMDB_MAP_SIZE = 1 << 40   # MODIFY
    env = lmdb.open(output_file, map_size=LMDB_MAP_SIZE)

    checksum = 0
    with env.begin(write=True) as txn:
        for j in range(0, 128):
            # MODIFY: add your own data reader / creator
            label = j % 10
            width = 64
            height = 32

            img_data = np.random.rand(3, width, height) # 随机生成的 image data

            # Create TensorProtos
            tensor_protos = caffe2_pb2.TensorProtos()
            img_tensor = tensor_protos.protos.add()
            img_tensor.dims.extend(img_data.shape)
            img_tensor.data_type = 1

            flatten_img = img_data.reshape(np.prod(img_data.shape))
            img_tensor.float_data.extend(flatten_img)

            label_tensor = tensor_protos.protos.add()
            label_tensor.data_type = 2
            label_tensor.int32_data.append(label)
            txn.put('{}'.format(j).encode('ascii'),
                    tensor_protos.SerializeToString() )

            checksum += np.sum(img_data) * label
            if (j % 16 == 0):
                print("Inserted {} rows".format(j))

    print("Checksum/write: {}".format(int(checksum)))
    return checksum


def read_db_with_caffe2(db_file, expected_checksum):
    print(">>> Read database...")
    model = model_helper.ModelHelper(name="lmdbtest")
    batch_size = 32
    data, label = model.TensorProtosDBInput([], 
                                            ["data", "label"], 
                                            batch_size=batch_size,
                                            db=db_file, 
                                            db_type="lmdb")

    checksum = 0

    workspace.RunNetOnce(model.param_init_net)
    workspace.CreateNet(model.net)

    for _ in range(0, 4):
        workspace.RunNet(model.net.Proto().name)

        img_datas = workspace.FetchBlob("data")
        labels = workspace.FetchBlob("label")
        for j in range(batch_size):
            checksum += np.sum(img_datas[j, :]) * labels[j]

    print("Checksum/read: {}".format(int(checksum)))
    assert np.abs(expected_checksum - checksum < 0.1), \
        "Read/write checksums dont match"


def main():
    parser = argparse.ArgumentParser(description="Example LMDB creation" )
    parser.add_argument("--output_file", type=str, default=None,
                        help="Path to write the database to",
                        required=True)

    args = parser.parse_args()
    checksum = create_db(args.output_file)

    # For testing reading:
    read_db_with_caffe2(args.output_file, checksum)


if __name__ == '__main__':
    main()
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2018年01月26日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Caffe2 - 创建 lmdb
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档