前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【关系抽取-mre-in-one-pass】加载数据(二)

【关系抽取-mre-in-one-pass】加载数据(二)

作者头像
西西嘛呦
发布2021-04-12 11:13:13
4440
发布2021-04-12 11:13:13
举报
文章被收录于专栏:数据分析与挖掘

接上一节加载数据(一)

上一节我们说到了

代码语言:javascript
复制
convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer)

这个函数,里面又分别调用了:

代码语言:javascript
复制
loc, mas, e1_mas, e2_mas = prepare_extra_data(mapping_a, example.locations, FLAGS.max_distance)

而在prepare_extra_data里面调用了两个函数:

代码语言:javascript
复制
convert_entity_row(mapping, e, max_distance)
find_lo_hi(mapping, lo)

我们一步步从prepare_extra_data里面看起:

  • 一开始就定义了4个数组:
代码语言:javascript
复制
  res = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
  mas = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
  
  e1_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)
  e2_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)

先总体对这些是什么有个大概的了解: (1)res:存储的是相对位置,是一个[128,128]的数组,这里的128是句子的最大长度。这个数组记录的是实体和其它词之间的相对位置。 (2)mas:存储的是实体的mask矩阵,也就是每个句子中实体出现的位置就是1,其它的就是0,也是一个[128,128]的数组 (3)e1_mas:在每一对关系中实体1的掩码矩阵,维度是[12,128],其中12是设置的最大的关系种类数。 (4)e2_mas:在每一对关系中实体2的掩码矩阵,维度是[12,128],其中12是设置的最大的关系种类数。

  • 得到每一个关系的实体集合
代码语言:javascript
复制
entities = set()
for loc in locs:
    entities.add(loc[0])
    entities.add(loc[1])
  • 接下来是关键了
代码语言:javascript
复制
  for e in entities:
    (lo, hi) = e
    relative_position, _ = convert_entity_row(mapping, e, max_distance)
    sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
    sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
    if sub_lo1 == 0 and sub_hi1 == 0:
      continue
    if sub_lo2 == 0 and sub_hi2 == 0:
      continue
    # col
    res[:, sub_lo1:sub_hi2+1] = np.expand_dims(relative_position, -1)
    mas[1:, sub_lo1:sub_hi2+1] = 1

我们先看下输出:

代码语言:javascript
复制
example.text_a = a large database . Traditional information retrieval techniques use a histogram of keywords as the document representation but oral communication may offer additional indices such as the time and is shown on a large database of TV shows . Emotions and other indices
tokens_a = ['a', 'large', 'database', '.', 'traditional', 'information', 'retrieval', 'techniques', 'use', 'a', 'his', '##to', '##gram', 'of', 'key', '##words', 'as', 'the', 'document', 'representation', 'but', 'oral', 'communication', 'may', 'offer', 'additional', 'indices', 'such', 'as', 'the', 'time', 'and', 'is', 'shown', 'on', 'a', 'large', 'database', 'of', 'tv', 'shows', '.', 'emotions', 'and', 'other', 'indices']
mapping_a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 11, 11, 12, 13, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]
example.locations = [((13, 13), (6, 8)), ((19, 20), (24, 24)), ((37, 38), (35, 35))]
entities = {(6, 8), (13, 13), (35, 35), (37, 38), (24, 24), (19, 20)}

对于每一个实体的位置,调用了relative_position, _ =convert_entity_row(mapping, e, max_distance),这个函数:

代码语言:javascript
复制
def convert_entity_row(mapping, loc, max_distance):
  """
  convert an entity span(lo,hi) to a relative distance vector of shape [max_seq_length]
  """
  lo, hi = loc
  res = [max_distance] * FLAGS.max_seq_length
  mas = [0] * FLAGS.max_seq_length
  for i in range(FLAGS.max_seq_length):
    if i < len(mapping):
      val = mapping[i]
      if val < lo - max_distance:
        res[i] = max_distance
      elif val < lo:
        res[i] = lo - val
      elif val <= hi:
        res[i] = 0
        mas[i] = 1
      elif val <= hi + max_distance:
        res[i] = val - hi + max_distance
      else:
        res[i] = 2 * max_distance
    else:
      res[i] = 2 * max_distance
  return res, mas

的输出是:

代码语言:javascript
复制
lo = 6
hi = 8
res = [4, 4, 4, 3, 2, 1, 0, 0, 0, 5, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
mas = [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
relative_position = [4, 4, 4, 3, 2, 1, 0, 0, 0, 5, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]

设置最大距离为4。在res中,对于实体而言,其相对位置为0,当实体左边的字和实体左边边界的距离小于定义的最大距离时,值就是距离值,否则左边的就都是最大距离值。同理右边也是这样,只不过是从最大值开始,到最大值的两倍结束,需要注意的是由于是wordpiece拆分的,对于一个单词而言,如果拆分成了几个,那么他们的位置是一致的,比如上面的7,7,7。如果不好理解的话,直接看上面的结果就能理解了。 对于:

代码语言:javascript
复制
def find_lo_hi(mapping, value):
  """
  find the boundary of a value in a list
  will return (0,0) if no such value in the list
  """
  try:
    lo = mapping.index(value)
    hi = min(len(mapping) - 1 - mapping[::-1].index(value), FLAGS.max_seq_length)
    return (lo, hi)
  except:
    return (0,0)

这个而言,由于我们会进行wordpiece的拆分,因此实体在分词后的索引有可能是变换的,因此对于hi,我们要反向索引。

  • 接着就是将位置信息用矩阵的形式表现,也就是下面的两段代码:
代码语言:javascript
复制
  for e in entities:
    (lo, hi) = e
    relative_position, _ = convert_entity_row(mapping, e, max_distance)
    sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
    sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
    if sub_lo1 == 0 and sub_hi1 == 0:
      continue
    if sub_lo2 == 0 and sub_hi2 == 0:
      continue
    # col
    res[:, sub_lo1:sub_hi2+1] = np.expand_dims(relative_position, -1)
    mas[1:, sub_lo1:sub_hi2+1] = 1

  for e in entities:
    (lo, hi) = e
    relative_position, _ = convert_entity_row(mapping, e, max_distance)
    sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
    sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
    if sub_lo1 == 0 and sub_hi1 == 0:
      continue
    if sub_lo2 == 0 and sub_hi2 == 0:
      continue
    # row
    res[sub_lo1:sub_hi2+1, :] = relative_position
    mas[sub_lo1:sub_hi2+1, 1:] = 1

结果是这样的:

代码语言:javascript
复制
[[0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 3 3 3 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 2 2 2 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 [4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 [4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 [0 0 0 0 0 0 5 5 5 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 6 6 6 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 ...
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

最后就是实体的掩码矩阵了:

代码语言:javascript
复制
  for idx, (e1,e2) in enumerate(locs):
    # e1
    (lo, hi) = e1
    _, mask = convert_entity_row(mapping, e1, max_distance)
    e1_mas[idx] = mask
    # e2
    (lo, hi) = e2
    _, mask = convert_entity_row(mapping, e2, max_distance)
    e2_mas[idx] = mask

结果:

代码语言:javascript
复制
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 ...
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
  • 回到convert_single_example函数中来:
代码语言:javascript
复制
  label_id = [label_map[label] for label in example.labels]
  label_id = label_id + [0] * (FLAGS.max_num_relations - len(label_id))
  cls_mask = [1] * example.num_relations + [0] * (FLAGS.max_num_relations - example.num_relations)

这里定义了一个最大关系数量:12。先看结果:

代码语言:javascript
复制
labels: 5 5 2 0 0 0 0 0 0 0 0 0
cls_mask:[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]

也就是一句话中的句子有多种关系的实体。 最终将这些信息包装为InputFeatures类并返回。

  • 回到file_based_convert_examples_to_features函数:
代码语言:javascript
复制
def file_based_convert_examples_to_features(
    examples, label_list, max_seq_length, tokenizer, output_file):
  """Convert a set of `InputExample`s to a TFRecord file."""

  writer = tf.python_io.TFRecordWriter(output_file)

  for (ex_index, example) in enumerate(examples):
    if ex_index % 10000 == 0:
      tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

    feature = convert_single_example(ex_index, example, label_list,
                                     max_seq_length, tokenizer)

    def create_int_feature(values):
      f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
      return f

    features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(feature.input_ids)
    features["input_mask"] = create_int_feature(feature.input_mask)
    features["segment_ids"] = create_int_feature(feature.segment_ids)
    features["loc"] = create_int_feature(feature.loc)
    features["mas"] = create_int_feature(feature.mas)
    features["e1_mas"] = create_int_feature(feature.e1_mas)
    features["e2_mas"] = create_int_feature(feature.e2_mas)
    features["cls_mask"] = create_int_feature(feature.cls_mask)
    features["label_ids"] = create_int_feature(feature.label_id)

    tf_example = tf.train.Example(features=tf.train.Features(feature=features))
    writer.write(tf_example.SerializeToString())
  writer.close()

也没什么好说的,转换成tensorflow中训练所需的张量后存储起来就行了。 至此,mre-in-one-pass的数据处理部分就完了。

参考代码:https://sourcegraph.com/github.com/helloeve/mre-in-one-pass/-/blob/run_classifier.py#L550

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021-04-08 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 接上一节加载数据(一)
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档