前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >挑战程序竞赛系列(31):4.5剪枝

挑战程序竞赛系列(31):4.5剪枝

作者头像
用户1147447
发布2019-05-26 09:27:37
4320
发布2019-05-26 09:27:37
举报
文章被收录于专栏:机器学习入门

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1434655

挑战程序竞赛系列(31):4.5剪枝

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

POJ 1011: Sticks

变态的DFS搜索,需要剪枝否则TLE,初始版本如下:

代码语言:javascript
复制
    void solve() {
        while (true){
            int n = ni();
            if (n == 0) break;
            int[] sticks = new int[n];
            int min = 0;
            int sum = 0;
            int max = 0;
            for (int i = 0; i < n; ++i){
                int len = ni();
                max = Math.max(max, len);
                sum += len;
                sticks[i] = len;
            }
            Arrays.sort(sticks);
            Set<Integer> mem = new HashSet<Integer>();
            for (int i = n; i >= 1; --i){
                if (sum % i == 0){
                    min = sum / i;
                    if (min >= max && dfs(sticks, min, 0, new boolean[n], mem)) break;
                }
            }
            out.println(min);
        }
    }

    public boolean dfs(int[] sticks, int min, int sum, boolean[] visited, Set<Integer> mem){
        if (mem.size() == sticks.length){
            return min == sum;
        }
        if (sum > min) return false;
        if (sum == min){
            if (dfs(sticks, min, 0, visited, mem)) return true;
            else return false;
        }
        for (int i = sticks.length - 1; i >= 0; --i){
            int rem = min - sum;
            if (!visited[i] && sticks[i] <= rem){
                visited[i] = true;
                mem.add(i);
                if (dfs(sticks, min, sum + sticks[i], visited, mem)){
                    return true;
                }
                else{
                    visited[i] = false;
                    mem.remove(i);
                }
            }
        }
        return false;
    }

代码细节可以忽略,visited和mem可以合并,做了一些简单的剪枝处理,但始终超时。思路是遍历各种组合,且当所有元素被使用后,看是否能够找到所有长度一致的木棒。

遍历超时很大一部分的原因在于dfs中有个for循环,对于重复长度的棒子过滤的不够干净,浪费了大量的搜素资源。我们可以采用map对长度进行统计,这样重复长度的棒子大可不必搜索,省时省力。

依旧遍历,对每种可能的组合进行搜索,搜索时记录拼接完成的棒子个数,个数 * 可能长度 = 总和时,遍历结束。

具体细节可以参考博文:http://www.hankcs.com/program/algorithm/poj-1011-sticks.html,不作赘述。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashSet;
import java.util.InputMismatchException;
import java.util.Set;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/P1011.txt";

    int[] in;
    int candicate;
    void solve() {
        while (true){
            int n = ni();
            if (n == 0) break;
            in = new int[51];
            finish = false;
            int sum = 0;
            candicate = 0;
            int max = 0;
            for (int i = 0; i < n; ++i){
                int len = ni();
                max = Math.max(max, len);
                sum += len;
                in[len] ++;
            }
            candicate = max;
            while (true){
                if (sum % candicate == 0){
                    check(sum / candicate, candicate, max);
                }
                if (finish) break;
                ++candicate;
            }
            out.println(candicate);
        }
    }

    boolean finish;
    public void check(int count, int len, int plen){
        --in[plen];
        if (count == 0){
            finish = true;
        }
        if (!finish){
            len -= plen; //剩余长度
            if (len != 0){
                int nextPlen = Math.min(len, plen);
                for (; nextPlen > 0; --nextPlen){
                    if (in[nextPlen] != 0){
                        check(count, len, nextPlen);
                    }
                }
            }
            else{
                int max = 50;
                while (max > 0 && in[max] == 0) --max;
                check(count - 1, candicate, max); //当前剩余棒子的最大长度
            }
        }
        ++in[plen];
    }


    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

提供一组变态数据,可以自己测测来玩玩,上述算法需要3s搜索出答案。

