专栏首页量子位1美元训练BERT,教你如何薅谷歌TPU羊毛 | 附Colab代码

1美元训练BERT,教你如何薅谷歌TPU羊毛 | 附Colab代码

晓查 发自 凹非寺 量子位 出品 | 公众号 QbitAI

BERT是谷歌去年推出的NLP模型,一经推出就在各项测试中碾压竞争对手,而且BERT是开源的。只可惜训练BERT的价格实在太高,让人望而却步。

之前需要用64个TPU训练4天才能完成,后来谷歌用并行计算优化了到只需一个多小时,但是需要的TPU数量陡增,达到了惊人的1024个。

那么总共要多少钱呢?谷歌云TPU的使用价格是每个每小时6.5美元,训练完成训练完整个模型需要近4万美元,简直就是天价。

现在,有个羊毛告诉你,在Medium上有人找到了薅谷歌羊毛的办法,只需1美元就能训练BERT,模型还能留存在你的谷歌云盘中,留作以后使用。

准备工作

为了薅谷歌的羊毛,您需要一个Google云存储(Google Cloud Storage)空间。按照Google 云TPU快速入门指南,创建Google云平台(Google Cloud Platform)帐户和Google云存储账户。新的谷歌云平台用户可获得300美元的免费赠送金额。

在TPUv2上预训练BERT-Base模型大约需要54小时。Google Colab并非设计用于执行长时间运行的作业,它会每8小时左右中断一次训练过程。对于不间断的训练,请考虑使用付费的不间断使用TPUv2的方法。

也就是说,使用Colab TPU,你可以在以1美元的价格在Google云盘上存储模型和数据,以几乎可忽略成本从头开始预训练BERT模型。

以下是整个过程的代码下面的代码,可以在Colab Jupyter环境中运行。

设置训练环境

首先,安装训练模型所需的包。Jupyter允许使用’!’直接从笔记本执行bash命令:

!pip install sentencepiece
!git clone https://github.com/google-research/bert

导入包并在Google云中授权:

import os
import sys
import json
import nltk
import random
import logging
import tensorflow as tf
import sentencepiece as spm

from glob import glob
from google.colab import auth, drive
from tensorflow.keras.utils import Progbar

sys.path.append("bert")

from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder

auth.authenticate_user()

# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s :  %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]

if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  USE_TPU = True
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)

else:
  log.warning('Not connected to TPU runtime')
  USE_TPU = False

下载原始文本数据

接下来从网络上获取文本数据语料库。在本次实验中,我们使用OpenSubtitles数据集,该数据集包括65种语言。

与更常用的文本数据集(如维基百科)不同,它不需要任何复杂的预处理,提供预格式化,一行一个句子。

AVAILABLE =  {'af','ar','bg','bn','br','bs','ca','cs',
              'da','de','el','en','eo','es','et','eu',
              'fa','fi','fr','gl','he','hi','hr','hu',
              'hy','id','is','it','ja','ka','kk','ko',
              'lt','lv','mk','ml','ms','nl','no','pl',
              'pt','pt_br','ro','ru','si','sk','sl','sq',
              'sr','sv','ta','te','th','tl','tr','uk',
              'ur','vi','ze_en','ze_zh','zh','zh_cn',
              'zh_en','zh_tw','zh_zh'}

LANG_CODE = "en" #@param {type:"string"}

assert LANG_CODE in AVAILABLE, "Invalid language code selected"

!wget http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.'$LANG_CODE'.gz -O dataset.txt.gz
!gzip -d dataset.txt.gz
!tail dataset.txt

你可以通过设置代码随意选择你需要的语言。出于演示目的,代码只默认使用整个语料库的一小部分。在实际训练模型时,请务必取消选中DEMO_MODE复选框,使用大100倍的数据集。

当然,100M数据足以训练出相当不错的BERT基础模型。

DEMO_MODE = True #@param {type:"boolean"}
if DEMO_MODE:
 CORPUS_SIZE = 1000000
else:
 CORPUS_SIZE = 100000000 #@param {type: "integer"}

!(head -n $CORPUS_SIZE dataset.txt) > subdataset.txt
!mv subdataset.txt dataset.txt

