Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >训练TensorFlow期间面临的问题(BatchNormV3错误)

训练TensorFlow期间面临的问题(BatchNormV3错误)
EN

Stack Overflow用户
提问于 2021-09-21 04:09:37
回答 1查看 43关注 0票数 1

在用于机器翻译的变压器网络的训练期间,GPU显示此错误。为什么会出现这个问题?

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Traceback (most recent call last):
  File "D:/Transformer_MC__translation/model.py", line 64, in <module>
    output = model(train, label)
  File "C:\Users\Devanshu\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1012, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "D:\Transformer_MC__translation\transformer.py", line 36, in call
    enc_src = self.encoder(src, src_mask)
  File "C:\Users\Devanshu\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1012, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "D:\Transformer_MC__translation\encoder.py", line 23, in call
    output = layer(output, output, output, mask)
  File "C:\Users\Devanshu\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1012, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "D:\Transformer_MC__translation\transformerblock.py", line 22, in call
    x = self.dropout(self.norm1(attention+query))
  File "C:\Users\Devanshu\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1012, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "C:\Users\Devanshu\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\keras\layers\normalization.py", line 1293, in call
    outputs, _, _ = nn.fused_batch_norm(
  File "C:\Users\Devanshu\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "C:\Users\Devanshu\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\ops\nn_impl.py", line 1660, in fused_batch_norm
    y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
  File "C:\Users\Devanshu\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\ops\gen_nn_ops.py", line 4255, in fused_batch_norm_v3
    _ops.raise_from_not_ok_status(e, name)
  File "C:\Users\Devanshu\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\python\framework\ops.py", line 6862, in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InternalError: cuDNN launch failure : input shape ([1,4928,256,1]) [Op:FusedBatchNormV3]

这是编码器块

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import tensorflow as tf
from selfattention import SelfAttention
from transformerblock import TransformerBlock

class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, embed_size, head, forward_expansion, dropout):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, head)
        self.norm = tf.keras.layers.LayerNormalization()
        self.transformer_block = TransformerBlock(embed_size, head, dropout=dropout, forward_expansion=forward_expansion)
        self.dropout = tf.keras.layers.Dropout(dropout)

    def call(self, inputs, key, value, src_mask, trg_mask):
        attention = self.attention(inputs, inputs, inputs, trg_mask)
        # skip connection
        query = self.dropout(self.norm(attention + inputs))
        print(query.shape)

        output = self.transformer_block(value, key, query, src_mask)

        return output

attention+input的输出形状为(64,80,250) (批量大小,检测长度,单词大小)

EN

回答 1

Stack Overflow用户

发布于 2021-09-21 04:19:27

您可以进行的解决问题的可能尝试。我曾经遇到过这个问题,当时我试图使用非常大的批处理大小,并通过减少它解决了这个问题。

  • 减少batch_size参数。逐渐增加它(2,4,8,10 etc.)
  • Sometimes当出现这样的cuDNN内部错误时,是由于库installations.

中的不匹配

确保正确安装了所有依赖项(TF+CUDNN+CUDA),并在确定安装正确后减少batch_size

在您的情况下,我怀疑问题是由于大批量。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69268651

