专栏首页机器学习入门挑战程序竞赛系列(30):3.4矩阵的幂

挑战程序竞赛系列(30):3.4矩阵的幂

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u014688145/article/details/76310181

挑战程序竞赛系列(30):3.4矩阵的幂

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

POJ 3734: Blocks

矩阵的幂入门题,写出递推式即可,题解:需要记录红色和绿色的状态,分成三个状态:

  • a:红色和绿色均为偶数时
  • b:红色和绿色恰为一个奇数(注意互斥)
  • c:红色和绿色均为奇数

这样当加入下一个木块时,就可以写出状态转移方程了,有点像HMM中的状态转移啊。。。

状态转移方程:

a = 2a + b;
b = 2a + 2b + 2c;
c = 2c + b;

矩阵幂技术在于把上述转移状态写成矩阵的形式,因为每个状态只和前几个状态相关而不是所有状态,这点很关键,于是有:

⎛⎝⎜aibici⎞⎠⎟=⎛⎝⎜220121022⎞⎠⎟i⎛⎝⎜a0b0c0⎞⎠⎟

\begin{pmatrix} a_i \\ b_i \\ c_i \\ \end{pmatrix} = \begin{pmatrix} 2 & 1 & 0 \\ 2 & 2 & 2 \\ 0 & 1 & 2 \\ \end{pmatrix}^i \begin{pmatrix} a_0 \\ b_0 \\ c_0 \\ \end{pmatrix}

当然可以思考下为什么矩阵的幂的时间复杂度为O(logn)O(\log n),关键在于求解AnA^n的过程加快了速度,传统的乘法需要循环n次,但我们可以利用二进制转十进制的性质,用快速幂来计算A的n次。

