前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >明月机器学习系列030:特殊二分图的最优匹配算法

明月机器学习系列030:特殊二分图的最优匹配算法

作者头像
明月AI
发布2021-10-28 14:27:50
8130
发布2021-10-28 14:27:50
举报
文章被收录于专栏:野生AI架构师

1. 缘起


最近开发文档识别与比对,经常遇到的一个问题就是谁跟谁应该配对在一起,例如:

  • 两个页面上的文本行,哪行跟哪行应该是对应的?
  • 两份文档中都有若干个表格,哪个表格跟哪个表格应该是对应的?
  • 两个表格都会包含若干的单元格,这些单元格哪个跟哪个是对应的?

开始时,想得比较简单,因为看上去问题也不复杂嘛。

2. 算法的第一个版本


把问题抽象一下,其实不管是单元格,表格,还是文本行都可以看成是一个个的元素,于是我们的问题就成了在两个有序的序列中寻找一个最优的匹配,每个元素最多能跟一个元素进行匹配(可以没有匹配),如下图:

上图显示的就是左右两个集合中的一种匹配,但是这种匹配的形式跟我们的要求是不符合的。我们的场景下,我们匹配的左右两边并不是无序集合,而是两个有序的序列。例如从上图来理解,如果1和6匹配上之后,2则只能和6后面的7或者8进行匹配,所以上图的4这个元素开倒车不符合规则(如果把4和5这两个元素之间的边去掉的话,则是满足条件)。

定义:边就是两个之间的连线。

2.1 算法的目标

我们既然要找到最优的匹配,但是怎么才算是最优呢?这就是要求我们先定义一个数值指标,以此来衡量优劣。这也比较简单,对每条连线的权重求和,以此作为衡量指标。连线的权重则采用两个元素之间的相似性,如果这两个元素是两行文本,我们可以直接使用编辑距离计算相似性(至于那种距离更加合适,就得看具体的场景了。余弦距离,杰卡德距离,甚至语义距离等都可以,只要合适)。

简单说就是,对元素之间的相似性得分进行求和

2.2 算法思路

有了目标,那看起来就比较简单了,直接从左边元素随机取一个子集,然后再右边元素也随机取一个相同元素个数的子集,再按顺序对应上,就能计算一个得分指标。

这个思路简单,实现也简单,直接上代码:

代码语言:javascript
复制
    def less_match(self):
        """当数据比较小时,可以使用穷举匹配"""
        max_score = 0
        items = np.array([])
        max_num = min(len(self.seq1), len(self.seq2))
        for num in range(max_num, 0, -1):
            tmp_score, tmp_items = self.match_num(num)
            if tmp_score >= max_score:
                # 两个空元素配对在一起,去掉之后得分不会改变,但是空元素不应该配对在一起
                items = tmp_items
                max_score = tmp_score
            else:
                # 如果不能产生更好的值,则退出
                break

        return items
        
     def match_num(self, num):
        """从两个序列中分别提取num个元素进行匹配"""
        # print('match num: ', num)
        max_score = 0
        comb_match = None
        # 提取两个序列的下标子集合
        comb1 = combinations(range(len(self.seq1)), num)
        comb2 = list(combinations(range(len(self.seq2)), num))
        for comb_i in comb1:
            for comb_j in comb2:
                tmp_score = self.cal_comb_score(comb_i, comb_j)
                if tmp_score is None:
                    continue
                if tmp_score >= max_score:
                    max_score = tmp_score
                    comb_match = (comb_i, comb_j)

        # 生成配对items
        if comb_match is None:
            return 0, []
        items = np.array((comb_match[0], comb_match[1])).T
        return max_score, items

    def cal_comb_score(self, comb_i, comb_j):
        """计算集合得分"""
        where = (np.array(comb_i, dtype=int), np.array(comb_j, dtype=int))
        scores = self.scores[where]
        if self.min_score is not None and np.min(scores) < self.min_score:
            return None
        return np.sum(scores)

暴力出奇迹(指时间),很快就完成了第一个版本。这个版本其实已经做了部分的剪枝,已经部分计算已经提前,例如元素之间的距离就是预先计算好放到scores中。