复制
相关文章
边缘计算面临的问题
目前边缘计算已经得到了各行各业的广泛重视,并且在很多应用场景下开花结果。根据边缘计算领域特定的特点,本文认为6个方向是未来几年迫切需要解决的问题:编程模型、软硬件选型、基准程序与标准、动态调度、与垂直行业的紧密结合以及边缘节点的落地。
边缘计算
2019/07/03
2.7K0
边缘计算面临的问题
春节期间,读者留言最多的问题
这几天我抽空看了以前文章的留言,很多读者对动态规划问题的 base case、备忘录初始值等问题存在疑问。
labuladong
2021/09/23
3150
并发面临的问题小结
在单核CPU机器下,也可以支持并发多线程执行代码,这个时候CPU会为每一个线程分配对应的时间片,通过在指定的时间片内执行对应的线程程序代码,时间片一到,线程再继续争抢CPU资源重复上述动作,CPU需要不断地进行来回切换上下文以便能够执行到争抢到资源的线程,开发人员可以在linux系统下通过vmstat查看的context switch,即cs表示上下文
keithl
2020/03/10
6580
并发面临的问题小结
Tensorflow中遇到的错误
TypeError: Input 'b' of 'MatMul' Op has type float32 that does not match type int32 of argument 'a'. loss = tf.reduce_mean( tf.nn.nce_loss(nce_weights, nce_biases, embed, train_labels, num_sampled, vocabulary_size)) 解决方案,修改 embed, train_
听城
2018/04/27
2.3K0
tensorflow timeout错误
本文由腾讯云+社区自动同步,原文地址 https://stackoverflow.club/socket-timeout-error-tensorflow/
羽翰尘
2019/11/20
9810
tensorflow错误总结
本文主要介绍了在编写 TensorFlow 代码时可能会遇到的一些常见错误,包括如何正确设置输入 shape、如何指定数据 dtype、如何创建常量以及如何使用 tf.get_variable() 等。同时,还提供了一些解决这些问题的 tips 和技巧,以帮助开发者更好地使用 TensorFlow。
ke1th
2018/01/02
6990
存储Tensorflow训练网络的参数
训练一个神经网络的目的是啥?不就是有朝一日让它有用武之地吗?可是,在别处使用训练好的网络,得先把网络的参数(就是那些variables)保存下来,怎么保存呢?其实,tensorflow已经给我们提供了很方便的API,来帮助我们实现训练参数的存储与读取,如果想了解详情,请看晦涩难懂的官方API,接下来我简单介绍一下我的理解。 保存与读取数据全靠下面这个类实现: class tf.train.Saver 当我们需要存储数据时,下面2条指令就够了 saver = tf.train.Saver() save_pat
用户1332428
2018/03/30
1.1K0
使用TensorFlow训练WDL模型性能问题定位与调优
总第237篇 2018年 第29篇 简介 TensorFlow是Google研发的第二代人工智能学习系统,能够处理多种深度学习算法模型,以功能强大和高可扩展性而著称。TensorFlow完全开源,所以很多公司都在使用,但是美团点评在使用分布式TensorFlow训练WDL模型时,发现训练速度很慢,难以满足业务需求。 经过对TensorFlow框架和Hadoop的分析定位,发现在数据输入、集群网络和计算内存分配等层面出现性能瓶颈。主要原因包括TensorFlow数据输入接口效率低、PS/Worker算子分
美团技术团队
2018/06/07
2.8K3
TensorFlow版本带来的concat错误
TypeError: Expected int32, got list containing Tensors of type '_Message' instead.
Cloudox
2021/11/23
6080
安装 tensorflow 1.1.0;以及安装其他相似版本tensorflow遇到的问题;tensorflow 1.13.2 cuda-10环境变量配置问题;Tensorflow 指定训练时如何指定
tensorboard --logdir=/tmp/tensorflow/mnist/logs/mnist_with_summaries/ 
西湖醋鱼
2020/12/30
7180
tensorflow对象检测框架训练VOC数据集常见的两个问题
Tensorflow自从发布了object detection API这套对象检测框架以来,成为很多做图像检测与对象识别开发者手中的神兵利器,因为他不需要写一行代码,就可以帮助开发者训练出一个很好的自定义对象检测器(前提是有很多标注数据)。我之前曾经写过几篇文章详细介绍了tensorflow对象检测框架的安装与使用,感兴趣可以看如下几篇文章!
OpenCV学堂
2019/04/29
2.1K2
tensorflow对象检测框架训练VOC数据集常见的两个问题
AI 技术讲座精选:如何在时序预测问题中在训练期间更新LSTM网络
使用神经网络解决时间序列预测问题的好处是网络可以在获得新数据时对权重进行更新。 在本教程中,你将学习如何使用新数据更新长短期记忆(LTCM)递归神经网络。 在学完本教程后,你将懂得: 如何用新数据更
AI科技大本营
2018/04/26
1.5K0
AI 技术讲座精选:如何在时序预测问题中在训练期间更新LSTM网络
TensorFlow 组合训练数据(batching)
摘要总结:本文主要介绍了使用TensorFlow从TFRecord文件中读取数据,并将其组合成batch进行训练的过程。首先介绍了TensorFlow和TFRecord的基本概念,然后详细讲解了从TFRecord文件中读取数据的过程,包括使用TensorFlow的队列和线程进行数据读取和组合成batch的过程。最后通过一个例子演示了如何使用TensorFlow读取和组合成batch进行训练的过程。
chaibubble
2018/01/02
2K0
TensorFlow 组合训练数据(batching)
实例分割综述_实例分割面临的问题
广泛应用于深度学习中提取特征的卷积操作具有不变性,这限制了网络精确定位目标的能力。
全栈程序员站长
2022/09/23
3950
tensorflow常见错误记录
tensorflow TypeError: run() got multiple values for argument 'feed_dict' 原因分析:造成此错误的原因为:run()函数接收的fetches参数为一个列表、元组、或者字典,此错误是因为要获取的对象被当作多个参数,正确用法: a = tf.constant([10, 20]) b = tf.constant([1.0, 2.0]) # 'fetches' can be a singleton
听城
2018/11/09
6050
内外网数据交换面临的问题
近年来全球网络安全威胁态势的加速严峻,企业的网络安全体系建设正从“以合规为导向”转变到“以风险为导向”,从原来的“保护安全边界”转换到“保护核心数据资产”的思路上来。
企业文件数据交换
2019/08/26
2.3K0
内外网数据交换面临的问题
TensorFlow应用实战 | 编写训练的python文件
一个错误的个人使用,因为我的TensorFlow版本较老。keras并没有被集成进来。
用户1332428
2018/07/30
6160
TensorFlow应用实战 | 编写训练的python文件
tensorflow版本的tansformer训练IWSLT数据集
代码来源:https://github.com/Kyubyong/transformer
西西嘛呦
2020/08/26
1.9K0
【TensorFlow】使用迁移学习训练自己的模型
最近在研究tensorflow的迁移学习,网上看了不少文章,奈何不是文章写得不清楚就是代码有细节不对无法运行,下面给出使用迁移学习训练自己的图像分类及预测问题全部操作和代码,希望能帮到刚入门的同学。
刘早起
2020/04/23
2.2K0
点击加载更多

相似问题

tensorflow在训练期间没有改进

223

Tensorflow Slim在训练期间的调试

312

安装tensorflow所面临的问题

16

导入Tensorflow时面临的问题

112

Tensorflow (2.9.1):在训练期间更改层上的“可训练”属性

11
添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文