代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3734.txt";

    static final int MOD = 10007;
    void solve() {
        int T = ni();
        for (int t = 0; t < T; ++t){
            int n = ni();
            int[][] a = {{2, 1, 0},{2, 2, 2},{0, 1, 2}};
            Mat A = new Mat(a);
            A = A.pow(A, n, MOD);
            out.println(A.mat[0][0]);
        }
    }

    class Mat{
        int[][] mat;
        int n;
        int m;

        public Mat(int[][] arra){
            this.mat = arra;
            this.n = arra.length;
            this.m = arra[0].length;
        }

        public Mat mul(Mat A, Mat B, int MOD){
            int[][] a = A.mat;
            int[][] b = B.mat;
            int[][] res = new int[A.n][B.m];
            for (int i = 0; i < A.n; ++i){
                for (int j = 0; j < B.m; ++j){
                    for (int ll = 0; ll < A.m; ++ll){
                        res[i][j] = (res[i][j] + a[i][ll] * b[ll][j]) % MOD;
                    }
                }
            }
            return new Mat(res);
        }

        public Mat pow(Mat A, int n, int MOD){
            int[][] one = new int[A.n][A.m];
            for (int i = 0; i < A.n; ++i) one[i][i] = 1;
            Mat res = new Mat(one);

            while (n > 0){
                if (n % 2 != 0){
                    res = mul(res, A, MOD);
                }
                n >>= 1;
                A = mul(A, A, MOD);
            }
            return res;
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3420: Quad Tiling

参考博文:http://blog.sina.com.cn/s/blog_69c3f0410100vnhj.html

思路:关键看怎么找递推式了,起初找递推的方式比较幼稚,出现大量子问题重复情况,而这种再做进一步递推式不知道如何干净去重,有点蛋疼。

它的思路是根据2*1的木块在4行中可能出现的轮廓来构建,进行完美贴合,呵呵哒,所以说不一定要以“正确的完美的递推式”来递推出答案,(递推就一定要保证每个n正确的情况下才能完成么?它只要是其中几种情况的一个解即可),思维很重要啊!

所以如上可以构成6种合法轮廓,如下图:

接着根据这六种情况就可以写出递推式了:

an+1=an+bn+cn+dxn+dyn

a_{n + 1} = a_n + b_n + c_n + dx_n + dy_n

bn+1=an

b_{n + 1} = a_n

cn+1=an+e

c_{n + 1} = a_n + e

dxn+1=an+dyn

dx_{n + 1} = a_n + dy_n

dyn+1=an+dxn

dy_{n + 1} = a_n + dx_n

en+1=cn

e_{n + 1} = c_n

当然令 d = dx + dy,可得

dn+1=2an+dn

d_{n + 1} = 2a_n + d_n

于是我们得到了A矩阵为:

A=⎛⎝⎜⎜⎜⎜⎜⎜1112010000100011001000100⎞⎠⎟⎟⎟⎟⎟⎟

A = \begin{pmatrix} 1 & 1 & 1 & 1 & 0\\ 1 & 0 & 0 & 0 & 0 \\ 1 & 0 & 0 & 0 & 1\\ 2 & 0 & 0 & 1 & 0\\ 0 & 0 & 1 & 0 & 0\\ \end{pmatrix}

代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3420.txt";

    void solve() {
        while (true){
            int N = ni();
            int M = ni();
            if (N + M == 0) break;
            int[][] a = {{1,1,1,1,0},{1,0,0,0,0},{1,0,0,0,1},{2,0,0,1,0},{0,0,1,0,0}};
            Mat A = new Mat(a);
            A = A.pow(A, N, M);
            out.println(A.mat[0][0]);
        }

    }

    class Mat{
        int[][] mat;
        int n;
        int m;

        public Mat(int[][] mat){
            this.mat = mat;
            this.n = mat.length;
            this.m = mat[0].length;
        }

        public Mat mul(Mat A, Mat B, int MOD){
            int[][] a = A.mat;
            int[][] b = B.mat;
            int[][] res = new int[A.n][B.m];
            for (int i = 0; i < A.n; ++i){
                for (int j = 0; j < B.m; ++j){
                    for (int ll = 0; ll < A.m; ++ll){
                        res[i][j] = (res[i][j] + a[i][ll] * b[ll][j]) % MOD;
                    }
                }
            }
            return new Mat(res);
        }

        public Mat pow(Mat A, int n, int MOD){
            int[][] one = new int[A.n][A.n];
            for (int i = 0; i < A.n; ++i) one[i][i] = 1;
            Mat res = new Mat(one);
            while (n > 0){
                if ((n & 1) != 0){
                    res = mul(res, A, MOD);
                }
                n >>= 1;
                A = mul(A, A, MOD);
            }
            return res;
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3735: Training Little cats

如果能够想到矩阵幂来做,就不难了。无非就是如何根据这些操作来构造一个矩阵,就拿case为例:

3 1 6
g 1
g 2
g 2
s 1 2
g 3
e 2
0 0 0

有三只猫,可以当作变量a,b,c
g 1 : a = a + 1
如果看成矩阵
a   1 0 0 1   0
b = 0 1 0 0 * 0
c   0 0 1 0   0
1   0 0 0 1   1

得a = a + 1
同理,s 1 2 无非就是把元素i和j对应的位置交换下:
a   0 1 0 1   0
b = 1 0 0 0 * 0
c   0 0 1 0   0
1   0 0 0 1   1

e 2
令矩阵[1][1] = 0即可
a   0 1 0 1   0
b = 1 0 0 0 * 0
c   0 0 0 0   0
1   0 0 0 1   1
得 c = 0

每个操作可以单独和初始向量相乘,保证矩阵相乘的正确性,最后构造的最先乘,最后再幂乘m次。

注意两点:long防止溢出wa,稀疏矩阵加个判断,否则TLE。

代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.Stack;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3735.txt";

    int N;
    void solve() {
        while (true){
            N = ni();
            int M = ni();
            int K = ni();
            if (N + M + K == 0) break;
            Stack<Mat> stack = new Stack<Mat>();
            for (int i = 0; i < K; ++i){
                char c = nc();
                if (c == 'g'){
                    stack.push(createMat(c, ni() - 1, 0));
                }
                else if (c == 's'){
                    stack.push(createMat(c, ni() - 1, ni() - 1));
                }
                else{
                    stack.push(createMat(c, ni() - 1, 0));
                }
            }

            long[][] one = new long[N + 1][N + 1];
            for (int i = 0; i < N + 1; ++i) one[i][i] = 1;
            Mat A = new Mat(one);
            while (!stack.isEmpty()){
                A = mul(A, stack.pop());
            }
            A = pow(A, M);
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < N; ++i){
                sb.append(" " + A.mat[i][N]);
            }
            out.println(sb.deleteCharAt(0).toString());
        }
    }

    public Mat createMat(char command, int i, int j){
        long[][] one = new long[N + 1][N + 1];
        for (int l = 0; l < one.length; ++l) one[l][l] = 1;
        switch (command) {
        case 'g':
            one[i][N] = 1;
            break;
        case 's':
            one[i][i] = 0;
            one[j][j] = 0;
            one[i][j] = 1;
            one[j][i] = 1;
            break;
        case 'e':
            one[i][i] = 0;
            break;
        default:
            break;
        }
        return new Mat(one);
    }

    class Mat{
        long[][] mat;
        int n;
        int m;
        public Mat(long[][] mat){
            this.mat = mat;
            this.n = mat.length;
            this.m = mat[0].length;
        }
    }

    public Mat mul(Mat A, Mat B){
        long[][] a = A.mat;
        long[][] b = B.mat;
        long[][] res = new long[A.n][B.m];
        for (int i = 0; i < A.n; ++i){
            for (int ll = 0; ll < A.m; ++ll){
                if (a[i][ll] != 0){
                    for (int j = 0; j < B.m; ++j){
                        res[i][j] += a[i][ll] * b[ll][j];
                    }
                }
            }
        }
        return new Mat(res);
    }

    public Mat pow(Mat A, int n){
        long[][] one = new long[A.n][A.n];
        for (int i = 0; i < A.n; ++i) one[i][i] = 1;
        Mat res = new Mat(one);
        while (n > 0){
            if ((n & 1) != 0){
                res = mul(res, A);
            }
            n >>= 1;
            A = mul(A, A);
        }
        return res;
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

当然你也可以在生成矩阵时,直接对原始矩阵进行操作,不过这是代码量的优化,无关乎算法,具体代码参考博文:http://www.hankcs.com/program/algorithm/poj-3735-training-little-cats-time.html

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 挑战程序竞赛系列(26):3.5二分图匹配(1)

    版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.n...

    用户1147447
  • K-th Smallest Prime Fraction

    思路1: 一种聪明的做法,如果A = [1, 7, 23, 29, 47],那么有:

    用户1147447
  • 挑战程序竞赛系列(81):4.3 LCA(1)

    挑战程序竞赛系列(81):4.3 LCA(1) 传送门:POJ 2763: Housewife Wind 题意: XX村里有n个小屋,小屋之间有双向可达的道...

    用户1147447
  • 2019 CCPC 重现赛 1006 基环树

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    用户2965768
  • Python3 基础学习之数值进制转换

        这个函数在上篇里表示强转,并没有输入n这个参数。当n不输入的时候默认是n=10。

    ZY_FlyWay
  • 剑指OFFER之重建二叉树(九度OJ1385)

    题目描述: 输入某二叉树的前序遍历和中序遍历的结果,请重建出该二叉树。假设输入的前序遍历和中序遍历的结果中都不含重复的数字。例如输入前序遍历序列{1,2,4,7...

    用户1154259
  • 【优秀题解】题号:1179

    关注我们 今天给大家带来一份优秀题解(题号:1179): ? 解题思路 1 设共n=7个站,第一站上车a=5人,最后一站下车32人,设第二站上车人数为x(其...

    编程范 源代码公司
  • 每日算法系列【EOJ 3031】二进制倒置

    给定一个整数 、将 的 334 位二进制表示形式(不包括开头可能的值为 0 的位, 表示为 1 位 0)前后倒置,输出倒置后的二进制数对应的整数。

    godweiyang
  • LeetCode Contest 180

    ShenduCC
  • zoj 2521 LED Display

    题意:开灯,每个数字都由好几个灯组成,其中一些数字灭掉某些灯可以成为另一个数字,如0灭掉3个灯可以变成7,         现给你一组数字,如何组合可以形成最少...

    用户1624346

扫码关注云+社区

领取腾讯云代金券