首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >tf.segment_max错误:段ids未按1递增

tf.segment_max错误:段ids未按1递增
EN

Stack Overflow用户
提问于 2016-10-18 21:38:02
回答 1查看 552关注 0票数 4

tf.segment_max()和其他网段运维通常是require the segment IDs to be consecutive。将此操作应用于动态生成的批处理并使用tf.unique()定义段时,段ID可能不连续,从而产生错误。两种情况(错误/无错误)如下所示:

代码语言:javascript
运行
复制
with tf.Session() as sess:
    data = tf.constant([[1, 3, 5], [6, 2, 7], [9, 9, 2], [9, 5, 1]])
    #labels = ['a','r','r','d'] # this works
    labels = ['a','r','d','r']  # error: segment ids are not increasing by 1
    y, idx = tf.unique(labels) 
    maxs = tf.segment_max(data, idx)
    rval = sess.run([idx, maxs])
    print('indices: ', rval[0])
    print('maxs: ', rval[1])

一般的段ID不连续的情况如何处理?

EN

回答 1

Stack Overflow用户

发布于 2021-06-21 06:21:38

考虑到这个问题已经发布了很长时间,我不确定OP是否已经找到了答案。然而,我将为任何可能最终遇到同样问题的人提供解决方案。

这里的诀窍是,使用tf.math.unsorted_segment_max().It非常简单。例如,考虑以下代码-

代码语言:javascript
运行
复制
# Data
data = tf.constant([0.4, 0.5, 0.2, 0.7, 0.1], dtype=tf.float32)
# Unique segements with Segment IDs
unique_elems, segment_ids = tf.unique(tf.constant([1, 2, 1, 1, 2], dtype=tf.int32))
# Count the number of unique segments.
nb_segments = tf.size(tf.unique(segment_ids)[0])
# Result
result = tf.math.unsorted_segment_max(
    data=data, segment_ids=segment_ids, num_segments=nb_segments
)
tf.print(result) 
# result turns out to be [0.7, 0.5]  

在上面的代码中,我们首先使用tf.unique()方法查找具有相同值的索引。然后,我们使用tf.unsorted_segment_max()方法找到与这些索引片段对应的data张量中的最大值。

你可以阅读更多关于它的here

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/40109854

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档