前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >挑战程序竞赛系列(29):3.4熟练掌握动态规划

挑战程序竞赛系列(29):3.4熟练掌握动态规划

作者头像
用户1147447
发布2019-05-26 09:27:20
4540
发布2019-05-26 09:27:20
举报
文章被收录于专栏:机器学习入门机器学习入门

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1434653

挑战程序竞赛系列(29):3.4熟练掌握动态规划

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

POJ 2441: Arrange the Bulls

开始状态压缩,做了几道,发现状态压缩都是带记忆的暴力枚举。解决问题的关键在对状态的抽象,从而可以降低直接暴力枚举的时间复杂度。这些题目往往都是些NP问题,如旅行商问题。

此题思路:对于车票而言,用一张少一张,很明显的一个阶段(DAG),所以不会走环路,那么用简单的DP就能解决。问题在于阶段中所有状态该如何寻找,很明显从城市a开始,因为此时没有使用任何车票,所以可以枚举任何一张车票和与a相连的城市状态,总共有:(剩余车票数 * 与城市a连接的城市总数)个状态,那么可以构造:

代码语言:javascript
复制
dp[s][v] // s表示当前剩余的所有车票,v表示已经抵达的目的城市,值存放了从a到v的最短路径

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/2686.txt";


    static final int INF = 1 << 29;
    void solve() {
        while (true){
            int n = ni();
            int m = ni();
            int p = ni();
            int a = ni();
            int b = ni();

            if (n + m + p + a + b == 0) break;

            a --;
            b --;
            int[] tickets = new int[n];
            for (int i = 0; i < n; ++i){
                tickets[i] = ni();
            }

            int[][] graph = new int[m][m];
            for (int i = 0; i < m; ++i) Arrays.fill(graph[i], -1);

            for (int i = 0; i < p; ++i){
                int from = ni();
                int to = ni();
                from --;
                to --;
                int dist = ni();
                graph[from][to] = dist;
                graph[to][from] = dist;
            }

            double[][] dp = new double[1 << n][m + 16];
            for (int i = 0; i < (1 << n); ++i) Arrays.fill(dp[i], INF);

            dp[(1 << n) - 1][a] = 0; //从城市a出发,且全票情况下的最短路径
            double res = INF;

            for (int s = (1 << n) - 1; s >= 0; --s){
                res = Math.min(res, dp[s][b]);
                for (int v = 0; v < m; ++v){
                    for (int t = 0; t < n; ++t){
                        //票还在集合当中,则从集合删除
                        if (((s >> t) & 1) != 0){
                            for (int u = 0; u < m; ++u){
                                if (graph[u][v] > 0){
                                    dp[s & ~(1 << t)][v] = Math.min(dp[s & ~(1 << t)][v], dp[s][u] + graph[u][v] / (1.0 * tickets[t])); 
                                }
                            }
                        }
                    }
                }
            }

            if (res == INF){
                out.println("Impossible");
            }
            else{
                out.printf("%.3f\n", res);
            }
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 2441: Arrange the Bulls

此题经历了三个版本,分别MLE,TLE,最后才AC。思路比较简单,每头牛的选择只与剩余地点集合有关,所以只要在选择时,确保该地点没有被选择,那么这头牛就可以选则这球场,并且是所有可选状态的总和。可以想象,选定即占位,对于后来的牛来说,并不在乎谁选了哪些球场,只在乎还有多少球场可选。

代码如下:

代码语言:javascript
复制
    void solve() {
        int N = ni();
        int M = ni();
        List<Integer>[] g = new ArrayList[N];
        for (int i = 0; i < N; ++i) g[i] = new ArrayList<Integer>();

        for (int i = 0; i < N; ++i){
            int m = ni();
            for (int j = 0; j < m; ++j){
                g[i].add(ni() - 1);
            }
        }

        int[][] dp = new int[N][1 << M];
        //阶段1
        for (int barn : g[0]){
            dp[0][0 | (1 << barn)] = 1;
        }

        for (int i = 1; i < N; ++i){
            for (int barn : g[i]){
                for (int s = 0; s < (1 << M); ++s){
                    if ((s >> barn & 1) == 0){
                        dp[i][s | (1 << barn)] += dp[(i - 1)][s];
                    }
                }
            }
        }

        int sum = 0;
        for (int s = 0; s < (1 << M); ++s){
            sum += dp[N - 1][s];
        }

        out.println(sum);
    }

数组开的太大,MLE了,此题的特色在于当前阶段之和前一阶段有关,所以可以采用滚动数组,代码如下:

代码语言:javascript
复制
    void solve() {
        int N = ni();
        int M = ni();
        List<Integer>[] g = new ArrayList[N];
        for (int i = 0; i < N; ++i) g[i] = new ArrayList<Integer>();

        for (int i = 0; i < N; ++i){
            int m = ni();
            for (int j = 0; j < m; ++j){
                g[i].add(ni() - 1);
            }
        }

        int[][] dp = new int[2][1 << M];
        //阶段1
        for (int barn : g[0]){
            dp[0][0 | (1 << barn)] = 1;
        }

        for (int i = 1; i < N; ++i){
            for (int barn : g[i]){
                for (int s = 0; s < (1 << M); ++s){
                    if ((s >> barn & 1) == 0){
                        dp[i % 2][s | (1 << barn)] += dp[(i - 1) % 2][s];
                    }
                }
            }
            Arrays.fill(dp[(i - 1) % 2], 0);
        }

        int sum = 0;
        for (int s = 0; s < (1 << M); ++s){
            sum += dp[(N - 1) % 2][s];
        }

        out.println(sum);
    }

TLE了,可以观察下循环,原因在于对每个阶段,都会有很多无效状态参与计算,很显然那些只会在阶段i+1之后出现的状态没有必要遍历,所以我们必须采取遍历大小为i的所有子集的算法。

《挑战》给了我们一个很好的算法:

嘿,在书P157讲的很详细,我们用到了枚举大小为k的子集方法:

代码语言:javascript
复制
for (int comb = (1 << i) - 1, x, y; comb < 1 << M; x = comb & -comb, y = comb + x, comb = ((comb & ~y) / x >> 1) | y){
    //遍历所有大小为k的子集,按升序遍历
}

原理可以参看书中的分析,主要是先找出最低位的1,接着把地位连续1全部变成0,接着检测出所有由1变成0的连续1,把它们的个数记录下来,移动到最右端和comb或一下,得到结果。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.List;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/2441.txt";

    void solve() {
        int N = ni();
        int M = ni();
        List<Integer>[] g = new ArrayList[N];
        for (int i = 0; i < N; ++i) g[i] = new ArrayList<Integer>();

        for (int i = 0; i < N; ++i){
            int m = ni();
            for (int j = 0; j < m; ++j){
                g[i].add(ni() - 1);
            }
        }

        int[] dp = new int[1 << M];

        for (int u : g[0]){
            dp[0 | 1 << u] = 1;
        }

        for (int i = 1; i < N; ++i){
            for (int comb = (1 << i) - 1, x, y; comb < 1 << M; x = comb & -comb, y = comb + x, comb = ((comb & ~y) / x >> 1) | y){
                if (dp[comb] != 0){
                    for (int j : g[i]){
                        if ((comb & 1 << j) == 0){
                            dp[comb | (1 << j)] += dp[comb];
                        }
                    }
                }
            }
        }

        int sum = 0;
        for (int comb = (1 << N) - 1, x, y; comb < 1 << M; x = comb & -comb, y = comb + x, comb = ((comb & ~y) / x >> 1) | y){
            sum += dp[comb];
        }
        out.println(sum);
    }



    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

JAVA是真的比C++慢好几个数量级啊。。。

POJ 3254: Corn Fields

思路:依旧找阶段,最初定义阶段的方法,每次从集合中添加一个坑,但这种方式的状态转换不太好求,所以转换思路。

枚举第一行所有可能的种植情况,并把第一行的所有状态记录下来,此时遍历第二行的所有可能状态,找出第一行和第二行合法状态的交集,即为答案。

判断第一行和第二行是否合法可以采用if ((s1 & s2) == 0),这就表明不可能选择相邻元素,高明。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3254.txt";

    static final int MOD = 1000000000; 
    int N = 0;
    int M = 0;
    void solve() {
        N = ni();
        M = ni();
        boolean[][] board = new boolean[N][M];
        for (int i = 0; i < N; ++i){
            for (int j = 0; j < M; ++j){
                if (ni() == 1) board[i][j] = true;
            }
        }

        int[][] dp = new int[N][1 << M];
        for (int i = 0; i < 1 << M; ++i){
            if (valid(i, board[0])){
                dp[0][i] = 1;
            }
        }

        for (int i = 1; i < N; ++i){
            for (int j = 0; j < 1 << M; ++j){
                if (valid(j, board[i])){
                    for (int s = 0; s < 1 << M; ++s){
                        if ((j & s) == 0){
                            dp[i][j] = (dp[i][j] + dp[i - 1][s]) % MOD;
                        }
                    }
                }
            }
        }

        int sum = 0;
        for (int i = 0; i < 1 << M; ++i){
            if (valid(i, board[N - 1])){
                sum = (sum + dp[N - 1][i]) % MOD;
            }
        }
        out.println(sum);
    }

    public boolean valid(int s, boolean[] board){
        for (int i = 0; i < M; ++i){
            if ((s & (1 << i)) != 0){
                if (!board[i]) return false;
                if ((i + 1 < M) && (s & (1 << (i + 1))) != 0) return false;
                if ((i - 1 >= 0) && (s & (1 << (i - 1))) != 0) return false;
            }
        }
        return true;
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 2836: Rectangle Covering

思路:首先枚举所有点两两组合成的矩形,并且得到该矩形包含的所有点(反证法,必然两两组合的方案是最小矩形,且一定在这两点的边界上),有了这些矩形集合,就可以从状态0(没有任何点构成矩形)不断扩展到(所有点构成的矩形)。状态压缩的关键在于,对于每一个中间状态只记录衍生过来的最小值即可。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.List;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/2386.txt";

    class Rec{
        int covered;
        int area;

        public Rec(int covered, int area){
            this.covered = covered;
            this.area = area;
        }

        public void add(int i){
            covered |= 1 << i;
        }
    }

    public boolean inRec(int[] a, int[] b, int[] p){
        int minX = Math.min(a[0], b[0]);
        int maxX = Math.max(a[0], b[0]);
        int minY = Math.min(a[1], b[1]);
        int maxY = Math.max(a[1], b[1]);
        int x = p[0], y = p[1];
        return x >= minX && x <= maxX && y >= minY && y <= maxY;
    }

    static final int INF = 0x3f3f3f3f;
    void solve() {
        while (true){
            int n = ni();
            if (n == 0) break;

            int[][] points = new int[n][2];
            for (int i = 0; i < n; ++i){
                int x = ni();
                int y = ni();
                points[i] = new int[]{x, y};
            }

            List<Rec> recs = new ArrayList<Rec>();
            for (int i = 0; i < n; ++i){
                for (int j = i + 1; j < n; ++j){
                    Rec rec = new Rec(1 << i | 1 << j, Math.max(1, Math.abs(points[i][0] - points[j][0]))
                            * Math.max(1, Math.abs(points[i][1] - points[j][1])));
                    for (int k = 0; k < n; ++k){
                        if (inRec(points[i], points[j], points[k])){
                            rec.add(k);
                        }
                    }
                    recs.add(rec);
                }
            }

            int[] dp = new int[1 << n]; //所有点加入到集合中的状态总数
            Arrays.fill(dp, INF);
            dp[0] = 0;
            for (int s = 0; s < 1 << n; ++s){
                for (Rec rec : recs){
                    int ns = s | rec.covered;
                    if (dp[s] != INF && ns != s){
                        dp[ns] = Math.min(dp[ns], dp[s] + rec.area);
                    }
                }
            }
            out.println(dp[(1 << n) - 1]);
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3411: Paid Roads

我还纳闷,case中的110怎么来的,题目大意是说:从城市a到城市b的路费可以在城市c缴纳,也可以直接在城市a缴纳,这就意味着如果在城市c缴纳的路费较便宜,且之前已经抵达过城市c了,反正还是去b,直接在c把路费缴了更划算。

它是一个根据状态在随意转变的图,跟所去过的顶点有关,可以用集合S来表示去过顶点的集合(状态压缩),接着找状态转移即可:

代码语言:javascript
复制
dp[S][v]: 表示在状态s下,抵达城市v的最短路径

方法:采用类似的Floyd-Warshall松弛算法

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.List;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3411.txt";

    class Road{
        int a;
        int b;
        int c;
        int p;
        int r;
        public Road(int a, int b, int c, int p, int r){
            this.a = a;
            this.b = b;
            this.c = c;
            this.p = p;
            this.r = r;
        }
    }

    class Node implements Comparable<Node>{
        int to;
        int dist;

        @Override
        public int compareTo(Node that) {
            return this.dist - that.dist;
        }
    }

    static final int INF = 1 << 29;
    void solve() {
        int N = ni();
        int M = ni();
        Road[] roads = new Road[M];
        List<Road>[] g = new ArrayList[N];
        for (int i = 0; i < N; ++i) g[i] = new ArrayList<Road>();

        for (int i = 0; i < M; ++i){
            int a = ni() - 1;
            int b = ni() - 1;
            int c = ni() - 1;
            int p = ni();
            int r = ni();
            roads[i] = new Road(a, b, c, p, r);
            g[a].add(roads[i]);
        }
        int[][] distance = new int[1 << N][N];
        for (int i = 0; i < 1 << N; ++i) Arrays.fill(distance[i], INF);

        distance[0][0] = 0;
        for (int i = 0; i < N; ++i){
            for (int s = 0; s < 1 << N; ++s){
                for (int v = 0; v < N; ++v){
                    for (Road r : g[v]){
                        int ns = s | 1 << r.a | 1 << r.b | 1 << r.c;
                        int cost = 0;
                        if ((s & (1 << r.c)) == 0) cost = r.r;
                        else cost = Math.min(r.r, r.p);
                        if (distance[s][v] != INF && distance[ns][r.b] > distance[s][v] + cost){
                            distance[ns][r.b] = distance[s][v] + cost;
                        }
                    }
                }
            }
        }
        int min = INF;
        for (int s = 0; s < 1 << N; ++s){
            min = Math.min(min, distance[s][N - 1]);
        }
        if (min == INF) out.println("impossible");
        else out.println(min);
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

本来就想试试这种算法没想到AC了,时间复杂度较高,因为我们知道,在所有可能的路费中都是正值,因此可以采用Dijkstra的算法来找寻最短路径,这样能避免大量不必要的更新,当然这里的技巧在于当下次一抵达相同顶点时,可以看成另一个顶点,因为状态不同,吼吼。

代码如下:

代码语言:javascript
复制
    class Road{
        int a;
        int b;
        int c;
        int p;
        int r;
        public Road(int a, int b, int c, int p, int r){
            this.a = a;
            this.b = b;
            this.c = c;
            this.p = p;
            this.r = r;
        }
    }

    class Node implements Comparable<Node>{
        int v;
        int dist;
        int S;

        public Node(int v, int dist, int S){
            this.v = v;
            this.dist = dist;
            this.S = S;
        }

        @Override
        public int compareTo(Node that) {
            return this.dist - that.dist;
        }

        @Override
        public String toString() {
            return v + " " + dist + " " + S;
        }
    }

    static final int INF = 1 << 29;
    void solve() {
        int N = ni();
        int M = ni();
        Road[] roads = new Road[M];
        List<Road>[] g = new ArrayList[N];
        for (int i = 0; i < N; ++i) g[i] = new ArrayList<Road>();

        for (int i = 0; i < M; ++i){
            int a = ni() - 1;
            int b = ni() - 1;
            int c = ni() - 1;
            int p = ni();
            int r = ni();
            roads[i] = new Road(a, b, c, p, r);
            g[a].add(roads[i]);
        }

        boolean[][] visited = new boolean[1 << N][N];
        Node start = new Node(0, 0, 0); //顶点0 , 在状态0下的,最短距离为0
        Queue<Node> queue = new PriorityQueue<Node>();
        queue.offer(start);

        int ans = INF;
        while (!queue.isEmpty()){
            Node r = queue.poll();
            if (visited[r.S][r.v]) continue;
            if (r.v == N - 1){
                ans = r.dist;
                break;
            }
            visited[r.S][r.v] = true;
            int v = r.v;
            for (Road edge : g[v]){
                int ns = r.S | 1 << edge.c | 1 << edge.a | 1 << edge.b;
                int cost = 0;
                if ((r.S & 1 << edge.c) == 0) cost = edge.r;
                else cost = Math.min(edge.r, edge.p);
                int to = edge.b;
                queue.offer(new Node(to, cost + r.dist, ns));
            }
        }

        if (ans == INF){
            out.println("impossible");
        }
        else{
            out.println(ans);
        }

    }

POJ 1795: DNA Laboratory

此题做的辛苦,先用DFS发现策略是错误的,接着看如何使用状态压缩,可状态压缩还不够啊,后续还要还原最优DP路径,构造字符串,我就呵呵了。又是TLE,又是MLE,折腾了几个小时总算AC了,不过它的确是难得的好题,说说思路吧。

思路:首先单纯的进行认为的字符串拼接,我们能够得到两条规则:

  • 遇到带拼接的字符串包含与另一个字符串,则可以忽略被包含的字符串。
  • 两个字符串互不包含,这就意味着字符串i + 字符串j 和字符串j + 字符串i这两种情况都要试一试。

刚开始采用了DFS,把每个字符串的头和尾都拼接试一试,后来发现一个问题,如何确定i和j的顺序呢?或者说给定字符串集合{a,b,c},如何确定a,b,c的拼接顺序?这里就需要采用状态压缩来枚举所有的拼接顺序。

好了,假如现在有了字符串a,拼接b和拼接c都要试下,于是有了{ab,ac},okay,如果有了字符串b,则a和c都要试下,于是有{ba,bc},c俺就不例举了。综上:

代码语言:javascript
复制
a : {ab,ac}
b : {ba,bc}
c : {ca,cb}

拼接枚举出来之后,从中又可以得出结论,{ab} 和 {ba}对于答案来说是不需要区分先后顺序的!
我们只需要输出:min{ab,ba}即可

至于是ab还是ba?题目说的很清楚,输出长度较小的,若长度一致则按字典序大小输出(小的输出)。

所以问题就转换成了如何衡量{ab}和{ba}

对于计算机而言,需要可以量化的标准,所谓的代价,该代价我们可以用字符串拼接的增量表示。

具体增量是什么就不重复了,看代码一目了然。

okay,这样我们只要记录最后拼接的字符串是谁,就可以根据代价来选择全局的最短拼接长度

所以定义:
dp[S][j]: S表示如今有哪些字符串已经被拼接,且在该状态下以j结尾的最短拼接长度。

为什么需要定义j?
为了路径还原!以及为了衡量拼接代价,自行体会。

这样状态转移矩阵就跟着出来了。
dp[S | 1 << j][j] = min{dp[S | 1 << j][j], dp[S][i] + cost[i][j]};

显然维护的是在选择了相同集合的字符串中,求以j结尾的最短拼接长度,与其他字符串的选取顺序无关哦。

如:j = c, S = {a, b, c}
则枚举的结果为: {abc,bac},只要保证c在最后即可

但我们又知道,这只是c在最后的情况,还有b,还有a呢,所以只能用状态压缩咯。

最后如何构造最短拼接的路径呢?

路径会有多条!这是肯定的,因为在拼接长度一定的情况下,可以出现字典序不同的情况,此时就需要把所有这些拼接情况遍历出来,选择字典序最小的即可。

两种办法:DFS遍历和迭代

我用了迭代:
首先把最短路径在DP数组中标注出来,采用负数的形式,这样可以忽略哪些INF值和正值,而专注于构造负数的路径。

具体看代码吧,这部分还是比较容易理解的。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashSet;
import java.util.InputMismatchException;
import java.util.Set;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/1795.txt";


    static final int INF = 1 << 29;
    static final int MAX_N = 15;
    int[][] cost = new int[MAX_N][MAX_N];
    int[][] dp = new int[1 << MAX_N][MAX_N];
    void solve() {
        int test = ni();
        for (int t = 0; t < test; ++t){
            int len = ni();
            String[] str = new String[len];
            for (int j = 0; j < len; ++j){
                str[j] = ns();
            }

            //预处理,去重,把包含的字符串去除
            for (int i = 0; i < len; ++i){
                for (int j = 0; j < len; ++j){
                    if (i == j) continue;
                    if (str[i].contains(str[j])){
                        str[j] = str[i];
                    }
                }
            }

            Set<String> set = new HashSet<String>(Arrays.asList(str));
            int N = set.size();
            String[] newStr = set.toArray(new String[0]);

            int[] lenStr = new int[N];
            for (int i = 0; i < N; ++i){
                lenStr[i] = newStr[i].length();
            }

            //i右拼接j左
            for (int i = 0; i < N; ++i){
                for (int j = 0; j < N; ++j){
                    for (int l = 0; l < Math.min(lenStr[i], lenStr[j]); ++l){
                        if (newStr[i].substring(lenStr[i] - l).equals(newStr[j].substring(0, l))){
                            cost[i][j] = lenStr[j] - l;
                        }
                    }
                }
            }

            //进行最短距离拼接
            for (int i = 0; i < 1 << N; ++i){
                Arrays.fill(dp[i], INF);
            }

            //拼接i所需要的最短距离
            for (int i = 0; i < N; ++i){
                dp[0 | 1 << i][i] = lenStr[i];
            }

            //遍历每种状态,对每种状态进行i和j的拼接
            for (int s = 0; s < 1 << N; ++s){
                for (int i = 0; i < N; ++i){
                    if (dp[s][i] != INF){
                        for (int j = 0; j < N; ++j){
                            if ((s & 1 << j) == 0){
                                dp[s | 1 << j][j] = Math.min(dp[s | 1 << j][j], dp[s][i] + cost[i][j]);
                            }
                        }
                    }
                }
            }

            int bestLen = INF;
            for (int i = 0; i < N; ++i){
                bestLen = Math.min(dp[(1 << N) - 1][i], bestLen);
            }

            for (int i = 0; i < N; ++i){
                if (dp[(1 << N) - 1][i] == bestLen){
                    dp[(1 << N) - 1][i] = -dp[(1 << N) - 1][i];
                }
            }

            for (int s = (1 << N) - 1; s >= 0; --s){
                for (int i = 0; i < N; ++i){
                    if (dp[s][i] < 0){
                        for (int j = 0; j < N; ++j){
                            if (i != j && (s & (1 << j)) != 0){
                                if (dp[s & ~(1 << i)][j] + cost[j][i] == -dp[s][i]){
                                    dp[s & ~(1 << i)][j] = -dp[s & ~(1 << i)][j];
                                }
                            }
                        }
                    }
                }
            }

            String res = new String(new char[]{'z' + 1});
            int append = 0;
            int last = -1;
            for (int i = 0; i < N; ++i){
                if (dp[append | 1 << i][i] < 0){
                    if (res.compareTo(newStr[i]) > 0){
                        res = newStr[i];
                        last = i;
                    }
                }
            }
            append |= 1 << last;
            for (int i = 0; i < N - 1; ++i){
                String tail = new String(new char[]{'z' + 1});
                int key = -1;
                for (int j = 0; j < N; ++j){
                    if ((append & 1 << j) == 0){
                        if (dp[append | 1 << j][j] < 0){
                            if (Math.abs(dp[append][last]) + cost[last][j] == Math.abs(dp[append | 1 << j][j])){
                                if (tail.compareTo(newStr[j].substring(lenStr[j] - cost[last][j])) > 0){
                                    key = j;
                                    tail = newStr[j].substring(lenStr[j] - cost[last][j]);
                                }
                            }
                        }
                    }
                }
                last = key;
                append |= 1 << key;
                res += tail;
            }

            out.println("Scenario #"+(t + 1)+":");
            out.println(res);
            out.println();
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年07月27日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 挑战程序竞赛系列(29):3.4熟练掌握动态规划
    • POJ 2441: Arrange the Bulls
      • POJ 2441: Arrange the Bulls
        • POJ 3254: Corn Fields
          • POJ 2836: Rectangle Covering
            • POJ 3411: Paid Roads
              • POJ 1795: DNA Laboratory
              相关产品与服务
              文件存储
              文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档