前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >挑战程序竞赛系列(30):3.4矩阵的幂

挑战程序竞赛系列(30):3.4矩阵的幂

作者头像
用户1147447
发布2019-05-26 09:26:57
3850
发布2019-05-26 09:26:57
举报
文章被收录于专栏:机器学习入门

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1434651

挑战程序竞赛系列(30):3.4矩阵的幂

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

POJ 3734: Blocks

矩阵的幂入门题,写出递推式即可,题解:需要记录红色和绿色的状态,分成三个状态:

  • a:红色和绿色均为偶数时
  • b:红色和绿色恰为一个奇数(注意互斥)
  • c:红色和绿色均为奇数

这样当加入下一个木块时,就可以写出状态转移方程了,有点像HMM中的状态转移啊。。。

状态转移方程:

代码语言:javascript
复制
a = 2a + b;
b = 2a + 2b + 2c;
c = 2c + b;

矩阵幂技术在于把上述转移状态写成矩阵的形式,因为每个状态只和前几个状态相关而不是所有状态,这点很关键,于是有:

⎛⎝⎜aibici⎞⎠⎟=⎛⎝⎜220121022⎞⎠⎟i⎛⎝⎜a0b0c0⎞⎠⎟

\begin{pmatrix} a_i \ b_i \ c_i \ \end{pmatrix} = \begin{pmatrix} 2 & 1 & 0 \ 2 & 2 & 2 \ 0 & 1 & 2 \ \end{pmatrix}^i \begin{pmatrix} a_0 \ b_0 \ c_0 \ \end{pmatrix}