预处理文本数据

我们下载的原始文本数据包含标点符号,大写字母和非UTF符号,我们将在继续下一步之前将其删除。在推理期间,我们将对新数据应用相同的过程。

如果你需要不同的预处理方式(例如在推理期间预期会出现大写字母或标点符号),请修改以下代码以满足你的需求。

regex_tokenizer = nltk.RegexpTokenizer("\w+")

def normalize_text(text):
  # lowercase text
  text = str(text).lower()
  # remove non-UTF
  text = text.encode("utf-8", "ignore").decode()
  # remove punktuation symbols
  text = " ".join(regex_tokenizer.tokenize(text))
  return text

def count_lines(filename):
  count = 0
  with open(filename) as fi:
    for line in fi:
      count += 1
  return count

现在让我们预处理整个数据集:

RAW_DATA_FPATH = "dataset.txt" #@param {type: "string"}
PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}

# apply normalization to the dataset
# this will take a minute or two

total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)

with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
  with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
    for l in fi:
      fo.write(normalize_text(l)+"\n")
      bar.add(1)

构建词汇表

下一步,我们将训练模型学习一个新的词汇表,用于表示我们的数据集。

BERT文件使用WordPiece分词器,在开源中不可用。我们将在unigram模式下使用SentencePiece分词器。虽然它与BERT不直接兼容,但是通过一个小的处理方法,可以使它工作。

SentencePiece需要相当多的运行内存,因此在Colab中的运行完整数据集会导致内核崩溃。

为避免这种情况,我们将随机对数据集的一小部分进行子采样,构建词汇表。另一个选择是使用更大内存的机器来执行此步骤。

此外,SentencePiece默认情况下将BOS和EOS控制符号添加到词汇表中。我们通过将其索引设置为-1来禁用它们。

VOC_SIZE的典型值介于32000和128000之间。如果想要更新词汇表,并在预训练阶段结束后对模型进行微调,我们会保留NUM_PLACEHOLDERS个token。

MODEL_PREFIX = "tokenizer" #@param {type: "string"}
VOC_SIZE = 32000 #@param {type:"integer"}
SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}
NUM_PLACEHOLDERS = 256 #@param {type:"integer"}

SPM_COMMAND = ('--input={} --model_prefix={} '
               '--vocab_size={} --input_sentence_size={} '
               '--shuffle_input_sentence=true ' 
               '--bos_id=-1 --eos_id=-1').format(
               PRC_DATA_FPATH, MODEL_PREFIX, 
               VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)

spm.SentencePieceTrainer.Train(SPM_COMMAND)

现在,让我们看看如何让SentencePiece在BERT模型上工作。

下面是使用来自官方的预训练英语BERT基础模型的WordPiece词汇表标记的语句。

>>> wordpiece.tokenize("Colorless geothermal substations are generating furiously")

['color',
 '##less',
 'geo',
 '##thermal',
 'sub',
 '##station',
 '##s',
 'are',
 'generating',
 'furiously']

WordPiece标记器在“##”的单词中间预置了出现的子字。在单词开头出现的子词不变。如果子词出现在单词的开头和中间,则两个版本(带和不带’##’)都会添加到词汇表中。

SentencePiece创建了两个文件:tokenizer.model和tokenizer.vocab。让我们来看看它学到的词汇:

def read_sentencepiece_vocab(filepath):
  voc = []
  with open(filepath, encoding='utf-8') as fi:
    for line in fi:
      voc.append(line.split("\t")[0])
  # skip the first <unk> token
  voc = voc[1:]
  return voc

snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))

运行结果:

Learnt vocab size: 31743 
Sample tokens: ['▁cafe', '▁slippery', 'xious', '▁resonate', '▁terrier', '▁feat', '▁frequencies', 'ainty', '▁punning', 'modern']

SentencePiece与WordPiece的运行结果完全相反。从文档中可以看出:SentencePiece首先使用元符号“_”将空格转义为空格,如下所示:

Hello_World。

然后文本被分段为小块:

[Hello] [_Wor] [ld] [.]

