上一节我们说到了
convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer)
这个函数,里面又分别调用了:
loc, mas, e1_mas, e2_mas = prepare_extra_data(mapping_a, example.locations, FLAGS.max_distance)
而在prepare_extra_data里面调用了两个函数:
convert_entity_row(mapping, e, max_distance)
find_lo_hi(mapping, lo)
我们一步步从prepare_extra_data里面看起:
res = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
mas = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
e1_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)
e2_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)
先总体对这些是什么有个大概的了解: (1)res:存储的是相对位置,是一个[128,128]的数组,这里的128是句子的最大长度。这个数组记录的是实体和其它词之间的相对位置。 (2)mas:存储的是实体的mask矩阵,也就是每个句子中实体出现的位置就是1,其它的就是0,也是一个[128,128]的数组 (3)e1_mas:在每一对关系中实体1的掩码矩阵,维度是[12,128],其中12是设置的最大的关系种类数。 (4)e2_mas:在每一对关系中实体2的掩码矩阵,维度是[12,128],其中12是设置的最大的关系种类数。
entities = set()
for loc in locs:
entities.add(loc[0])
entities.add(loc[1])
for e in entities:
(lo, hi) = e
relative_position, _ = convert_entity_row(mapping, e, max_distance)
sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
if sub_lo1 == 0 and sub_hi1 == 0:
continue
if sub_lo2 == 0 and sub_hi2 == 0:
continue
# col
res[:, sub_lo1:sub_hi2+1] = np.expand_dims(relative_position, -1)
mas[1:, sub_lo1:sub_hi2+1] = 1
我们先看下输出:
example.text_a = a large database . Traditional information retrieval techniques use a histogram of keywords as the document representation but oral communication may offer additional indices such as the time and is shown on a large database of TV shows . Emotions and other indices
tokens_a = ['a', 'large', 'database', '.', 'traditional', 'information', 'retrieval', 'techniques', 'use', 'a', 'his', '##to', '##gram', 'of', 'key', '##words', 'as', 'the', 'document', 'representation', 'but', 'oral', 'communication', 'may', 'offer', 'additional', 'indices', 'such', 'as', 'the', 'time', 'and', 'is', 'shown', 'on', 'a', 'large', 'database', 'of', 'tv', 'shows', '.', 'emotions', 'and', 'other', 'indices']
mapping_a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 11, 11, 12, 13, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]
example.locations = [((13, 13), (6, 8)), ((19, 20), (24, 24)), ((37, 38), (35, 35))]
entities = {(6, 8), (13, 13), (35, 35), (37, 38), (24, 24), (19, 20)}
对于每一个实体的位置,调用了relative_position, _ =convert_entity_row(mapping, e, max_distance),这个函数:
def convert_entity_row(mapping, loc, max_distance):
"""
convert an entity span(lo,hi) to a relative distance vector of shape [max_seq_length]
"""
lo, hi = loc
res = [max_distance] * FLAGS.max_seq_length
mas = [0] * FLAGS.max_seq_length
for i in range(FLAGS.max_seq_length):
if i < len(mapping):
val = mapping[i]
if val < lo - max_distance:
res[i] = max_distance
elif val < lo:
res[i] = lo - val
elif val <= hi:
res[i] = 0
mas[i] = 1
elif val <= hi + max_distance:
res[i] = val - hi + max_distance
else:
res[i] = 2 * max_distance
else:
res[i] = 2 * max_distance
return res, mas
的输出是:
lo = 6
hi = 8
res = [4, 4, 4, 3, 2, 1, 0, 0, 0, 5, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
mas = [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
relative_position = [4, 4, 4, 3, 2, 1, 0, 0, 0, 5, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
设置最大距离为4。在res中,对于实体而言,其相对位置为0,当实体左边的字和实体左边边界的距离小于定义的最大距离时,值就是距离值,否则左边的就都是最大距离值。同理右边也是这样,只不过是从最大值开始,到最大值的两倍结束,需要注意的是由于是wordpiece拆分的,对于一个单词而言,如果拆分成了几个,那么他们的位置是一致的,比如上面的7,7,7。如果不好理解的话,直接看上面的结果就能理解了。 对于:
def find_lo_hi(mapping, value):
"""
find the boundary of a value in a list
will return (0,0) if no such value in the list
"""
try:
lo = mapping.index(value)
hi = min(len(mapping) - 1 - mapping[::-1].index(value), FLAGS.max_seq_length)
return (lo, hi)
except:
return (0,0)
这个而言,由于我们会进行wordpiece的拆分,因此实体在分词后的索引有可能是变换的,因此对于hi,我们要反向索引。
for e in entities:
(lo, hi) = e
relative_position, _ = convert_entity_row(mapping, e, max_distance)
sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
if sub_lo1 == 0 and sub_hi1 == 0:
continue
if sub_lo2 == 0 and sub_hi2 == 0:
continue
# col
res[:, sub_lo1:sub_hi2+1] = np.expand_dims(relative_position, -1)
mas[1:, sub_lo1:sub_hi2+1] = 1
for e in entities:
(lo, hi) = e
relative_position, _ = convert_entity_row(mapping, e, max_distance)
sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
if sub_lo1 == 0 and sub_hi1 == 0:
continue
if sub_lo2 == 0 and sub_hi2 == 0:
continue
# row
res[sub_lo1:sub_hi2+1, :] = relative_position
mas[sub_lo1:sub_hi2+1, 1:] = 1
结果是这样的:
[[0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 3 3 3 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 2 2 2 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[0 0 0 0 0 0 5 5 5 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 6 6 6 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
...
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
最后就是实体的掩码矩阵了:
for idx, (e1,e2) in enumerate(locs):
# e1
(lo, hi) = e1
_, mask = convert_entity_row(mapping, e1, max_distance)
e1_mas[idx] = mask
# e2
(lo, hi) = e2
_, mask = convert_entity_row(mapping, e2, max_distance)
e2_mas[idx] = mask
结果:
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
...
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
label_id = [label_map[label] for label in example.labels]
label_id = label_id + [0] * (FLAGS.max_num_relations - len(label_id))
cls_mask = [1] * example.num_relations + [0] * (FLAGS.max_num_relations - example.num_relations)
这里定义了一个最大关系数量:12。先看结果:
labels: 5 5 2 0 0 0 0 0 0 0 0 0
cls_mask:[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
也就是一句话中的句子有多种关系的实体。 最终将这些信息包装为InputFeatures类并返回。
def file_based_convert_examples_to_features(
examples, label_list, max_seq_length, tokenizer, output_file):
"""Convert a set of `InputExample`s to a TFRecord file."""
writer = tf.python_io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
feature = convert_single_example(ex_index, example, label_list,
max_seq_length, tokenizer)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["loc"] = create_int_feature(feature.loc)
features["mas"] = create_int_feature(feature.mas)
features["e1_mas"] = create_int_feature(feature.e1_mas)
features["e2_mas"] = create_int_feature(feature.e2_mas)
features["cls_mask"] = create_int_feature(feature.cls_mask)
features["label_ids"] = create_int_feature(feature.label_id)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
也没什么好说的,转换成tensorflow中训练所需的张量后存储起来就行了。 至此,mre-in-one-pass的数据处理部分就完了。
参考代码:https://sourcegraph.com/github.com/helloeve/mre-in-one-pass/-/blob/run_classifier.py#L550