前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >BERT源码分析(二)

BERT源码分析(二)

作者头像
zenRRan
发布2019-08-14 17:55:00
8430
发布2019-08-14 17:55:00
举报

写在前面

BERT的使用可以分为两个步骤:pre-trainingfine-tuning。pre-training的话可以很好地适用于自己特定的任务,但是训练成本很高(four days on 4 to 16 Cloud TPUs),对于大对数从业者而言不太好实现从零开始(from scratch)。不过Google已经发布了各种预训练好的模型可供选择,只需要进行对特定任务的Fine-tuning即可。

今天我们就继续按照原始论文的框架,来一起读读BERT预训练的源码。BERT预训练过程分为两个具体子任务:Masked LMNext Sentence Prediction

tokenization.py

create_pretraining_data.py

xrun_pretraining

1、分词(tokenization.py)

tokenization.py是对原始文本语料的处理,分为BasicTokenizer和WordpieceTokenizer两类。

BasicTokenizer

根据空格,标点进行普通的分词,最后返回的是关于词的列表,对于中文而言是关于字的列表。

代码语言:javascript
复制
 1class BasicTokenizer(object):
 2  def __init__(self, do_lower_case=True):
 3    self.do_lower_case = do_lower_case
 4
 5  def tokenize(self, text):
 6    text = convert_to_unicode(text)
 7    text = self._clean_text(text)
 8    # 增加中文支持
 9    text = self._tokenize_chinese_chars(text)
10
11    orig_tokens = whitespace_tokenize(text)
12    split_tokens = []
13    for token in orig_tokens:
14      if self.do_lower_case:
15        token = token.lower()
16        token = self._run_strip_accents(token)
17      split_tokens.extend(self._run_split_on_punc(token))
18
19    output_tokens = whitespace_tokenize(" ".join(split_tokens))
20    return output_tokens
21
22  def _run_strip_accents(self, text):
23    # 对text进行归一化
24    text = unicodedata.normalize("NFD", text)
25    output = []
26    for char in text:
27      cat = unicodedata.category(char)
28      # 把category为Mn的去掉
29      # refer: https://www.fileformat.info/info/unicode/category/Mn/list.htm
30      if cat == "Mn":
31        continue
32      output.append(char)
33    return "".join(output)
34
35  def _run_split_on_punc(self, text):
36    # 用标点切分,返回list
37    chars = list(text)
38    i = 0
39    start_new_word = True
40    output = []
41    while i < len(chars):
42      char = chars[i]
43      if _is_punctuation(char):
44        output.append([char])
45        start_new_word = True
46      else:
47        if start_new_word:
48          output.append([])
49        start_new_word = False
50        output[-1].append(char)
51      i += 1
52
53    return ["".join(x) for x in output]
54
55  def _tokenize_chinese_chars(self, text):
56    # 按字切分中文,实现就是在字两侧添加空格
57    output = []
58    for char in text:
59      cp = ord(char)
60      if self._is_chinese_char(cp):
61        output.append(" ")
62        output.append(char)
63        output.append(" ")
64      else:
65        output.append(char)
66    return "".join(output)
67
68  def _is_chinese_char(self, cp):
69    # 判断是否是汉字
70    # refer:https://www.cnblogs.com/straybirds/p/6392306.html
71    if ((cp >= 0x4E00 and cp <= 0x9FFF) or  #
72        (cp >= 0x3400 and cp <= 0x4DBF) or  #
73        (cp >= 0x20000 and cp <= 0x2A6DF) or  #
74        (cp >= 0x2A700 and cp <= 0x2B73F) or  #
75        (cp >= 0x2B740 and cp <= 0x2B81F) or  #
76        (cp >= 0x2B820 and cp <= 0x2CEAF) or
77        (cp >= 0xF900 and cp <= 0xFAFF) or  #
78        (cp >= 0x2F800 and cp <= 0x2FA1F)):  #
79      return True
80
81    return False
82
83  def _clean_text(self, text):
84    # 去除无意义字符以及空格
85    output = []
86    for char in text:
87      cp = ord(char)
88      if cp == 0 or cp == 0xfffd or _is_control(char):
89        continue
90      if _is_whitespace(char):
91        output.append(" ")
92      else:
93        output.append(char)
94    return "".join(output)
WordpieceTokenizer