但是显然这个版本存在巨大的性能问题。

3. 优化版本


上面的算法在数据量小的时候,还没有问题,但是数据量稍大一点,因为取集合的方式是指数级的,想不废都难。

3.1 剪枝优化

剪枝1. 在我们的场景中,相似度得分大于0,但是其值却很小的边通常是没有意义的,这样我们就可以通过阈值参数直接过滤掉这部分的边。

剪枝2. 仔细分析上面的暴力算法,就会发现其实很多计算是多余的,因为在我们的场景中,一个元素通常只会和另一个序列中附近的元素产生联系,和位置相差比较远的元素产生联系的可能性是很小的,但是在计算编辑距离时,却有可能联系在一起。例如左右两边的序列都有50个元素,左边的第一个元素值恰好和右边元素的最后一个元素的值完全相同,这时他们这两个元素的相似性得分最大,但是这基本是不可能的。于是我们可以考虑将位置因素整合到权重得分上。

代码语言:javascript
复制
        # 剪枝:其值却很小的边通常是没有意义的
        # self.min_score: 这个是算法的参数,可以根据不同的场景选择不同的阈值
        where_i, where_j = np.where(self.scores > self.min_score)
        len_j = len(where_j)
        # 优化得分: 将位置影响整合到边的权重上
        for j, val_j in enumerate(where_j):
            # 正常来说,where_j是按顺序排序的
            # 如果前面有比当前值大,或者后面有比当前值小,这两种情况都是不常见的,可以减少其权重
            err_num = np.count_nonzero(where_j[:j] > val_j)
            err_num += np.count_nonzero(where_j[j:] < val_j)
            self.scores[where_i[j], val_j] *= (len_j-err_num)/(len_j)

这段代码实现了前面两个剪枝的方式。这里融合位置的方式设计上比较特别,具体可以看代码注释。

剪枝3. 基于第一点的分析,我们还可以在预先计算相似性得分的,只计算相邻位置的元素之间的边的相似性得分,其他的全部置为0。

代码语言:javascript
复制
        # 计算得分
        len1, len2 = len(seq1), len(seq2)
        # 计算窗口的开始和结束位置
        start, end = -window, window
        if len2 >= len1:
            end += len2 - len1
        else:
            start += len2 - len1

        scores = np.zeros((len1, len2))
        for i, s1 in enumerate(seq1):
            # 一个元素通常只会和另一个序列中相邻的元素产生联系
            w_start, w_end = max(0, i+start), min(len2, i+end)
            scores[i][w_start:w_end] = [score_func(s1, s2) for s2 in seq2[w_start:w_end]]

3.2 计算优化

元素与元素之间的边的权重已经计算出来了,我们不再使用遍历集合这种暴力的方式,而是先找连通子图,然后在每个连通图的内部删掉一些多余的边,使得每个元素最多只和一个元素联通,并且保证每个联通子图删掉多余的边之后,相似度得分是最高的。简单说就是保证每个联通子图的最优来保证全局最优(当然这不一定成立,但是概率很小,而且即使不是全局最优,也和全局最优相差不多了,所以可以忽略)。连通图计算可以直接使用networkx包中的connected_components函数。

代码行数比较多,就不凑字数了,具体看:https://github.com/ibbd-dev/python-ibbd-algo/blob/master/ibbd_algo/sequence.py

经过这个优化,在我们的场景下,性能基本没什么问题了。

4. 后续思考


后来查资料得知,图论里专门有一种叫二分图,还有相关的算法,不过我们的场景却比较特别,算是一种特殊的二分图吧。研究一下现有的二分图,应该还是有改进空间的。

附录:

源码:https://github.com/ibbd-dev/python-ibbd-algo/blob/master/ibbd_algo/sequence.py

20201230:这个文章上个月就开始写了,只是一直在草稿了,今晚算是补充完整了,自己也梳理了一遍。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-12-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 野生AI架构师 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
行业文档识别
行业文档识别(Document Optical Character Recognition,Document OCR)基于行业前沿的深度学习技术,支持将图片上的文字内容,智能识别为结构化的文本,可应用于智能核保、智能理赔、试题批改等多种行业场景,大幅提升信息处理效率。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档