前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何根据thucnews中的海量文章数据集训练一个根据文章生成题目的seq2seq模型

如何根据thucnews中的海量文章数据集训练一个根据文章生成题目的seq2seq模型

原创
作者头像
用户1750490
修改2020-01-13 12:25:59
1.2K0
修改2020-01-13 12:25:59
举报
文章被收录于专栏:钛问题钛问题

声明本文代码方案来自苏剑林老师的bert4keras,代码来源链接

https://github.com/bojone/bert4keras

首先安装bert4keras pip install git+https://www.github.com/bojone/bert4keras.git 基于苏剑林老师的bert4keras进行小幅度改动

https://www.github.com/bojone/bert4keras.git

特别感谢腾讯钛提供的免费的32GB显存的机器。希望腾讯钛能一直给我提供机器。对应的我会给腾讯钛写好多好多的技术博客的呦。 下载 thucnews数据集 thucnews文件需要自己申请才可以下载的呦,非商业用途仅为了技术交流哦。 #! -*- coding: utf-8 -*- # albert做Seq2Seq任务,采用UNILM方案

苏剑林老师的原文如下。 # 介绍链接:https://kexue.fm/archives/6933

代码语言:javascript
复制
from __future__ import print_function

import codecs
import glob
import json
import os

import numpy as np
from tqdm import tqdm

from bert4keras.backend import keras, K
from bert4keras.bert import build_bert_model
from bert4keras.optimizers import Adam
from bert4keras.snippets import DataGenerator
from bert4keras.snippets import parallel_apply, sequence_padding
from bert4keras.tokenizer import Tokenizer, load_vocab

seq2seq_config = 'seq2seq_config.json'
min_count = 64
max_len = 128
batch_size = 16
steps_per_epoch = 1000
epochs = 10000

# bert配置
config_path = 'albert_small_zh_google/albert_config_small_google.json'
checkpoint_path = 'albert_small_zh_google/albert_model.ckpt'
dict_path = 'albert_small_zh_google/vocab.txt'

# 训练样本。THUCNews数据集,每个样本保存为一个txt。
txts = glob.glob('thuctc/THUCNews/*/*.txt')


_token_dict = load_vocab(dict_path)  # 读取词典
_tokenizer = Tokenizer(_token_dict, do_lower_case=True)  # 建立临时分词器

if os.path.exists(seq2seq_config):

    tokens = json.load(open(seq2seq_config))

else:

    def _batch_texts():
        texts = []
        for txt in txts:
            text = codecs.open(txt, encoding='utf-8').read()
            texts.append(text)
            if len(texts) == 100:
                yield texts
                texts = []
        if texts:
            yield texts

    def _tokenize_and_count(texts):
        _tokens = {}
        for text in texts:
            for token in _tokenizer.tokenize(text):
                _tokens[token] = _tokens.get(token, 0) + 1
        return _tokens

    tokens = {}

    def _total_count(result):
        for k, v in result.items():
            tokens[k] = tokens.get(k, 0) + v

    # 10进程来完成词频统计
    parallel_apply(
        func=_tokenize_and_count,
        iterable=tqdm(_batch_texts(), desc=u'构建词汇表中'),
        workers=10,
        max_queue_size=100,
        callback=_total_count,
    )

    tokens = [(i, j) for i, j in tokens.items() if j >= min_count]
    tokens = sorted(tokens, key=lambda t: -t[1])
    tokens = [t[0] for t in tokens]
    json.dump(tokens,
              codecs.open(seq2seq_config, 'w', encoding='utf-8'),
              indent=4,
              ensure_ascii=False)

token_dict, keep_words = {}, []  # keep_words是在bert中保留的字表

for t in ['[PAD]', '[UNK]', '[CLS]', '[SEP]']:
    token_dict[t] = len(token_dict)
    keep_words.append(_token_dict[t])

for t in tokens:
    if t in _token_dict and t not in token_dict:
        token_dict[t] = len(token_dict)
        keep_words.append(_token_dict[t])

tokenizer = Tokenizer(token_dict, do_lower_case=True)  # 建立分词器