在空格之后出现的子词(也是大多数词开头的子词)前面加上“_”,而其他子词不变。这排除了仅出现在句子开头而不是其他地方的子词。然而,这些案件应该非常罕见。

因此,为了获得类似于WordPiece的词汇表,我们需要执行一个简单的转换,从包含它的标记中删除“_”,并将“##”添加到不包含它的标记中。

我们还添加了一些BERT架构所需的特殊控制符号。按照惯例,我们把它们放在词汇的开头。 另外,我们在词汇表中添加了一些占位符token。

如果你希望使用新的用于特定任务的token来更新预先训练的模型,那么这些方法是很有用的。

在这种情况下,占位符token被替换为新的token,重新生成预训练数据,并且对新数据进行微调。

def parse_sentencepiece_token(token):
    if token.startswith("▁"):
        return token[1:]
    else:
        return "##" + token

bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))

ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab

bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))

最后,我们将获得的词汇表写入文件。

VOC_FNAME = "vocab.txt" #@param {type:"string"}

with open(VOC_FNAME, "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

现在,让我们看看新词汇在实践中是如何运作的:

>>> testcase = "Colorless geothermal substations are generating furiously"
>>> bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)
>>> bert_tokenizer.tokenize(testcase)
['color',  
 '##less',  
 'geo',  
 '##ther',  
 '##mal',  
 'sub',  
 '##station',  
 '##s',  
 'are',  
 'generat',  
 '##ing',  
 'furious',  
 '##ly']

创建分片预训练数据(生成预训练数据)

通过手头的词汇表,我们可以为BERT模型生成预训练数据。

由于我们的数据集可能非常大,我们将其拆分为碎片:

mkdir ./shards
split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_

现在,对于每个部分,我们需要从BERT仓库调用create_pretraining_data.py脚本,需要使用xargs命令。

在开始生成之前,我们需要设置一些参数传递给脚本。你可以从自述文件中找到有关它们含义的更多信息。

MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}

PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
# controls how many parallel processes xargs can create
PROCESSES = 2 #@param {type:"integer"}

运行此操作可能需要相当长的时间,具体取决于数据集的大小。

XARGS_CMD = ("ls ./shards/ | "
             "xargs -n 1 -P {} -I{} "
             "python3 bert/create_pretraining_data.py "
             "--input_file=./shards/{} "
             "--output_file={}/{}.tfrecord "
             "--vocab_file={} "
             "--do_lower_case={} "
             "--max_predictions_per_seq={} "
             "--max_seq_length={} "
             "--masked_lm_prob={} "
             "--random_seed=34 "
             "--dupe_factor=5")

XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', 
                             VOC_FNAME, DO_LOWER_CASE, 
                             MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)

tf.gfile.MkDir(PRETRAINING_DIR)
!$XARGS_CMD

为数据和模型设置GCS存储,将数据和模型存储到云端

为了保留来之不易的训练模型,我们会将其保留在Google云存储中。

在Google云存储中创建两个目录,一个用于数据,一个用于模型。在模型目录中,我们将放置模型词汇表和配置文件。

在继续操作之前,请配置BUCKET_NAME变量,否则将无法训练模型。

BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
tf.gfile.MkDir(MODEL_DIR)

if not BUCKET_NAME:
  log.warning("WARNING: BUCKET_NAME is not set. "
              "You will not be able to train the model.")

下面是BERT-base的超参数配置示例:

# use this for BERT-base

bert_base_config = {
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": VOC_SIZE
}

with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo:
  json.dump(bert_base_config, fo, indent=2)

with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

现在,我们已准备好将模型和数据存储到谷歌云当中:

if BUCKET_NAME:
  !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME

在云TPU上训练模型

注意,之前步骤中的某些参数在此处不用改变。请确保在整个实验中设置的参数完全相同。

BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
VOC_FNAME = "vocab.txt" #@param {type:"string"}

# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param

# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 1000000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}
NUM_TPU_CORES = 8

if BUCKET_NAME:
  BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
  BUCKET_PATH = "."

BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)

VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)
CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")

INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)

bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))

log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))

