前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用TensorFlow 2.0的简单BERT

使用TensorFlow 2.0的简单BERT

作者头像
代码医生工作室
发布2019-11-12 15:40:25
8.4K0
发布2019-11-12 15:40:25
举报
文章被收录于专栏:相约机器人

作者 | Gailly Nemes

来源 | Medium

这篇文章展示了使用TensorFlow 2.0的BERT [1]嵌入的简单用法。由于TensorFlow 2.0最近已发布,该模块旨在使用基于高级Keras API的简单易用的模型。在一本很长的NoteBook中描述了BERT的先前用法,该NoteBook实现了电影评论预测。在这篇文章中,将看到一个使用Keras和最新的TensorFlow和TensorFlow Hub模块的简单BERT嵌入生成器。所有代码都可以在Google Colab上找到。

https://colab.research.google.com/github/google-research/bert/blob/master/predicting_movie_reviews_with_bert_on_tf_hub.ipynb#scrollTo=LL5W8gEGRTAf

https://colab.research.google.com/drive/1hMLd5-r82FrnFnBub-B-fVW78Px4KPX1

使用该bert-embedding 模块使用预先训练的无大小写BERT基本模型生成句子级和令牌级嵌入。在这里,仅需几个步骤即可实现该模块的用法。

Module imports

将使用最新的TensorFlow(2.0+)和TensorFlow Hub(0.7+),因此,可能需要在系统中进行升级。对于模型创建,使用高级Keras API模型类(新集成到tf.keras中)。

BERT令牌生成器仍来自BERT python模块。

代码语言:javascript
复制
import tensorflow_hub as hub
import tensorflow as tf
from bert.tokenization import FullTokenizer
from tensorflow.keras.models import Model

模型

将基于TensorFlow Hub上的示例实现一个模型。在这里,可以看到 bert_layer 可以像其他任何Keras层一样在更复杂的模型中使用。

该模型的目标是使用预训练的BERT生成嵌入向量。因此,仅需要BERT层所需的输入,并且模型仅将BERT层作为隐藏层。当然,在BERT层内部,有一个更复杂的体系结构。

该hub.KerasLayer函数将预训练的模型导入为Keras层。

代码语言:javascript
复制
max_seq_length = 128  # Your choice here.
input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                       name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                   name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                    name="segment_ids")
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1",
                            trainable=True)
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
 
model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=[pooled_output, sequence_output])

BERT在Keras中的嵌入模型

预处理

BERT层需要3个输入序列:

  • 令牌ID:句子中的每个令牌。从BERT vocab字典中还原它
  • 掩码ID:为每个令牌掩蔽仅用于序列填充的令牌(因此每个序列具有相同的长度)。
  • 段ID:0表示一个句子序列,如果序列中有两个句子则为1,第二个句子为1。
代码语言:javascript
复制
def get_masks(tokens, max_seq_length):
    """Mask for padding"""
    if len(tokens)>max_seq_length:
        raise IndexError("Token length more than max seq length!")
    return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))
 
 
def get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    if len(tokens)>max_seq_length:
        raise IndexError("Token length more than max seq length!")
    segments = []
    current_segment_id = 0
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))
 
 
def get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids

用于根据标记和最大序列长度生成输入的函数

预测

通过这些步骤,可以为句子生成BERT上下文化嵌入向量!不要忘记添加[CLS]和[SEP]分隔符以保持原始格式!

代码语言:javascript
复制
s = "This is a nice sentence."
stokens = tokenizer.tokenize(s)
stokens = ["[CLS]"] + stokens + ["[SEP]"]
 
input_ids = get_ids(stokens, tokenizer, max_seq_length)
input_masks = get_masks(stokens, max_seq_length)
input_segments = get_segments(stokens, max_seq_length)
 
pool_embs, all_embs = model.predict([[input_ids],[input_masks],[input_segments]])

Bert嵌入生成器正在使用

合并嵌入作为句子级嵌入

原始论文建议使用[CLS]分隔符来表示整个句子,因为每个句子都有一个[CLS]标记,并且由于它是上下文嵌入,因此可以表示整个句子。在bert_layer从TensorFlow集线器返回与针对整个输入序列的表示不同的合并输出。

为了比较两个嵌入,使用余弦相似度。样本语句“这是一个不错的语句。”中的合并嵌入与第一个标记的嵌入之间的差异为0.0276。

总结

这篇文章介绍了一个简单的,基于Keras的,基于TensorFlow 2.0的高级BERT嵌入模型。TensorFlow Hub上还提供了其他模型,例如ALBERT。

可以在Google Colab上访问所有代码。

https://colab.research.google.com/drive/1hMLd5-r82FrnFnBub-B-fVW78Px4KPX1

参考文献

[1] Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.

https://arxiv.org/abs/1810.04805

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-10-31,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 相约机器人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档