class data_generator(DataGenerator):
    """数据生成器
    """
    def __iter__(self, random=False):
        idxs = list(range(len(self.data)))
        if random:
            np.random.shuffle(idxs)
        batch_token_ids, batch_segment_ids = [], []
        for i in idxs:
            txt = self.data[i]
            text = codecs.open(txt, encoding='utf-8').read()
            text = text.split('\n')
            if len(text) > 1:
                title = text[0]
                content = '\n'.join(text[1:])
                token_ids, segment_ids = tokenizer.encode(content,
                                                          title,
                                                          max_length=max_len)
                batch_token_ids.append(token_ids)
                batch_segment_ids.append(segment_ids)
            if len(batch_token_ids) == self.batch_size or i == idxs[-1]:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                yield [batch_token_ids, batch_segment_ids], None
                batch_token_ids, batch_segment_ids = [], []


model = build_bert_model(
    config_path,
    checkpoint_path,
    application='seq2seq',
    model='albert',
    keep_words=keep_words,  # 只保留keep_words中的字,精简原字表
)

model.summary()

# 交叉熵作为loss,并mask掉输入部分的预测
y_in = model.input[0][:, 1:]  # 目标tokens
y_mask = model.input[1][:, 1:]
y = model.output[:, :-1]  # 预测tokens,预测与目标错开一位
cross_entropy = K.sparse_categorical_crossentropy(y_in, y)
cross_entropy = K.sum(cross_entropy * y_mask) / K.sum(y_mask)

model.add_loss(cross_entropy)
model.compile(optimizer=Adam(1e-5))


def gen_sent(s, topk=2, title_max_len=32):
    """beam search解码
    每次只保留topk个最优候选结果;如果topk=1,那么就是贪心搜索
    """
    content_max_len = max_len - title_max_len
    token_ids, segment_ids = tokenizer.encode(s, max_length=content_max_len)
    target_ids = [[] for _ in range(topk)]  # 候选答案id
    target_scores = [0] * topk  # 候选答案分数
    for i in range(title_max_len):  # 强制要求输出不超过title_max_len字
        _target_ids = [token_ids + t for t in target_ids]
        _segment_ids = [segment_ids + [1] * len(t) for t in target_ids]
        _probas = model.predict([_target_ids, _segment_ids
                                 ])[:, -1, 3:]  # 直接忽略[PAD], [UNK], [CLS]
        _log_probas = np.log(_probas + 1e-6)  # 取对数,方便计算
        _topk_arg = _log_probas.argsort(axis=1)[:, -topk:]  # 每一项选出topk
        _candidate_ids, _candidate_scores = [], []
        for j, (ids, sco) in enumerate(zip(target_ids, target_scores)):
            # 预测第一个字的时候,输入的topk事实上都是同一个,
            # 所以只需要看第一个,不需要遍历后面的。
            if i == 0 and j > 0:
                continue
            for k in _topk_arg[j]:
                _candidate_ids.append(ids + [k + 3])
                _candidate_scores.append(sco + _log_probas[j][k])
        _topk_arg = np.argsort(_candidate_scores)[-topk:]  # 从中选出新的topk
        target_ids = [_candidate_ids[k] for k in _topk_arg]
        target_scores = [_candidate_scores[k] for k in _topk_arg]
        best_one = np.argmax(target_scores)
        if target_ids[best_one][-1] == 3:
            return tokenizer.decode(target_ids[best_one])
    # 如果title_max_len字都找不到结束符,直接返回
    return tokenizer.decode(target_ids[np.argmax(target_scores)])


def just_show():
    s1 = u'夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及 时 就医 。'
    s2 = u'8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10' \
         u'余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5' \
         u'亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午 ,华 住集 ' \
         u'团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。 '
    for s in [s1, s2]:
        print(u'生成标题:', gen_sent(s))
    print()


class Evaluate(keras.callbacks.Callback):
    def __init__(self):
        self.lowest = 1e10

    def on_epoch_end(self, epoch, logs=None):
        # 保存最优
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save_weights('best_model.weights')
        # 演示效果
        just_show()


if __name__ == '__main__':

    evaluator = Evaluate()
    train_generator = data_generator(txts, batch_size)

    model.fit_generator(train_generator.forfit(),
                        steps_per_epoch=steps_per_epoch,
                        epochs=epochs,
                        callbacks=[evaluator])

else:

    model.load_weights('best_model.weights')

文章首发于知乎,欢迎转载。

代码语言:javascript
复制

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档