前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >ptb_reader源码解析

ptb_reader源码解析

作者头像
ke1th
发布2019-05-28 11:15:55
3890
发布2019-05-28 11:15:55
举报

版权声明:本文为博主原创文章,转载请注明出处。 https://cloud.tencent.com/developer/article/1436897

源码来自git 。正在学习tensorflow,所以在此记一下笔记

代码语言:javascript
复制
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


"""Utilities for parsing PTB text files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os

import numpy as np
import tensorflow as tf


def _read_words(filename):
  with tf.gfile.GFile(filename, "r") as f:
    return f.read().replace("\n", "<eos>").split() #读取文件, 将换行符替换为 <eos>, 然后将文件按空格分割。 返回一个 1-D list


def _build_vocab(filename):  #用于建立字典
  data = _read_words(filename)
  counter = collections.Counter(data) #输出一个字典: key是word, value是这个word出现的次数
  count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
#counter.items() 会返回一个tuple列表, tuple是(key, value), 按 value的降序,key的升序排列
  words, _ = list(zip(*count_pairs)) #感觉这个像unzip 就是把key放在一个tuple里,value放在一个tuple里
  word_to_id = dict(zip(words, range(len(words))))#对每个word进行编号, 按照之前words输出的顺序(value降序,key升序)
  return word_to_id  #返回dict, key:word, value:id


def _file_to_word_ids(filename, word_to_id): #将file表示为word_id的形式
  data = _read_words(filename)
  return [word_to_id[word] for word in data]

def ptb_raw_data(data_path=None):
  """Load PTB raw data from data directory "data_path".
  Reads PTB text files, converts strings to integer ids,
  and performs mini-batching of the inputs.
  The PTB dataset comes from Tomas Mikolov's webpage:
  http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
  Args:
    data_path: string path to the directory where simple-examples.tgz has
      been extracted.
  Returns:
    tuple (train_data, valid_data, test_data, vocabulary)
    where each of the data objects can be passed to PTBIterator.
  """

  train_path = os.path.join(data_path, "ptb.train.txt")
  valid_path = os.path.join(data_path, "ptb.valid.txt")
  test_path = os.path.join(data_path, "ptb.test.txt")

  word_to_id = _build_vocab(train_path) #使用训练集确定word id
  train_data = _file_to_word_ids(train_path, word_to_id)
  valid_data = _file_to_word_ids(valid_path, word_to_id)
  test_data = _file_to_word_ids(test_path, word_to_id)
  vocabulary = len(word_to_id)#字典的大小
  return train_data, valid_data, test_data, vocabulary


def ptb_iterator(raw_data, batch_size, num_steps):
  """Iterate on the raw PTB data.
  This generates batch_size pointers into the raw PTB data, and allows
  minibatch iteration along these pointers.
  Args:
    raw_data: one of the raw data outputs from ptb_raw_data.
    batch_size: int, the batch size.
    num_steps: int, the number of unrolls.
  Yields:
    Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
    The second element of the tuple is the same data time-shifted to the
    right by one.
  Raises:
    ValueError: if batch_size or num_steps are too high.
  """
  raw_data = np.array(raw_data, dtype=np.int32)#raw data : train_data | vali_data | test data

  data_len = len(raw_data) #how many words in the data_set
  batch_len = data_len // batch_size
  data = np.zeros([batch_size, batch_len], dtype=np.int32)#batch_len 就是几个word的意思
  for i in range(batch_size):
    data[i] = raw_data[batch_len * i:batch_len * (i + 1)]

  epoch_size = (batch_len - 1) // num_steps

  if epoch_size == 0:
    raise ValueError("epoch_size == 0, decrease batch_size or num_steps")

  for i in range(epoch_size):
    x = data[:, i*num_steps:(i+1)*num_steps]
    y = data[:, i*num_steps+1:(i+1)*num_steps+1]
  yield (x, y)

一直很懵逼怎么emdeding的, 这里并没有进行embeding。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2016年10月16日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档