专栏首页漫漫深度学习路ptb_reader源码解析

ptb_reader源码解析

版权声明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/u012436149/article/details/52828782

源码来自git 。正在学习tensorflow,所以在此记一下笔记

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


"""Utilities for parsing PTB text files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os

import numpy as np
import tensorflow as tf


def _read_words(filename):
  with tf.gfile.GFile(filename, "r") as f:
    return f.read().replace("\n", "<eos>").split() #读取文件, 将换行符替换为 <eos>, 然后将文件按空格分割。 返回一个 1-D list


def _build_vocab(filename):  #用于建立字典
  data = _read_words(filename)
  counter = collections.Counter(data) #输出一个字典: key是word, value是这个word出现的次数
  count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
#counter.items() 会返回一个tuple列表, tuple是(key, value), 按 value的降序,key的升序排列
  words, _ = list(zip(*count_pairs)) #感觉这个像unzip 就是把key放在一个tuple里,value放在一个tuple里
  word_to_id = dict(zip(words, range(len(words))))#对每个word进行编号, 按照之前words输出的顺序(value降序,key升序)
  return word_to_id  #返回dict, key:word, value:id


def _file_to_word_ids(filename, word_to_id): #将file表示为word_id的形式
  data = _read_words(filename)
  return [word_to_id[word] for word in data]

def ptb_raw_data(data_path=None):
  """Load PTB raw data from data directory "data_path".
  Reads PTB text files, converts strings to integer ids,
  and performs mini-batching of the inputs.
  The PTB dataset comes from Tomas Mikolov's webpage:
  http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
  Args:
    data_path: string path to the directory where simple-examples.tgz has
      been extracted.
  Returns:
    tuple (train_data, valid_data, test_data, vocabulary)
    where each of the data objects can be passed to PTBIterator.
  """

  train_path = os.path.join(data_path, "ptb.train.txt")
  valid_path = os.path.join(data_path, "ptb.valid.txt")
  test_path = os.path.join(data_path, "ptb.test.txt")

  word_to_id = _build_vocab(train_path) #使用训练集确定word id
  train_data = _file_to_word_ids(train_path, word_to_id)
  valid_data = _file_to_word_ids(valid_path, word_to_id)
  test_data = _file_to_word_ids(test_path, word_to_id)
  vocabulary = len(word_to_id)#字典的大小
  return train_data, valid_data, test_data, vocabulary


def ptb_iterator(raw_data, batch_size, num_steps):
  """Iterate on the raw PTB data.
  This generates batch_size pointers into the raw PTB data, and allows
  minibatch iteration along these pointers.
  Args:
    raw_data: one of the raw data outputs from ptb_raw_data.
    batch_size: int, the batch size.
    num_steps: int, the number of unrolls.
  Yields:
    Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
    The second element of the tuple is the same data time-shifted to the
    right by one.
  Raises:
    ValueError: if batch_size or num_steps are too high.
  """
  raw_data = np.array(raw_data, dtype=np.int32)#raw data : train_data | vali_data | test data

  data_len = len(raw_data) #how many words in the data_set
  batch_len = data_len // batch_size
  data = np.zeros([batch_size, batch_len], dtype=np.int32)#batch_len 就是几个word的意思
  for i in range(batch_size):
    data[i] = raw_data[batch_len * i:batch_len * (i + 1)]

  epoch_size = (batch_len - 1) // num_steps

  if epoch_size == 0:
    raise ValueError("epoch_size == 0, decrease batch_size or num_steps")

  for i in range(epoch_size):
    x = data[:, i*num_steps:(i+1)*num_steps]
    y = data[:, i*num_steps+1:(i+1)*num_steps+1]
  yield (x, y)

一直很懵逼怎么emdeding的, 这里并没有进行embeding。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • tensorflow学习笔记(二十六):构建TF代码

    如何构建TF代码 batch_size: batch的大小 mini_batch: 将训练样本以batch_size分组 epoch_size: 样本分...

    ke1th
  • MXNET学习笔记(一):Module类(1)

    Module 是 mxnet 提供给用户的一个高级封装的类。有了它,我们可以很容易的来训练模型。 Module 包含以下单元的一个 wraper symbol ...

    ke1th
  • tensorflow0.10.0 ptb_word_lm.py 源码解析

    Created with Raphaël 2.1.0inputlstm1_1lstm2_1softmaxoutput

    ke1th
  • stLearn :空间轨迹推断

    空间信息在空间转录组中的运用 Giotto|| 空间表达数据分析工具箱 SPOTlight || 用NMF解卷积空间表达数据 Seurat新版教程:分析空间转录...

    生信技能树jimmy
  • datatables,表格

    windseek
  • [Python3 开发技巧]·如何打乱字典中多个对应数组

    当我们把数个对应数组保存到字典中,在我们读取的时候这些数据会按照我们保存的顺序读取出来。如果我们需要打乱顺序,但不改变对应数组的关系时,例如原先位置0对应的各个...

    小宋是呢
  • Python_冒泡排序

    从小到大的排序:(最前面的数和一步步和后面的数比较,如果大于则交换,如果不大于则继续循环)

    py3study
  • 学习笔记 | 基于FPGA的伪随机数发生器(附代码)

    今天是画师本人第一次和各位大侠见面,执笔绘画FPGA江湖,本人写了篇关于FPGA的伪随机数发生器学习笔记,这里分享给大家,仅供参考。

    FPGA技术江湖
  • FPGA学习altera 系列 第二十三篇 二进制转BCD

    大侠好,欢迎来到FPGA技术江湖,江湖偌大,相见即是缘分。大侠可以关注FPGA技术江湖,在“闯荡江湖”、"行侠仗义"栏里获取其他感兴趣的资源,或者一起煮酒言欢。...

    FPGA技术江湖
  • 员工流动分析和预测

    公司员工,是一家公司成长和发展的关键要素之一。留不住优秀的员工,也就难以打造出卓越的公司。很多公司,比方说,惠普公司,IBM公司等,已经采用数据科学的手段,对内...

    陆勤_数据人网

扫码关注云+社区

领取腾讯云代金券