专栏首页小鹏的专栏tf44:tensorflow CRF的使用

tf44:tensorflow CRF的使用

版权声明:本文为博主原创文章,未经博主允许不得转载。有问题可以加微信:lp9628(注明CSDN)。 https://blog.csdn.net/u014365862/article/details/81298373

MachineLP的Github(欢迎follow):https://github.com/MachineLP

CRF的应用还是挺多的,像前期deeplab的语义分割、bilstm+crf做词性标注。

CRF简单的例子:

# coding=utf-8

import numpy as np
import tensorflow as tf

# 参数设置
num_examples = 10
num_words = 20
num_features = 100
num_tags = 5

# 构建随机特征
x = np.random.rand(num_examples, num_words, num_features).astype(np.float32)

# 构建随机tag
y = np.random.randint(
    num_tags, size=[num_examples, num_words]).astype(np.int32)

# 获取样本句长向量(因为每一个样本可能包含不一样多的词),在这里统一设为 num_words - 1,真实情况下根据需要设置
sequence_lengths = np.full(num_examples, num_words - 1, dtype=np.int32)

# 训练,评估模型
with tf.Graph().as_default():
    with tf.Session() as session:
        x_t = tf.constant(x)
        y_t = tf.constant(y)
        sequence_lengths_t = tf.constant(sequence_lengths)

        # 在这里设置一个无偏置的线性层
        weights = tf.get_variable("weights", [num_features, num_tags])
        matricized_x_t = tf.reshape(x_t, [-1, num_features])
        matricized_unary_scores = tf.matmul(matricized_x_t, weights)
        unary_scores = tf.reshape(matricized_unary_scores,
                                  [num_examples, num_words, num_tags])

        # 计算log-likelihood并获得transition_params
        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
            unary_scores, y_t, sequence_lengths_t)

        # 进行解码(维特比算法),获得解码之后的序列viterbi_sequence和分数viterbi_score
        viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
            unary_scores, transition_params, sequence_lengths_t)

        loss = tf.reduce_mean(-log_likelihood)
        train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

        session.run(tf.global_variables_initializer())

        mask = (np.expand_dims(np.arange(num_words), axis=0) <     # np.arange()创建等差数组
                np.expand_dims(sequence_lengths, axis=1))          # np.expand_dims()扩张维度

        # 得到一个num_examples*num_words的二维数组,数据类型为布尔型,目的是对句长进行截断

        # 将每个样本的sequence_lengths加起来,得到标签的总数
        total_labels = np.sum(sequence_lengths)

        # 进行训练
        for i in range(1000):
            tf_viterbi_sequence, _ = session.run([viterbi_sequence, train_op])
            if i % 100 == 0:
                correct_labels = np.sum((y == tf_viterbi_sequence) * mask)
                accuracy = 100.0 * correct_labels / float(total_labels)
                print("Accuracy: %.2f%%" % accuracy)

转自:https://blog.csdn.net/guolindonggld/article/details/79044574

Bi-LSTM

使用TensorFlow构建Bi-LSTM时经常是下面的代码:

cell_fw = tf.contrib.rnn.LSTMCell(num_units=100)
cell_bw = tf.contrib.rnn.LSTMCell(num_units=100)

(outputs, output_states) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, 
sequence_length=300)

首先下面是我画的Bi-LSTM示意图:

其实LSTM使用起来很简单,就是输入一排的向量,然后输出一排的向量。构建时只要设定两个超参数:num_unitssequence_length

LSTMCell

tf.contrib.rnn.LSTMCell(
    num_units,
    use_peepholes=False,
    cell_clip=None,
    initializer=None,
    num_proj=None,
    proj_clip=None,
    num_unit_shards=None,
    num_proj_shards=None,
    forget_bias=1.0,
    state_is_tuple=True,
    activation=None,
    reuse=None
)

上面的LSTM Cell只有一个超参数需要设定,num_units,即输出向量的维度。

bidirectional_dynamic_rnn()