WordpieceTokenizer是将BasicTokenizer的结果进一步做更细粒度的切分。做这一步的目的主要是为了去除未登录词对模型效果的影响。这一过程对中文没有影响,因为在前面BasicTokenizer里面已经切分成以字为单位的了。

代码语言:javascript
复制
 1class WordpieceTokenizer(object):
 2  def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
 3    self.vocab = vocab
 4    self.unk_token = unk_token
 5    self.max_input_chars_per_word = max_input_chars_per_word
 6
 7  def tokenize(self, text):
 8    """使用贪心的最大正向匹配算法
 9    例如:
10      input = "unaffable"
11      output = ["un", "##aff", "##able"]
12    """
13    text = convert_to_unicode(text)
14
15    output_tokens = []
16    for token in whitespace_tokenize(text):
17      chars = list(token)
18      if len(chars) > self.max_input_chars_per_word:
19        output_tokens.append(self.unk_token)
20        continue
21
22      is_bad = False
23      start = 0
24      sub_tokens = []
25      while start < len(chars):
26        end = len(chars)
27        cur_substr = None
28        while start < end:
29          substr = "".join(chars[start:end])
30          if start > 0:
31            substr = "##" + substr
32          if substr in self.vocab:
33            cur_substr = substr
34            break
35          end -= 1
36        if cur_substr is None:
37          is_bad = True
38          break
39        sub_tokens.append(cur_substr)
40        start = end
41
42      if is_bad:
43        output_tokens.append(self.unk_token)
44      else:
45        output_tokens.extend(sub_tokens)
46    return output_tokens

我们用一个例子来看代码的执行过程。比如假设输入是”unaffable”。我们跳到while循环部分,这是start=0,end=len(chars)=9,也就是先看看unaffable在不在词典里,如果在,那么直接作为一个WordPiece,如果不再,那么end-=1,也就是看unaffabl在不在词典里,最终发现”un”在词典里,把un加到结果里。

接着start=2,看affable在不在,不在再看affabl,…,最后发现 ##aff 在词典里。注意:##表示这个词是接着前面的,这样使得WordPiece切分是可逆的——我们可以恢复出“真正”的词。

FullTokenizer

BERT分词的主要接口,包含了上述两种实现。

代码语言:javascript
复制
 1class FullTokenizer(object):
 2  def __init__(self, vocab_file, do_lower_case=True):
 3    # 加载词表文件为字典形式
 4    self.vocab = load_vocab(vocab_file)
 5    self.inv_vocab = {v: k for k, v in self.vocab.items()}
 6    self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
 7    self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
 8
 9  def tokenize(self, text):
10    split_tokens = []
11    # 调用BasicTokenizer粗粒度分词
12    for token in self.basic_tokenizer.tokenize(text):
13      # 调用WordpieceTokenizer细粒度分词
14      for sub_token in self.wordpiece_tokenizer.tokenize(token):
15        split_tokens.append(sub_token)
16
17    return split_tokens
18
19  def convert_tokens_to_ids(self, tokens):
20    return convert_by_vocab(self.vocab, tokens)
21
22  def convert_ids_to_tokens(self, ids):
23    return convert_by_vocab(self.inv_vocab, ids)

2、训练数据生成(create_pretraining_data.py)

这个文件的这作用就是将原始输入语料转换成模型预训练所需要的数据格式TFRecoed。

参数设置
代码语言:javascript
复制
 1flags.DEFINE_string("input_file", None,
 2                    "Input raw text file (or comma-separated list of files).")
 3
 4flags.DEFINE_string("output_file", None,
 5    "Output TF example file (or comma-separated list of files).")
 6
 7flags.DEFINE_string("vocab_file", None,
 8                    "The vocabulary file that the BERT model was trained on.")
 9
10flags.DEFINE_bool( "do_lower_case", True,
11    "Whether to lower case the input text. Should be True for uncased "
12    "models and False for cased models.")
13
14flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
15
16flags.DEFINE_integer("max_predictions_per_seq", 20,
17                     "Maximum number of masked LM predictions per sequence.")
18
19flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
20
21flags.DEFINE_integer( "dupe_factor", 10,
22    "Number of times to duplicate the input data (with different masks).")
23
24flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
25
26flags.DEFINE_float("short_seq_prob", 0.1,
27    "Probability of creating sequences which are shorter than the maximum length.")