当然可以思考下为什么矩阵的幂的时间复杂度为O(logn)O(\log n),关键在于求解AnA^n的过程加快了速度,传统的乘法需要循环n次,但我们可以利用二进制转十进制的性质,用快速幂来计算A的n次。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3734.txt";

    static final int MOD = 10007;
    void solve() {
        int T = ni();
        for (int t = 0; t < T; ++t){
            int n = ni();
            int[][] a = {{2, 1, 0},{2, 2, 2},{0, 1, 2}};
            Mat A = new Mat(a);
            A = A.pow(A, n, MOD);
            out.println(A.mat[0][0]);
        }
    }

    class Mat{
        int[][] mat;
        int n;
        int m;

        public Mat(int[][] arra){
            this.mat = arra;
            this.n = arra.length;
            this.m = arra[0].length;
        }

        public Mat mul(Mat A, Mat B, int MOD){
            int[][] a = A.mat;
            int[][] b = B.mat;
            int[][] res = new int[A.n][B.m];
            for (int i = 0; i < A.n; ++i){
                for (int j = 0; j < B.m; ++j){
                    for (int ll = 0; ll < A.m; ++ll){
                        res[i][j] = (res[i][j] + a[i][ll] * b[ll][j]) % MOD;
                    }
                }
            }
            return new Mat(res);
        }

        public Mat pow(Mat A, int n, int MOD){
            int[][] one = new int[A.n][A.m];
            for (int i = 0; i < A.n; ++i) one[i][i] = 1;
            Mat res = new Mat(one);

            while (n > 0){
                if (n % 2 != 0){
                    res = mul(res, A, MOD);
                }
                n >>= 1;
                A = mul(A, A, MOD);
            }
            return res;
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3420: Quad Tiling

参考博文:http://blog.sina.com.cn/s/blog_69c3f0410100vnhj.html

思路:关键看怎么找递推式了,起初找递推的方式比较幼稚,出现大量子问题重复情况,而这种再做进一步递推式不知道如何干净去重,有点蛋疼。

它的思路是根据2*1的木块在4行中可能出现的轮廓来构建,进行完美贴合,呵呵哒,所以说不一定要以“正确的完美的递推式”来递推出答案,(递推就一定要保证每个n正确的情况下才能完成么?它只要是其中几种情况的一个解即可),思维很重要啊!

所以如上可以构成6种合法轮廓,如下图:

接着根据这六种情况就可以写出递推式了:

an+1=an+bn+cn+dxn+dyn

a_{n + 1} = a_n + b_n + c_n + dx_n + dy_n

bn+1=an

b_{n + 1} = a_n

cn+1=an+e

c_{n + 1} = a_n + e

dxn+1=an+dyn

dx_{n + 1} = a_n + dy_n

dyn+1=an+dxn

dy_{n + 1} = a_n + dx_n

en+1=cn

e_{n + 1} = c_n

当然令 d = dx + dy,可得

dn+1=2an+dn

d_{n + 1} = 2a_n + d_n

于是我们得到了A矩阵为:

A=⎛⎝⎜⎜⎜⎜⎜⎜1112010000100011001000100⎞⎠⎟⎟⎟⎟⎟⎟

A = \begin{pmatrix} 1 & 1 & 1 & 1 & 0\ 1 & 0 & 0 & 0 & 0 \ 1 & 0 & 0 & 0 & 1\ 2 & 0 & 0 & 1 & 0\ 0 & 0 & 1 & 0 & 0\ \end{pmatrix}

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3420.txt";

    void solve() {
        while (true){
            int N = ni();
            int M = ni();
            if (N + M == 0) break;
            int[][] a = {{1,1,1,1,0},{1,0,0,0,0},{1,0,0,0,1},{2,0,0,1,0},{0,0,1,0,0}};
            Mat A = new Mat(a);
            A = A.pow(A, N, M);
            out.println(A.mat[0][0]);
        }

    }

    class Mat{
        int[][] mat;
        int n;
        int m;

        public Mat(int[][] mat){
            this.mat = mat;
            this.n = mat.length;
            this.m = mat[0].length;
        }

        public Mat mul(Mat A, Mat B, int MOD){
            int[][] a = A.mat;
            int[][] b = B.mat;
            int[][] res = new int[A.n][B.m];
            for (int i = 0; i < A.n; ++i){
                for (int j = 0; j < B.m; ++j){
                    for (int ll = 0; ll < A.m; ++ll){
                        res[i][j] = (res[i][j] + a[i][ll] * b[ll][j]) % MOD;
                    }
                }
            }
            return new Mat(res);
        }

        public Mat pow(Mat A, int n, int MOD){
            int[][] one = new int[A.n][A.n];
            for (int i = 0; i < A.n; ++i) one[i][i] = 1;
            Mat res = new Mat(one);
            while (n > 0){
                if ((n & 1) != 0){
                    res = mul(res, A, MOD);
                }
                n >>= 1;
                A = mul(A, A, MOD);
            }
            return res;
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3735: Training Little cats

如果能够想到矩阵幂来做,就不难了。无非就是如何根据这些操作来构造一个矩阵,就拿case为例:

代码语言:javascript
复制
3 1 6
g 1
g 2
g 2
s 1 2
g 3
e 2
0 0 0

有三只猫,可以当作变量a,b,c
g 1 : a = a + 1
如果看成矩阵
a   1 0 0 1   0
b = 0 1 0 0 * 0
c   0 0 1 0   0
1   0 0 0 1   1

得a = a + 1
同理,s 1 2 无非就是把元素i和j对应的位置交换下:
a   0 1 0 1   0
b = 1 0 0 0 * 0
c   0 0 1 0   0
1   0 0 0 1   1

e 2
令矩阵[1][1] = 0即可
a   0 1 0 1   0
b = 1 0 0 0 * 0
c   0 0 0 0   0
1   0 0 0 1   1
得 c = 0

每个操作可以单独和初始向量相乘,保证矩阵相乘的正确性,最后构造的最先乘,最后再幂乘m次。

注意两点:long防止溢出wa,稀疏矩阵加个判断,否则TLE。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.Stack;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201707/3735.txt";

    int N;
    void solve() {
        while (true){
            N = ni();
            int M = ni();
            int K = ni();
            if (N + M + K == 0) break;
            Stack<Mat> stack = new Stack<Mat>();
            for (int i = 0; i < K; ++i){
                char c = nc();
                if (c == 'g'){
                    stack.push(createMat(c, ni() - 1, 0));
                }
                else if (c == 's'){
                    stack.push(createMat(c, ni() - 1, ni() - 1));
                }
                else{
                    stack.push(createMat(c, ni() - 1, 0));
                }
            }

            long[][] one = new long[N + 1][N + 1];
            for (int i = 0; i < N + 1; ++i) one[i][i] = 1;
            Mat A = new Mat(one);
            while (!stack.isEmpty()){
                A = mul(A, stack.pop());
            }
            A = pow(A, M);
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < N; ++i){
                sb.append(" " + A.mat[i][N]);
            }
            out.println(sb.deleteCharAt(0).toString());
        }
    }

    public Mat createMat(char command, int i, int j){
        long[][] one = new long[N + 1][N + 1];
        for (int l = 0; l < one.length; ++l) one[l][l] = 1;
        switch (command) {
        case 'g':
            one[i][N] = 1;
            break;
        case 's':
            one[i][i] = 0;
            one[j][j] = 0;
            one[i][j] = 1;
            one[j][i] = 1;
            break;
        case 'e':
            one[i][i] = 0;
            break;
        default:
            break;
        }
        return new Mat(one);
    }

    class Mat{
        long[][] mat;
        int n;
        int m;
        public Mat(long[][] mat){
            this.mat = mat;
            this.n = mat.length;
            this.m = mat[0].length;
        }
    }

    public Mat mul(Mat A, Mat B){
        long[][] a = A.mat;
        long[][] b = B.mat;
        long[][] res = new long[A.n][B.m];
        for (int i = 0; i < A.n; ++i){
            for (int ll = 0; ll < A.m; ++ll){
                if (a[i][ll] != 0){
                    for (int j = 0; j < B.m; ++j){
                        res[i][j] += a[i][ll] * b[ll][j];
                    }
                }
            }
        }
        return new Mat(res);
    }

    public Mat pow(Mat A, int n){
        long[][] one = new long[A.n][A.n];
        for (int i = 0; i < A.n; ++i) one[i][i] = 1;
        Mat res = new Mat(one);
        while (n > 0){
            if ((n & 1) != 0){
                res = mul(res, A);
            }
            n >>= 1;
            A = mul(A, A);
        }
        return res;
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

当然你也可以在生成矩阵时,直接对原始矩阵进行操作,不过这是代码量的优化,无关乎算法,具体代码参考博文:http://www.hankcs.com/program/algorithm/poj-3735-training-little-cats-time.html

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2017年07月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 挑战程序竞赛系列(30):3.4矩阵的幂
    • POJ 3734: Blocks
      • POJ 3420: Quad Tiling
        • POJ 3735: Training Little cats
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档