(outputs, output_states) = tf.nn.bidirectional_dynamic_rnn(
    cell_fw,
    cell_bw,
    inputs,
    sequence_length=None,
    initial_state_fw=None,
    initial_state_bw=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

这个函数唯一需要设定的超参数就是序列长度sequence_length

输入: inputs的shape通常是[batch_size, sequence_length, dim_embedding]。

输出: outputs是一个(output_fw, output_bw)元组,output_fw和output_bw的shape都是[batch_size, sequence_length, num_units]

output_states是一个(output_state_fw, output_state_bw) 元组,分别是前向和后向最后一个Cell的Output,output_state_fw和output_state_bw的类型都是LSTMStateTuple,这个类有两个属性c和h,分别表示Memory Cell和Hidden State,如下图:

CRF

对于序列标注问题,通常会在LSTM的输出后接一个CRF层:将LSTM的输出通过线性变换得到维度为[batch_size, max_seq_len, num_tags]的张量,这个张量再作为一元势函数(Unary Potentials)输入到CRF层。

# 将两个LSTM的输出合并
output_fw, output_bw = outputs
output = tf.concat([output_fw, output_bw], axis=-1)

# 变换矩阵,可训练参数
W = tf.get_variable("W", [2 * num_units, num_tags])

# 线性变换
matricized_output = tf.reshape(output, [-1, 2 * num_units])
matricized_unary_scores = tf.matmul(matricized_output , W)
unary_scores = tf.reshape(matricized_unary_scores, [batch_size, max_seq_len, num_tags])

损失函数

# Loss函数
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(unary_scores, tags, sequence_lengths)
loss = tf.reduce_mean(-log_likelihood)

其中 tags:维度为[batch_size, max_seq_len]的矩阵,也就是Golden标签,注意这里的标签都是以索引方式表示的。 sequence_lengths:维度为[batch_size]的向量,记录了每个序列的长度。

log_likelihood:维度为[batch_size]的向量,每个元素代表每个给定序列的Log-Likelihood。 transition_params :维度为[num_tags, num_tags]的转移矩阵。注意这里的转移矩阵不像传统的HMM概率转移矩阵那样要求每个元素非负且每一行的和为1,这里的每个元素取值范围是实数(正负都可以)。

解码

decode_tags, best_score = tf.contrib.crf.crf_decode(unary_scores, transition_params, sequence_lengths)

其中 decode_tags:维度为[batch_size, max_seq_len]的矩阵,包含最高分的标签序列。 best_score :维度为[batch_size]的向量,包含最高分数。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • tf API 研读3:Building Graphs

    tensorflow是通过计算图的方式建立网络。 比喻说明: 结构:计算图建立的只是一个网络框架。编程时框架中不会出现任何的实际值,所有权重(weight)和偏...

    MachineLP
  • tf API 研读4:Inputs and Readers

    tensorflow中数据的读入相关类或函数: 占位符(Placeholders) tf提供一种占位符操作,在执行时需要为其提供数据data。 操作 描...

    MachineLP
  • tf24: GANs—生成明星脸

    GANs是Generative Adversarial Networks的简写,中文翻译为生成对抗网络,它最早出现在2014年Goodfellow发表的论文中...

    MachineLP
  • Golang之(if)流程控制

    超蛋lhy
  • 【leetcode刷题】T210-最大交换

    https://leetcode-cn.com/problems/maximum-swap/

    木又AI帮
  • 软件分享 | ZoomIt 4.5 演示辅助工具使用教程

    ZoomIt是一款非常实用的投影演示辅助软件。它源自Sysinternals公司,后来此公司被微软收购,因此,有些网友也称ZoomIt为微软放大镜。ZoomIt...

    课代表
  • 【Go 语言社区】POJ 1047 Round and Round We Go 循环数新解

    题目描述: 给定一字符串表示的高精度数,判断它是否是可循环的。如果假设字符串num的长为n,则将num从1开始乘到n,如果每次得到的结果包含的字符元素都和a是相...

    李海彬
  • 身份证号码校验算法

    1、数字含义 中国大陆第二代身份证号码由18位数据或字母组成,每位数据都有特定的含义,结果如下: ? 每组数字都有不同的含义: 第1至2位数字代表所在省(直辖市...

    Crossin先生
  • 李沐新书中文版上线,零基础也可以《动手学深度学习》| 这不是0.7版

    不变的是,作为一本交互式教材,除了书本身,它也提供简易的代码实现,以及一个高手如林的讨论区,让大家互相取暖切磋技艺。

    量子位
  • 『 懒人10分钟—Linux学习篇(三)』文件/目录的权限

    想必经常和服务器打交道的朋友,对于Linux可谓又爱又恨。对于项目组、运维人员、或者有多人需要对服务器进行操作的人,离不开Linux关于权限的管理。Linux的...

    23号杂货铺

扫码关注云+社区

领取腾讯云代金券