这里就说几个参数

  • dupe_factor: 重复参数,即对于同一个句子,我们可以设置不同位置的【MASK】次数。比如对于句子Hello world, this is bert.,为了充分利用数据,第一次可以mask成Hello [MASK], this is bert.,第二次可以变成Hello world, this is [MASK[.
  • max_predictions_per_seq: 一个句子里最多有多少个[MASK]标记
  • masked_lm_prob: 多少比例的Token被MASK掉
  • short_seq_prob: 长度小于“max_seq_length”的样本比例。因为在fine-tune过程里面输入的target_seq_length是可变的(小于等于max_seq_length),那么为了防止过拟合也需要在pre-train的过程当中构造一些短的样本。
Main入口

首先来看构造数据的整体流程,

代码语言:javascript
复制
 1def main(_):
 2  tf.logging.set_verbosity(tf.logging.INFO)
 3
 4  tokenizer = tokenization.FullTokenizer(
 5      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
 6
 7  input_files = []
 8  for input_pattern in FLAGS.input_file.split(","):
 9    input_files.extend(tf.gfile.Glob(input_pattern))
10
11  tf.logging.info("*** Reading from input files ***")
12  for input_file in input_files:
13    tf.logging.info("  %s", input_file)
14
15  rng = random.Random(FLAGS.random_seed)
16  instances = create_training_instances(
17      input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
18      FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
19      rng)
20
21  output_files = FLAGS.output_file.split(",")
22  tf.logging.info("*** Writing to output files ***")
23  for output_file in output_files:
24    tf.logging.info("  %s", output_file)
25
26  write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
27                                  FLAGS.max_predictions_per_seq, output_files)
  • 构造tokenizer对输入语料进行分词处理(Tokenizer部分会在后续说明)
  • 经过create_training_instances函数构造训练instance
  • 调用write_instance_to_example_files函数以TFRecord格式保存数据 下面我们一一解析这些函数。
构造训练样本

首先定义了一个训练样本的类

代码语言:javascript
复制
 1class TrainingInstance(object):
 2
 3  def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
 4               is_random_next):
 5    self.tokens = tokens
 6    self.segment_ids = segment_ids
 7    self.is_random_next = is_random_next
 8    self.masked_lm_positions = masked_lm_positions
 9    self.masked_lm_labels = masked_lm_labels
10
11  def __str__(self):
12    s = ""
13    s += "tokens: %s\n" % (" ".join(
14        [tokenization.printable_text(x) for x in self.tokens]))
15    s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
16    s += "is_random_next: %s\n" % self.is_random_next
17    s += "masked_lm_positions: %s\n" % (" ".join(
18        [str(x) for x in self.masked_lm_positions]))
19    s += "masked_lm_labels: %s\n" % (" ".join(
20        [tokenization.printable_text(x) for x in self.masked_lm_labels]))
21    s += "\n"
22    return s
23
24  def __repr__(self):
25    return self.__str__()

构造训练样本的代码如下。在源码包中Google提供了一个实例训练样本输入(sample_text.txt),输入文件格式为:

  • 每行一个句子,这应该是实际的句子,不应该是整个段落或者段落的随机片段(span),因为我们需要使用句子边界来做下一个句子的预测。
  • 不同文档之间用一个空行分割。
  • 我们认为同一文档的句子之间是有关系的,不同文档句子之间没有关系。
代码语言:javascript
复制
 1def create_training_instances(input_files, tokenizer, max_seq_length,
 2                              dupe_factor, short_seq_prob, masked_lm_prob,
 3                              max_predictions_per_seq, rng):
 4  all_documents = [[]]
 5  # all_documents是list的list,第一层list表示document,
 6  # 第二层list表示document里的多个句子。
 7  for input_file in input_files:
 8    with tf.gfile.GFile(input_file, "r") as reader:
 9      while True:
10        line = tokenization.convert_to_unicode(reader.readline())
11        if not line:
12          break
13        line = line.strip()
14
15        # 空行表示文档分割
16        if not line:
17          all_documents.append([])
18        tokens = tokenizer.tokenize(line)
19        if tokens:
20          all_documents[-1].append(tokens)
21
22  # 删除空文档
23  all_documents = [x for x in all_documents if x]
24  rng.shuffle(all_documents)
25
26  vocab_words = list(tokenizer.vocab.keys())
27  instances = []
28  # 重复dupe_factor次
29  for _ in range(dupe_factor):
30    for document_index in range(len(all_documents)):
31      instances.extend(
32          create_instances_from_document(
33              all_documents, document_index, max_seq_length, short_seq_prob,
34              masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
35
36  rng.shuffle(instances)
37  return instances

上面的函数会调用create_instances_from_document来实现从一个文档中抽取多个训练样本。

代码语言:javascript
复制
  1def create_instances_from_document(
  2    all_documents, document_index, max_seq_length, short_seq_prob,
  3    masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
  4
  5  document = all_documents[document_index]
  6
  7  # 为[CLS], [SEP], [SEP]预留三个空位
  8  max_num_tokens = max_seq_length - 3
  9
 10  target_seq_length = max_num_tokens
 11  # 以short_seq_prob的概率随机生成(2~max_num_tokens)的长度
 12  if rng.random() < short_seq_prob:
 13    target_seq_length = rng.randint(2, max_num_tokens)
 14
 15  #
 16  instances = []
 17  current_chunk = []
 18  current_length = 0
 19  i = 0
 20  while i < len(document):
 21    segment = document[i]
 22    current_chunk.append(segment)
 23    current_length += len(segment)
 24    # 将句子依次加入current_chunk中,直到加完或者达到限制的最大长度
 25    if i == len(document) - 1 or current_length >= target_seq_length:
 26      if current_chunk:
 27        # `a_end`是第一个句子A结束的下标
 28        a_end = 1
 29        # 随机选取切分边界
 30        if len(current_chunk) >= 2:
 31          a_end = rng.randint(1, len(current_chunk) - 1)
 32
 33        tokens_a = []
 34        for j in range(a_end):
 35          tokens_a.extend(current_chunk[j])
 36
 37        tokens_b = []
 38        # 是否随机next
 39        is_random_next = False
 40        # 构建随机的下一句
 41        if len(current_chunk) == 1 or rng.random() < 0.5:
 42          is_random_next = True
 43          target_b_length = target_seq_length - len(tokens_a)
 44
 45          # 随机的挑选另外一篇文档的随机开始的句子
 46          # 但是理论上有可能随机到的文档就是当前文档,因此需要一个while循环
 47          # 这里只while循环10次,理论上还是有重复的可能性,但是我们忽略
 48          for _ in range(10):
 49            random_document_index = rng.randint(0, len(all_documents) - 1)
 50            if random_document_index != document_index:
 51              break
 52
 53          random_document = all_documents[random_document_index]
 54          random_start = rng.randint(0, len(random_document) - 1)
 55          for j in range(random_start, len(random_document)):
 56            tokens_b.extend(random_document[j])
 57            if len(tokens_b) >= target_b_length:
 58              break
 59          # 对于上述构建的随机下一句,我们并没有真正地使用它们
 60          # 所以为了避免数据浪费,我们将其“放回”
 61          num_unused_segments = len(current_chunk) - a_end
 62          i -= num_unused_segments
 63        # 构建真实的下一句
 64        else:
 65          is_random_next = False
 66          for j in range(a_end, len(current_chunk)):
 67            tokens_b.extend(current_chunk[j])
 68        # 如果太多了,随机去掉一些
 69        truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
 70
 71        assert len(tokens_a) >= 1
 72        assert len(tokens_b) >= 1
 73
 74        tokens = []
 75        segment_ids = []
 76        # 处理句子A
 77        tokens.append("[CLS]")
 78        segment_ids.append(0)
 79        for token in tokens_a:
 80          tokens.append(token)
 81          segment_ids.append(0)
 82        # 句子A结束,加上【SEP】
 83        tokens.append("[SEP]")
 84        segment_ids.append(0)
 85        # 处理句子B
 86        for token in tokens_b:
 87          tokens.append(token)
 88          segment_ids.append(1)
 89        # 句子B结束,加上【SEP】
 90        tokens.append("[SEP]")
 91        segment_ids.append(1)
 92
 93        # 调用 create_masked_lm_predictions来随机对某些Token进行mask
 94        (tokens, masked_lm_positions,
 95         masked_lm_labels) = create_masked_lm_predictions(
 96             tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
 97        instance = TrainingInstance(
 98            tokens=tokens,
 99            segment_ids=segment_ids,
100            is_random_next=is_random_next,
101            masked_lm_positions=masked_lm_positions,
102            masked_lm_labels=masked_lm_labels)
103        instances.append(instance)
104      current_chunk = []
105      current_length = 0
106    i += 1
107
108  return instances

上面代码有点长,在关键的地方我都注释上了。下面我们结合一个具体的例子来看代码的实现过程。以提供的sample_text.txt中语料为例,只截取了一部分,下图包含了两个文档,第一个文档中有6个句子,第二个有4个句子:

create_instances_from_document分析的是一个文档,我们就以上述第一个为例。

  1. 算法首先会维护一个chunk,不断加入document中的元素,也就是句子(segment),直到加载完或者chunk中token数大于等于最大限制,这样做的目的是使得padding的尽量少,训练效率更高。
  2. 现在chunk建立完毕之后,假设包括了前三个句子,算法会随机选择一个切分点,比如2。接下来构建predict next判断: (1) 如果是正样本,前两个句子当成是句子A,后一个句子当成是句子B; (2) 如果是负样本,前两个句子当成是句子A,无关的句子从其他文档中随机抽取
  3. 得到句子A和句子B之后,对其填充tokens和segment_ids,这里会加入特殊的[CLS]和[SEP]标记
  4. 对句子进行mask操作(下一节中描述)
随机MASK

对Tokens进行随机mask是BERT的一大创新点。使用mask的原因是为了防止模型在双向循环训练的过程中“预见自身”。于是,文章中选取的策略是对输入序列中15%的词使用[MASK]标记掩盖掉,然后通过上下文去预测这些被mask的token。但是为了防止模型过拟合地学习到【MASK】这个标记,对15%mask掉的词进一步优化:

  • 以80%的概率用[MASK]替换:
    • hello world, this is bert. ----> hello world, this is [MASK].
  • 以10%的概率随机替换:
    • hello world, this is bert. ----> hello world, this is python.
  • 以10%的概率不进行替换:
    • hello world, this is bert. ----> hello world, this is bert.
代码语言:javascript
复制
 1def create_masked_lm_predictions(tokens, masked_lm_prob,
 2                                 max_predictions_per_seq, vocab_words, rng):
 3
 4  cand_indexes = []
 5  # [CLS]和[SEP]不能用于MASK
 6  for (i, token) in enumerate(tokens):
 7    if token == "[CLS]" or token == "[SEP]":
 8      continue
 9    cand_indexes.append(i)
10
11  rng.shuffle(cand_indexes)
12
13  output_tokens = list(tokens)
14
15  num_to_predict = min(max_predictions_per_seq,
16                       max(1, int(round(len(tokens) * masked_lm_prob))))
17
18  masked_lms = []
19  covered_indexes = set()
20  for index in cand_indexes:
21    if len(masked_lms) >= num_to_predict:
22      break
23    if index in covered_indexes:
24      continue
25    covered_indexes.add(index)
26
27    masked_token = None
28    # 80% of the time, replace with [MASK]
29    if rng.random() < 0.8:
30      masked_token = "[MASK]"
31    else:
32      # 10% of the time, keep original
33      if rng.random() < 0.5:
34        masked_token = tokens[index]
35      # 10% of the time, replace with random word
36      else:
37        masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
38
39    output_tokens[index] = masked_token
40
41    masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
42
43  # 按照下标重排,保证是原来句子中出现的顺序
44  masked_lms = sorted(masked_lms, key=lambda x: x.index)
45
46  masked_lm_positions = []
47  masked_lm_labels = []
48  for p in masked_lms:
49    masked_lm_positions.append(p.index)
50    masked_lm_labels.append(p.label)
51
52  return (output_tokens, masked_lm_positions, masked_lm_labels)
保存tfrecord数据

最后是将上述步骤处理好的数据保存为tfrecord文件。整体逻辑比较简单,代码如下

代码语言:javascript
复制
 1def write_instance_to_example_files(instances, tokenizer, max_seq_length,
 2                                    max_predictions_per_seq, output_files):
 3
 4  writers = []
 5  for output_file in output_files:
 6    writers.append(tf.python_io.TFRecordWriter(output_file))
 7
 8  writer_index = 0
 9
10  total_written = 0
11  for (inst_index, instance) in enumerate(instances):
12      # 将输入转成word-ids
13    input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
14    # 记录实际句子长度
15    input_mask = [1] * len(input_ids)
16    segment_ids = list(instance.segment_ids)
17    assert len(input_ids) <= max_seq_length
18
19    # padding
20    while len(input_ids) < max_seq_length:
21      input_ids.append(0)
22      input_mask.append(0)
23      segment_ids.append(0)
24
25    assert len(input_ids) == max_seq_length
26    assert len(input_mask) == max_seq_length
27    assert len(segment_ids) == max_seq_length
28
29    masked_lm_positions = list(instance.masked_lm_positions)
30    masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
31    masked_lm_weights = [1.0] * len(masked_lm_ids)
32
33    while len(masked_lm_positions) < max_predictions_per_seq:
34      masked_lm_positions.append(0)
35      masked_lm_ids.append(0)
36      masked_lm_weights.append(0.0)
37
38    next_sentence_label = 1 if instance.is_random_next else 0
39
40    features = collections.OrderedDict()
41    features["input_ids"] = create_int_feature(input_ids)
42    features["input_mask"] = create_int_feature(input_mask)
43    features["segment_ids"] = create_int_feature(segment_ids)
44    features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
45    features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
46    features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
47    features["next_sentence_labels"] = create_int_feature([next_sentence_label])
48
49    # 生成训练样本
50    tf_example = tf.train.Example(features=tf.train.Features(feature=features))
51
52    # 输出到文件
53    writers[writer_index].write(tf_example.SerializeToString())
54    writer_index = (writer_index + 1) % len(writers)
55
56    total_written += 1
57
58    # 打印前20个样本
59    if inst_index < 20:
60      tf.logging.info("*** Example ***")
61      tf.logging.info("tokens: %s" % " ".join(
62          [tokenization.printable_text(x) for x in instance.tokens]))
63
64      for feature_name in features.keys():
65        feature = features[feature_name]
66        values = []
67        if feature.int64_list.value:
68          values = feature.int64_list.value
69        elif feature.float_list.value:
70          values = feature.float_list.value
71        tf.logging.info(
72            "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
73
74  for writer in writers:
75    writer.close()
76
77  tf.logging.info("Wrote %d total instances", total_written)
测试代码
代码语言:javascript
复制
 1python create_pretraining_data.py \
 2  --input_file=./sample_text_zh.txt \
 3  --output_file=/tmp/tf_examples.tfrecord \
 4  --vocab_file=$BERT_BASE_DIR/vocab.txt \
 5  --do_lower_case=True \
 6  --max_seq_length=128 \
 7  --max_predictions_per_seq=20 \
 8  --masked_lm_prob=0.15 \
 9  --random_seed=12345 \
10  --dupe_factor=5

因为我之前下载的词表是中文的,所以就网上随便找了几篇新闻进行测试。结果如下

这是其中的一个样例:

小结一哈

主要介绍BERT的自带分词组件以及pretraining数据生成过程,属于整个项目的准备部分。 没想到代码这么多,pretraining训练的部分就不放在这一篇里了,请见下篇~

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-08-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 深度学习自然语言处理 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 写在前面
  • 1、分词(tokenization.py)
    • BasicTokenizer
      • WordpieceTokenizer
        • FullTokenizer
        • 2、训练数据生成(create_pretraining_data.py)
          • 参数设置
            • Main入口
              • 构造训练样本
                • 随机MASK
                  • 保存tfrecord数据
                    • 测试代码
                    • 小结一哈
                    领券
                    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档