前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >解决Keras中循环使用K.ctc_decode内存不释放的问题

解决Keras中循环使用K.ctc_decode内存不释放的问题

作者头像
砸漏
发布2020-10-21 14:36:12
1.7K0
发布2020-10-21 14:36:12
举报
文章被收录于专栏:恩蓝脚本

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

代码语言:javascript
复制
data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

代码语言:javascript
复制
data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

代码语言:javascript
复制
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
'''
用于计算CTC loss
'''
def ctc_lambda_func(self,args):
"""Runs CTC loss algorithm on each batch element.
# Arguments
y_true: tensor `(samples, max_string_length)` 真实标签
y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
input_length: tensor `(samples, 1)` 每一个y_pred的长度
label_length: tensor `(samples, 1)` 每一个y_true的长度
# Returns
Tensor with shape (samples,1) 包含了每一个样本的ctc loss
"""
y_true, y_pred, input_length, label_length = args
# y_pred = y_pred[:, :, :]
# y_pred = y_pred[:, 2:, :]
return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)
def __call__(self, args):
'''
ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
'''
y_true = Input(shape=(None,))
y_pred = Input(shape=(None,None))
input_length = Input(shape=(1,))
label_length = Input(shape=(1,))
lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")
# return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
return model(args)
def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
"""Runs CTC loss algorithm on each batch element.
# Arguments
y_true: tensor `(samples, max_string_length)`
containing the truth labels.
y_pred: tensor `(samples, time_steps, num_categories)`
containing the prediction, or output of the softmax.
input_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_pred`.
label_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_true`.
# Returns
Tensor with shape (samples,1) containing the
CTC loss of each element.
"""
label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))
y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)
# 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
labels=sparse_labels,
sequence_length=input_length,
ignore_longer_outputs_than_inputs=True), 1)
# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
代码语言:javascript
复制
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTCDecodeLayer(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _ctc_decode(self,args):
base_pred, in_len = args
in_len = K.squeeze(in_len,axis=-1)
r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
r1 = r[0][0]
prob = r[1][0]
return [r1,prob]
def call(self, inputs, **kwargs):
return self._ctc_decode(inputs)
def compute_output_shape(self, input_shape):
return [(None,None),(1,)]
class CTCDecode():
'''用与CTC 解码,得到真实语音序列
2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
'''
def __init__(self):
base_pred = Input(shape=[None,None],name="pred")
feature_len = Input(shape=[1,],name="feature_len")
r1, prob = CTCDecodeLayer()([base_pred,feature_len])
self.model = Model([base_pred,feature_len],[r1,prob])
pass
def ctc_decode(self,base_pred,in_len,return_prob = False):
'''
:param base_pred:[sample,timestamp,vector]
:param in_len: [sample,1]
:return:
'''
result,prob = self.model.predict([base_pred,in_len])
if return_prob:
return result,prob
return result
def __call__(self,base_pred,in_len,return_prob = False):
return self.ctc_decode(base_pred,in_len,return_prob)
# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len) 

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-09-11 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档