准备训练运行配置,建立评估器和输入函数,启动BERT!

model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=INIT_CHECKPOINT,
      learning_rate=LEARNING_RATE,
      num_train_steps=TRAIN_STEPS,
      num_warmup_steps=10,
      use_tpu=USE_TPU,
      use_one_hot_embeddings=True)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=BERT_GCS_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=USE_TPU,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE)

train_input_fn = input_fn_builder(
        input_files=input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=True)

执行!

estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)

最后,使用默认参数训练模型需要100万步,约54小时的运行时间。如果内核由于某种原因重新启动,可以从断点处继续训练。

以上就是是在云TPU上从头开始预训练BERT的指南。

下一步

好的,我们已经训练好了模型,接下来可以做什么?

1、使用预训练的模型作为通用的自然语言理解模块;

2、针对某些特定的分类任务微调模型;

3、使用BERT作为构建块,去创建另一个深度学习模型。

传送门

原文地址: https://towardsdatascience.com/pre-training-bert-from-scratch-with-cloud-tpu-6e2f71028379

Colab代码: https://colab.research.google.com/drive/1nVn6AFpQSzXBt8_ywfx6XR8ZfQXlKGAz

作者系网易新闻·网易号“各有态度”签约作者

本文分享自微信公众号 - 量子位(QbitAI),作者:关注前沿科技

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-07-27

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 伯克利弹跳机器人再进化:超精准着陆,指哪打哪

    机器人名叫Salto-1P,来自加州伯克利,曾经被IEEE Spectrum热烈地称赞为“最不可思议的弹跳机器人”。

    量子位
  • 为里皮献策:国足再进一步,试试人工智能 | 附机器人世界杯集锦

    唐旭、问耕 发自 凹非寺 量子位·QbitAI 报道 一场有着多重意义的比赛昨晚结束,国足坐镇长沙击败韩国队。 ? “天亮了”,赛后李毅大帝在朋友圈和微博上说。...

    量子位
  • “蚁人”不再是科幻!MIT最新研究,能把任何材料物体缩小1000倍 | Science

    漫威世界中,蚁人是蚂蚁大小的超级英雄,靠一件“变身服”,人类就能在更微观的世界里大干一场。

    量子位
  • Java性能调优-JDK命令行

    类似Linux的ps,但是jps只用于列出Java的进程 可以方便查看Java进程的启动类,传入参数和JVM参数等 直接运行,不加参数,列出Java程序的进...

    JavaEdge
  • C# 动态加载卸载 DLL

    我最近做的软件,需要检测dll或exe是否混淆,需要反射获得类名,这时发现,C#可以加载DLL,但不能卸载DLL。于是在网上找到一个方法,可以动态加载DLL,不...

    林德熙
  • matlab绘图(四)

    感谢大家关注matlab爱好者公众号,今天给大家介绍matlab常用的二维和三维绘图命令。在聊天栏中回复“012”、“二维”或“三维” 即可快速获取本文章。通过...

    艾木樨
  • 马哥linux | Linux系统性能和使用活动监控工具 sysstat

    Sysstat是一个非常方便的工具,它带有众多的系统资源监控工具,用于监控系统的性能和使用情况。我们在日常使用的工具中有相当一部分是来自sysstat工具包的。...

    小小科
  • 难解?SAP云平台集成前路何方?

    实际上,连SAP全球云平台产品营销VP Dan Lahl都表示,这是SAP的一个弱点。为了解决这个问题,SAP今年推出了一系列SAP Cloud Platfor...

    人称T客
  • 一个SAP顾问在美国的这些年

    今天的文章来自我的老乡宋浩,之前作为SAP顾问在美国工作多年。如今即将加入SAP成都研究院S4CRM开发团队。我们都是大邑人。

    Jerry Wang
  • 使用SAP iRPA Studio创建的本地项目,如何部署到SAP云平台上?

    以前看威尔-史密斯主演的《我是传奇》,影片里的人类世界被病毒肆虐之后,荒草丛生满目疮痍,只剩主人公一个人一只狗,好可怕。

    Jerry Wang

扫码关注云+社区

领取腾讯云代金券