专栏首页钛问题albert做Seq2Seq任务 采用UNILM方案
原创

albert做Seq2Seq任务 采用UNILM方案

#! -*- coding: utf-8 -*-
# albert做Seq2Seq任务,采用UNILM方案
# 介绍链接:https://kexue.fm/archives/6933

from __future__ import print_function

import codecs
import glob
import json
import os

import numpy as np
from tqdm import tqdm

from bert4keras.backend import keras, K
from bert4keras.bert import build_bert_model
from bert4keras.optimizers import Adam
from bert4keras.snippets import DataGenerator
from bert4keras.snippets import parallel_apply, sequence_padding
from bert4keras.tokenizer import Tokenizer, load_vocab

seq2seq_config = 'seq2seq_config.json'
min_count = 64
max_len = 128
batch_size = 16
steps_per_epoch = 1000
epochs = 10000

# bert配置
config_path = 'albert_small_zh_google/albert_config_small_google.json'
checkpoint_path = 'albert_small_zh_google/albert_model.ckpt'
dict_path = 'albert_small_zh_google/vocab.txt'

# 训练样本。THUCNews数据集,每个样本保存为一个txt。
txts = glob.glob('thuctc/THUCNews/*/*.txt')


_token_dict = load_vocab(dict_path)  # 读取词典
_tokenizer = Tokenizer(_token_dict, do_lower_case=True)  # 建立临时分词器

if os.path.exists(seq2seq_config):

    tokens = json.load(open(seq2seq_config))

else:

    def _batch_texts():
        texts = []
        for txt in txts:
            text = codecs.open(txt, encoding='utf-8').read()
            texts.append(text)
            if len(texts) == 100:
                yield texts
                texts = []
        if texts:
            yield texts

    def _tokenize_and_count(texts):
        _tokens = {}
        for text in texts:
            for token in _tokenizer.tokenize(text):
                _tokens[token] = _tokens.get(token, 0) + 1
        return _tokens

    tokens = {}

    def _total_count(result):
        for k, v in result.items():
            tokens[k] = tokens.get(k, 0) + v

    # 10进程来完成词频统计
    parallel_apply(
        func=_tokenize_and_count,
        iterable=tqdm(_batch_texts(), desc=u'构建词汇表中'),
        workers=10,
        max_queue_size=100,
        callback=_total_count,
    )

    tokens = [(i, j) for i, j in tokens.items() if j >= min_count]
    tokens = sorted(tokens, key=lambda t: -t[1])
    tokens = [t[0] for t in tokens]
    json.dump(tokens,
              codecs.open(seq2seq_config, 'w', encoding='utf-8'),
              indent=4,
              ensure_ascii=False)

token_dict, keep_words = {}, []  # keep_words是在bert中保留的字表

for t in ['[PAD]', '[UNK]', '[CLS]', '[SEP]']:
    token_dict[t] = len(token_dict)
    keep_words.append(_token_dict[t])

for t in tokens:
    if t in _token_dict and t not in token_dict:
        token_dict[t] = len(token_dict)
        keep_words.append(_token_dict[t])

tokenizer = Tokenizer(token_dict, do_lower_case=True)  # 建立分词器


class data_generator(DataGenerator):
    """数据生成器
    """
    def __iter__(self, random=False):
        idxs = list(range(len(self.data)))
        if random:
            np.random.shuffle(idxs)
        batch_token_ids, batch_segment_ids = [], []
        for i in idxs:
            txt = self.data[i]
            text = codecs.open(txt, encoding='utf-8').read()
            text = text.split('\n')
            if len(text) > 1:
                title = text[0]
                content = '\n'.join(text[1:])
                token_ids, segment_ids = tokenizer.encode(content,
                                                          title,
                                                          max_length=max_len)
                batch_token_ids.append(token_ids)
                batch_segment_ids.append(segment_ids)
            if len(batch_token_ids) == self.batch_size or i == idxs[-1]:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                yield [batch_token_ids, batch_segment_ids], None
                batch_token_ids, batch_segment_ids = [], []


model = build_bert_model(
    config_path,
    checkpoint_path,
    application='seq2seq',
    model='albert',
    keep_words=keep_words,  # 只保留keep_words中的字,精简原字表
)

model.summary()

# 交叉熵作为loss,并mask掉输入部分的预测
y_in = model.input[0][:, 1:]  # 目标tokens
y_mask = model.input[1][:, 1:]
y = model.output[:, :-1]  # 预测tokens,预测与目标错开一位
cross_entropy = K.sparse_categorical_crossentropy(y_in, y)
cross_entropy = K.sum(cross_entropy * y_mask) / K.sum(y_mask)

model.add_loss(cross_entropy)
model.compile(optimizer=Adam(1e-5))