代码语言:javascript
复制
64
40 40 30 35 35 26 15 40 40 40 40 40 40 40 40 40 40 40 40 40 40
40 40 43 42 42 41 10 4 40 40 40 40 40 40 40 40 40 40 40 40 40
40 25 39 46 40 10 4 40 40 37 18 17 16 15 40 40 40 40 40 40 40 
40
0

ans:
454
[3503ms]

POJ 只需128ms走完全部测试数据,数据有点水啊。

POJ 2046: Gap

一道模拟题,用BFS广搜就好了,关键抓住填入空格的规则,只有一种情况,只允许填入左侧的下一个数字,所以在当前board下只会出现四种状态,没有什么搜索策略,按照轮次搜即可。

BFS的一个好处在于,能够以最短的距离搜到终止状态,也是此题的关键。不过还需要注意,当我们定义board的状态时,可以从整体出发,需要重写hashCode和equal方法,方便记录状态的访问情况,好题。

步骤:

  • 定义状态
  • 考虑状态的终止条件
  • 考虑状态的切换规则
  • 重写hashCode和equals方法

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashSet;
import java.util.InputMismatchException;
import java.util.LinkedList;
import java.util.Queue;
import java.util.Set;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P2046.txt";

    class Game{
        int[][] board = new int[4][8];
        int turn;
        public Game(int[][] board){
            this.board = board;
            this.turn = 0;
            int[] y = find(11);
            swap(new int[]{0, 0}, y);
            y = find(21);
            swap(new int[]{1, 0}, y);
            y = find(31);
            swap(new int[]{2, 0}, y);
            y = find(41);
            swap(new int[]{3, 0}, y);
        }

        public Game(Game newGame){
            for (int i = 0; i < 4; ++i){
                for (int j = 0; j < 8; ++j){
                    this.board[i][j] = newGame.board[i][j];
                }
            }
            this.turn = newGame.turn;
        }

        public boolean canFill(int i, int j){
            if (board[i][j] != 0) return false;
            if (board[i][j - 1] != 0 && (board[i][j - 1] % 10) != 7) return true;
            return false;
        }

        public boolean done(){
            for (int i = 0; i < 4; ++i){
                if (board[i][7] != 0) return false;
            }
            for (int i = 0; i < 4; ++i){
                for (int j = 0; j < 7; ++j){
                    if (board[i][j] != (i + 1) * 10 + (j + 1)) return false;
                }
            }
            return true;
        }

        public void fillGap(int i, int j){
            int key = board[i][j - 1] + 1;
            int[] pos = find(key);
            swap(new int[]{i, j}, pos);
            this.turn ++;
        }

        @Override
        public boolean equals(Object obj) {
            if (obj instanceof Game){
                Game that = (Game)obj;
                for (int i = 0; i < 4; ++i){
                    for (int j = 0; j < 8; ++j){
                        if (board[i][j] != that.board[i][j]) return false;
                    }
                }
                return true;
            }
            else return false;
        }

        public int[] find(int key){
            for (int i = 0; i < 4; ++i){
                for (int j = 0; j < 8; ++j){
                    if (board[i][j] == key) return new int[]{i, j};
                }
            }
            return new int[]{-1, -1};
        }

        public void swap(int[] x, int[] y){
            int tmp = board[x[0]][x[1]];
            board[x[0]][x[1]] = board[y[0]][y[1]];
            board[y[0]][y[1]] = tmp;
        }

        @Override
        public String toString() {
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < 4; ++i){
                for (int j = 0; j < 8; ++j){
                    sb.append(board[i][j] + (j + 1 == 8 ? "\n" : " "));
                }
            }
            return sb.toString();
        }

        @Override
        public int hashCode() {
            int hash = 0;
            for (int i = 0; i < 4; ++i){
                for (int j = 1; j < 8; ++j){
                    hash += board[i][j];
                    hash <<= 1;
                }
            }
            return hash;
        }
    }

    void solve() {
        int T = ni();
        while (T --> 0){
            int[][] board = new int[4][8];
            for (int i = 0; i < 4; ++i){
                for (int j = 1; j < 8; ++j){
                    board[i][j] = ni();
                }
            }
            Game game = new Game(board);
            Queue<Game> queue = new LinkedList<Game>();
            Set<Game> visited = new HashSet<Game>();
            if (game.done()){
                out.println(0);
                continue;
            }
            queue.offer(game);
            int ans = -1;
            boolean end = false;
            outer: while (!queue.isEmpty() && !end){
                Game gg = queue.poll();
                if (visited.contains(gg)) continue;
                visited.add(gg);
                for (int i = 0; i < 4; ++i){
                    for (int j = 1; j < 8; ++j){
                        if (gg.canFill(i, j)){
                            Game tmp = new Game(gg);
                            tmp.fillGap(i, j);
                            if (tmp.done()){
                                ans = tmp.turn;
                                end = true;
                                continue outer;
                            }
                            else queue.offer(tmp);
                        }
                    }
                }
            }
            out.println(ans);
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3134: Power Calculus

第一次遇到迭代加深算法,有点难以理解。刚开始使用BFS,但发现细节处理上有问题。此题有些关键地方,比如同一轮生成的解,不能结合使用,只能使用前几层的解和当前层解的组合,或许可以如此想象,在构造轮次时,只有一条链,这种构造路径难道不是DFS?没错,就是它,但是何时终止呢?

剪枝算法告诉我们,每个给定的n都有一个上界,就拿快速幂的例子来说,举例13,至多也就这些操作:

代码语言:javascript
复制
13-1=12
12/2=6
6/2=3
3-1=2
2/2=1

于是我们迭代找解的时候可以根据此上界进行剪枝,对数据预处理下,有上界函数:

代码语言:javascript
复制
    public int upper(int n){
        int cnt = 0;
        while (n > 0){
            if ((n & 1) != 0){
                cnt ++;
            }
            n >>= 1;
            cnt ++;
        }
        return cnt - 2;
    }

接着就是构一条生成路径了,采用DFS,巧妙之处在于这种DFS刚好能够模拟这种构造状态,神奇,比如:

代码语言:javascript
复制
1 -> 0
表示x需要耗费0次,初始状态

那么自然地:
2 = 1 + 1
表示: x^2 = x * x

所以有1生成了2

那么3怎么来? 1 + 2 = 3

但是当前层还会有 2 + 2 = 4

但神奇的是,1 + 2 = 3 和 2 + 2 = 4不会同一时刻遍历,而是分为两次dfs调用。

这就解决了之前BFS的一个bug,呵呵哒。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P3134.txt";

    int MAX_N = 1024;
    int MAX_D = 20;
    int[] exp = new int[MAX_D];
    int[] ans = new int[MAX_N];

    void solve(){
        Arrays.fill(exp, 1);
        for (int i = 2; i < MAX_N; ++i){
            ans[i] = upper(i);
        }
        dfs(0);
        while (true){
            int n = ni();
            if (n == 0) break;
            out.println(ans[n]);
        }
    }

    public void dfs(int d) {
        if (d > MAX_D) {
            return; 
        }
        for (int i = 0; i <= d; i++) {
            exp[d + 1] = exp[i] + exp[d]; // 乘法
            if (exp[d + 1] < MAX_N && ans[exp[d + 1]] >= d + 1) { //这层的解要是被更新的话,继续更新下下层的解
                ans[exp[d + 1]]  = d + 1; //更新解
                dfs(d + 1);
            }
            exp[d + 1] = exp[d] - exp[i]; // 除法
            if (exp[d + 1] > 0 && ans[exp[d + 1]] >= d + 1) {
                ans[exp[d + 1]] = d + 1;
                dfs(d + 1);
            }
        }
    }

    public int upper(int n){
        int cnt = 0;
        while (n > 0){
            if ((n & 1) != 0){
                cnt ++;
            }
            n >>= 1;
            cnt ++;
        }
        return cnt - 2;
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

这就厉害了,类似于BFS,但能够精确的控制每层解的非法组合。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2017年08月01日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 挑战程序竞赛系列(31):4.5剪枝
    • POJ 1011: Sticks
      • POJ 2046: Gap
        • POJ 3134: Power Calculus
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档