网络流算法Push-relabel的Python实现

网络流的背景我就不多说了,就是在一个有向图中找出最大的流量,有意思的是,该问题的对偶问题为最小割,找到一种切分,使得图的两边的流通量最小,而且通常对偶问题是原问题的一个下界,但最小割正好等于最大流,即切割的边就是最大流中各个path饱和边的一个组合。说得可能比较含糊,这里想要了解清楚还是查阅相关资料吧。

最大流最原始最经典的解法就是FF算法,算法复杂度为O(mC),C为边的容量的总和,m为边数。而今天讲的Push-relabel算法是90年代提出的高效算法,复杂度为O(n^3),其实网络流最关键的步骤就是添加反向边,得出剩余图。而其他的改进就是为了在寻找增广路径时尽可能贪心,流量尽可能大。

好了,开始讲Push-relabel的主要思想,首先构造一个函数excess,代表每个节点保存的流量,就是等于该节点的入流量-出流量,正常来说,s的保存流量为负,t的保存流量为正,其他节点的保存流量均为0,而算法的最终目标就是这个,此外还定义一个height函数(h),表示每个节点的高度。然后,初始化过程是,h(s)=n,h(v)=0,对于所有不为s的节点,f(s, u)=c(s, u),对于所有从s出发的边都默认饱和,这是上界。接着,就是Push-relabel的过程了,首先遍历图中所有节点,如果存在非t的且excess大于0的节点v,则查看v出发的所有边(v, w),如果h(v)>h(w),则可以将label,即excess的流量,传递给w,如果该边为正向边,传的大小为bottleneck=min{excess(v), c(v,w) - f(v, w)},否则bottleneck=min{excess(v), f(v, w)},传完之后,继续寻找excess大于0的节点,注意,如果v有边,但所有边都是h(v)<h(w),则将v的高度提升1,继续寻找。

源代码如下:

注意图的输入格式需满足DIMACS格式。

__author__ = 'xanxus'
nodeNum, edgeNum = 0, 0
arcs = []


class Arc(object):
    def __init__(self):
        self.src = -1
        self.dst = -1
        self.cap = -1


s, t = -1, -1
with open('sample.dimacs') as f:
    for line in f.readlines():
        line = line.strip()
        if line.startswith('p'):
            tokens = line.split(' ')
            nodeNum = int(tokens[2])
            edgeNum = tokens[3]
        if line.startswith('n'):
            tokens = line.split(' ')
            if tokens[2] == 's':
                s = int(tokens[1])
            if tokens[2] == 't':
                t = int(tokens[1])
        if line.startswith('a'):
            tokens = line.split(' ')
            arc = Arc()
            arc.src = int(tokens[1])
            arc.dst = int(tokens[2])
            arc.cap = int(tokens[3])
            arcs.append(arc)

nodes = [-1] * nodeNum
for i in range(s, t + 1):
    nodes[i - s] = i
adjacent_matrix = [[0 for i in range(nodeNum)] for j in range(nodeNum)]
forward_matrix = [[0 for i in range(nodeNum)] for j in range(nodeNum)]
for arc in arcs:
    adjacent_matrix[arc.src - s][arc.dst - s] = arc.cap
    forward_matrix[arc.src - s][arc.dst - s] = arc.cap
flow_matrix = [[0 for i in range(nodeNum)] for j in range(nodeNum)]

height = [0] * nodeNum
height[0] = nodeNum
for i in range(len(adjacent_matrix)):
    flow_matrix[0][i] = adjacent_matrix[0][i]
    adjacent_matrix[0][i] = 0
    adjacent_matrix[i][0] = flow_matrix[0][i]


def excess(v):
    in_flow, out_flow = 0, 0
    for i in range(len(flow_matrix)):
        in_flow += flow_matrix[i][v]
        out_flow += flow_matrix[v][i]
    return in_flow - out_flow


def exist_excess():
    for v in range(len(flow_matrix)):
        if excess(v) > 0 and v != t - s:
            return v
    return None


v = exist_excess()
while v:
    has_lower_height = False
    for j in range(len(adjacent_matrix)):
        if adjacent_matrix[v][j] != 0 and height[v] > height[j]:
            has_lower_height = True
            if forward_matrix[v][j] != 0:
                bottleneck = min([excess(v), adjacent_matrix[v][j]])
                flow_matrix[v][j] += bottleneck
                adjacent_matrix[v][j] -= bottleneck
                adjacent_matrix[j][v] += bottleneck
            else:
                bottleneck = min([excess(v), flow_matrix[j][v]])
                flow_matrix[j][v] -= bottleneck
                adjacent_matrix[v][j] -= bottleneck
                adjacent_matrix[j][v] += bottleneck
    if not has_lower_height:
        height[v] += 1
    v = exist_excess()
for arc in arcs:
    print 'f %d %d %d' % (arc.src, arc.dst, flow_matrix[arc.src - s][arc.dst - s])

希望对大家有所帮助。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏coolblog.xyz技术专栏

科普:String hashCode 方法为什么选择数字31作为乘子

某天,我在写代码的时候,无意中点开了 String hashCode 方法。然后大致看了一下 hashCode 的实现,发现并不是很复杂。但是我从源码中发现了一...

714190
来自专栏ACM算法日常

不同路径(动态规划)- leetcode 62

最近在看leetcode的题目,都是面试题,需要面试的同学可以努力刷这里的题目,因为很多公司的面试笔试题都是参考这个上面的。相对OJ上的题目,感...

18940
来自专栏算法修养

矩阵快速幂小结

      矩阵快速幂大概是用来解决这样一类问题,当你知道了一个递推式比如a[n]=a[n-1]+a[n-2] 题目要求你求出a[n]。如果n大于1亿怎么办? ...

33150
来自专栏生信宝典

R语言学习 - 箱线图(小提琴图、抖动图、区域散点图)

箱线图 箱线图是能同时反映数据统计量和整体分布,又很漂亮的展示图。在2014年的Nature Method上有2篇Correspondence论述了使用箱线图的...

1.1K100
来自专栏Petrichor的专栏

深度学习: global pooling (全局池化)

今天看SPPNet论文时,看到“global pooling”一词,不是很明白是啥概念。上网查了一下定义,在StackOverflow 上找到了答案:

55530
来自专栏机器之心

深度 | 从概念到实践,我们该如何构建自动微分库

30880
来自专栏深度学习自然语言处理

【笔记】高效率但却没用过的一些numpy函数

最近在看源码的时候,碰到了一些大佬们常用,但自己暂时还没用过的numpy函数,特意来总结下。

7520
来自专栏数据派THU

独家 | 一文读懂R中的探索性数据分析(附R代码)

探索性数据分析(EDA)是数据项目的第一步。我们将创建一个代码模板来实现这一功能。

17920
来自专栏机器之心

资源 | Tensorlang:基于TensorFlow的可微编程语言

319110
来自专栏数据结构与算法

深海中的STL—mt19937

mt19937 当你第一眼看到这玩意儿的时候 肯定禁不住吐槽:纳尼?这是什么鬼? 确实,这个东西鲜为人知,但是它却有着卓越的性能 简介 mt19937是c++1...

32540

扫码关注云+社区

领取腾讯云代金券