def gen_sent(s, topk=2, title_max_len=32):
    """beam search解码
    每次只保留topk个最优候选结果;如果topk=1,那么就是贪心搜索
    """
    content_max_len = max_len - title_max_len
    token_ids, segment_ids = tokenizer.encode(s, max_length=content_max_len)
    target_ids = [[] for _ in range(topk)]  # 候选答案id
    target_scores = [0] * topk  # 候选答案分数
    for i in range(title_max_len):  # 强制要求输出不超过title_max_len字
        _target_ids = [token_ids + t for t in target_ids]
        _segment_ids = [segment_ids + [1] * len(t) for t in target_ids]
        _probas = model.predict([_target_ids, _segment_ids
                                 ])[:, -1, 3:]  # 直接忽略[PAD], [UNK], [CLS]
        _log_probas = np.log(_probas + 1e-6)  # 取对数,方便计算
        _topk_arg = _log_probas.argsort(axis=1)[:, -topk:]  # 每一项选出topk
        _candidate_ids, _candidate_scores = [], []
        for j, (ids, sco) in enumerate(zip(target_ids, target_scores)):
            # 预测第一个字的时候,输入的topk事实上都是同一个,
            # 所以只需要看第一个,不需要遍历后面的。
            if i == 0 and j > 0:
                continue
            for k in _topk_arg[j]:
                _candidate_ids.append(ids + [k + 3])
                _candidate_scores.append(sco + _log_probas[j][k])
        _topk_arg = np.argsort(_candidate_scores)[-topk:]  # 从中选出新的topk
        target_ids = [_candidate_ids[k] for k in _topk_arg]
        target_scores = [_candidate_scores[k] for k in _topk_arg]
        best_one = np.argmax(target_scores)
        if target_ids[best_one][-1] == 3:
            return tokenizer.decode(target_ids[best_one])
    # 如果title_max_len字都找不到结束符,直接返回
    return tokenizer.decode(target_ids[np.argmax(target_scores)])


def just_show():
    s1 = u'夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及 时 就医 。'
    s2 = u'8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10' \
         u'余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5' \
         u'亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午 ,华 住集 ' \
         u'团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。 '
    for s in [s1, s2]:
        print(u'生成标题:', gen_sent(s))
    print()


class Evaluate(keras.callbacks.Callback):
    def __init__(self):
        self.lowest = 1e10

    def on_epoch_end(self, epoch, logs=None):
        # 保存最优
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save_weights('best_model.weights')
        # 演示效果
        just_show()


if __name__ == '__main__':

    evaluator = Evaluate()
    train_generator = data_generator(txts, batch_size)

    model.fit_generator(train_generator.forfit(),
                        steps_per_epoch=steps_per_epoch,
                        epochs=epochs,
                        callbacks=[evaluator])

else:

    model.load_weights('best_model.weights')

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 如何根据thucnews中的海量文章数据集训练一个根据文章生成题目的seq2seq模型

    首先安装bert4keras pip install git+https://www.github.com/bojone/bert4keras.git 基于苏...

    用户1750490
  • 利用bert系列预训练模型在非结构化数据抽取数据

    https://ai.baidu.com/broad/download?dataset=sked

    用户1750490
  • 金融领域的统计学方法应用

    上升一个等级就是 第一产业 制造业 第二产业 以及第三产业服务产业的动态问题。再到中央银行对整个产业的现金流限制。

    用户1750490
  • 如何根据thucnews中的海量文章数据集训练一个根据文章生成题目的seq2seq模型

    首先安装bert4keras pip install git+https://www.github.com/bojone/bert4keras.git 基于苏...

    用户1750490
  • 小明学习代码审计writeup

    根据链接的复制访问resetpwd.php,并查看网页源码,发现注释中有PHP代码:

    KevinBruce
  • 雷达数据处理和风场反演

    强对流活动通常会伴随降水、降雹和龙卷风等现象,气象雷达常用于探测上述天气现象,并可以根据雷达观测数据采用外推等方法进行短临预报。

    zhangqibot
  • Ubuntu 12.04 LTS 搭建svn,mysql,apache过程

    1.apt-get install subversion libapache2-svn libapache2-mod-auth-mysql apache2 my...

    苦咖啡
  • 如何检测处理器是否支持AES-NI指令集?

    本文介绍如何检测处理器是否支持AES-NI指令集,首先我们先了解一下什么是AES-NI指令集。

    隔离没老王
  • 三十天写三十个网站后,我学到的东西[每日前端夜话0x3C]

    上个学期用 JavaScript 写了一些好玩的网站,但开始用 React 或其他框架的时候,总觉得有点不踏实,应该要对原生的 JavaScript(或称 Va...

    疯狂的技术宅
  • Pytest全局用例共用之conftest.py详解

    1、可以跨.py文件调用,有多个.py文件调用时,可让conftest.py只调用了一次fixture,或调用多次fixture

    橙子探索测试

扫码关注云+社区

领取腾讯云代金券