版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1434651
详细代码可以fork下Github上leetcode项目,不定期更新。
练习题如下:
矩阵的幂入门题,写出递推式即可,题解:需要记录红色和绿色的状态,分成三个状态:
这样当加入下一个木块时,就可以写出状态转移方程了,有点像HMM中的状态转移啊。。。
状态转移方程:
a = 2a + b;
b = 2a + 2b + 2c;
c = 2c + b;
矩阵幂技术在于把上述转移状态写成矩阵的形式,因为每个状态只和前几个状态相关而不是所有状态,这点很关键,于是有:
⎛⎝⎜aibici⎞⎠⎟=⎛⎝⎜220121022⎞⎠⎟i⎛⎝⎜a0b0c0⎞⎠⎟
\begin{pmatrix} a_i \ b_i \ c_i \ \end{pmatrix} = \begin{pmatrix} 2 & 1 & 0 \ 2 & 2 & 2 \ 0 & 1 & 2 \ \end{pmatrix}^i \begin{pmatrix} a_0 \ b_0 \ c_0 \ \end{pmatrix}
当然可以思考下为什么矩阵的幂的时间复杂度为O(logn)O(\log n),关键在于求解AnA^n的过程加快了速度,传统的乘法需要循环n次,但我们可以利用二进制转十进制的性质,用快速幂来计算A的n次。
代码如下:
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
public class Main{
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201707/3734.txt";
static final int MOD = 10007;
void solve() {
int T = ni();
for (int t = 0; t < T; ++t){
int n = ni();
int[][] a = {{2, 1, 0},{2, 2, 2},{0, 1, 2}};
Mat A = new Mat(a);
A = A.pow(A, n, MOD);
out.println(A.mat[0][0]);
}
}
class Mat{
int[][] mat;
int n;
int m;
public Mat(int[][] arra){
this.mat = arra;
this.n = arra.length;
this.m = arra[0].length;
}
public Mat mul(Mat A, Mat B, int MOD){
int[][] a = A.mat;
int[][] b = B.mat;
int[][] res = new int[A.n][B.m];
for (int i = 0; i < A.n; ++i){
for (int j = 0; j < B.m; ++j){
for (int ll = 0; ll < A.m; ++ll){
res[i][j] = (res[i][j] + a[i][ll] * b[ll][j]) % MOD;
}
}
}
return new Mat(res);
}
public Mat pow(Mat A, int n, int MOD){
int[][] one = new int[A.n][A.m];
for (int i = 0; i < A.n; ++i) one[i][i] = 1;
Mat res = new Mat(one);
while (n > 0){
if (n % 2 != 0){
res = mul(res, A, MOD);
}
n >>= 1;
A = mul(A, A, MOD);
}
return res;
}
}
void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);
long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s + "ms");
}
public static void main(String[] args) throws Exception {
new Main().run();
}
private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;
private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf++];
}
private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}
private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}
private double nd() {
return Double.parseDouble(ns());
}
private char nc() {
return (char) skip();
}
private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
private char[] ns(int n) {
char[] buf = new char[n];
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p++] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}
private char[][] nm(int n, int m) {
char[][] map = new char[n][];
for (int i = 0; i < n; i++)
map[i] = ns(m);
return map;
}
private int[] na(int n) {
int[] a = new int[n];
for (int i = 0; i < n; i++)
a[i] = ni();
return a;
}
private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private boolean oj = System.getProperty("ONLINE_JUDGE") != null;
private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}
参考博文:http://blog.sina.com.cn/s/blog_69c3f0410100vnhj.html
思路:关键看怎么找递推式了,起初找递推的方式比较幼稚,出现大量子问题重复情况,而这种再做进一步递推式不知道如何干净去重,有点蛋疼。
它的思路是根据2*1的木块在4行中可能出现的轮廓来构建,进行完美贴合,呵呵哒,所以说不一定要以“正确的完美的递推式”来递推出答案,(递推就一定要保证每个n正确的情况下才能完成么?它只要是其中几种情况的一个解即可),思维很重要啊!
所以如上可以构成6种合法轮廓,如下图:
接着根据这六种情况就可以写出递推式了:
an+1=an+bn+cn+dxn+dyn
a_{n + 1} = a_n + b_n + c_n + dx_n + dy_n
bn+1=an
b_{n + 1} = a_n
cn+1=an+e
c_{n + 1} = a_n + e
dxn+1=an+dyn
dx_{n + 1} = a_n + dy_n
dyn+1=an+dxn
dy_{n + 1} = a_n + dx_n
en+1=cn
e_{n + 1} = c_n
当然令 d = dx + dy,可得
dn+1=2an+dn
d_{n + 1} = 2a_n + d_n
于是我们得到了A矩阵为:
A=⎛⎝⎜⎜⎜⎜⎜⎜1112010000100011001000100⎞⎠⎟⎟⎟⎟⎟⎟
A = \begin{pmatrix} 1 & 1 & 1 & 1 & 0\ 1 & 0 & 0 & 0 & 0 \ 1 & 0 & 0 & 0 & 1\ 2 & 0 & 0 & 1 & 0\ 0 & 0 & 1 & 0 & 0\ \end{pmatrix}
代码如下:
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
public class Main{
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201707/3420.txt";
void solve() {
while (true){
int N = ni();
int M = ni();
if (N + M == 0) break;
int[][] a = {{1,1,1,1,0},{1,0,0,0,0},{1,0,0,0,1},{2,0,0,1,0},{0,0,1,0,0}};
Mat A = new Mat(a);
A = A.pow(A, N, M);
out.println(A.mat[0][0]);
}
}
class Mat{
int[][] mat;
int n;
int m;
public Mat(int[][] mat){
this.mat = mat;
this.n = mat.length;
this.m = mat[0].length;
}
public Mat mul(Mat A, Mat B, int MOD){
int[][] a = A.mat;
int[][] b = B.mat;
int[][] res = new int[A.n][B.m];
for (int i = 0; i < A.n; ++i){
for (int j = 0; j < B.m; ++j){
for (int ll = 0; ll < A.m; ++ll){
res[i][j] = (res[i][j] + a[i][ll] * b[ll][j]) % MOD;
}
}
}
return new Mat(res);
}
public Mat pow(Mat A, int n, int MOD){
int[][] one = new int[A.n][A.n];
for (int i = 0; i < A.n; ++i) one[i][i] = 1;
Mat res = new Mat(one);
while (n > 0){
if ((n & 1) != 0){
res = mul(res, A, MOD);
}
n >>= 1;
A = mul(A, A, MOD);
}
return res;
}
}
void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);
long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s + "ms");
}
public static void main(String[] args) throws Exception {
new Main().run();
}
private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;
private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf++];
}
private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}
private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}
private double nd() {
return Double.parseDouble(ns());
}
private char nc() {
return (char) skip();
}
private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
private char[] ns(int n) {
char[] buf = new char[n];
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p++] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}
private char[][] nm(int n, int m) {
char[][] map = new char[n][];
for (int i = 0; i < n; i++)
map[i] = ns(m);
return map;
}
private int[] na(int n) {
int[] a = new int[n];
for (int i = 0; i < n; i++)
a[i] = ni();
return a;
}
private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private boolean oj = System.getProperty("ONLINE_JUDGE") != null;
private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}
如果能够想到矩阵幂来做,就不难了。无非就是如何根据这些操作来构造一个矩阵,就拿case为例:
3 1 6
g 1
g 2
g 2
s 1 2
g 3
e 2
0 0 0
有三只猫,可以当作变量a,b,c
g 1 : a = a + 1
如果看成矩阵
a 1 0 0 1 0
b = 0 1 0 0 * 0
c 0 0 1 0 0
1 0 0 0 1 1
得a = a + 1
同理,s 1 2 无非就是把元素i和j对应的位置交换下:
a 0 1 0 1 0
b = 1 0 0 0 * 0
c 0 0 1 0 0
1 0 0 0 1 1
e 2
令矩阵[1][1] = 0即可
a 0 1 0 1 0
b = 1 0 0 0 * 0
c 0 0 0 0 0
1 0 0 0 1 1
得 c = 0
每个操作可以单独和初始向量相乘,保证矩阵相乘的正确性,最后构造的最先乘,最后再幂乘m次。
注意两点:long防止溢出wa,稀疏矩阵加个判断,否则TLE。
代码如下:
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.Stack;
public class Main{
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201707/3735.txt";
int N;
void solve() {
while (true){
N = ni();
int M = ni();
int K = ni();
if (N + M + K == 0) break;
Stack<Mat> stack = new Stack<Mat>();
for (int i = 0; i < K; ++i){
char c = nc();
if (c == 'g'){
stack.push(createMat(c, ni() - 1, 0));
}
else if (c == 's'){
stack.push(createMat(c, ni() - 1, ni() - 1));
}
else{
stack.push(createMat(c, ni() - 1, 0));
}
}
long[][] one = new long[N + 1][N + 1];
for (int i = 0; i < N + 1; ++i) one[i][i] = 1;
Mat A = new Mat(one);
while (!stack.isEmpty()){
A = mul(A, stack.pop());
}
A = pow(A, M);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < N; ++i){
sb.append(" " + A.mat[i][N]);
}
out.println(sb.deleteCharAt(0).toString());
}
}
public Mat createMat(char command, int i, int j){
long[][] one = new long[N + 1][N + 1];
for (int l = 0; l < one.length; ++l) one[l][l] = 1;
switch (command) {
case 'g':
one[i][N] = 1;
break;
case 's':
one[i][i] = 0;
one[j][j] = 0;
one[i][j] = 1;
one[j][i] = 1;
break;
case 'e':
one[i][i] = 0;
break;
default:
break;
}
return new Mat(one);
}
class Mat{
long[][] mat;
int n;
int m;
public Mat(long[][] mat){
this.mat = mat;
this.n = mat.length;
this.m = mat[0].length;
}
}
public Mat mul(Mat A, Mat B){
long[][] a = A.mat;
long[][] b = B.mat;
long[][] res = new long[A.n][B.m];
for (int i = 0; i < A.n; ++i){
for (int ll = 0; ll < A.m; ++ll){
if (a[i][ll] != 0){
for (int j = 0; j < B.m; ++j){
res[i][j] += a[i][ll] * b[ll][j];
}
}
}
}
return new Mat(res);
}
public Mat pow(Mat A, int n){
long[][] one = new long[A.n][A.n];
for (int i = 0; i < A.n; ++i) one[i][i] = 1;
Mat res = new Mat(one);
while (n > 0){
if ((n & 1) != 0){
res = mul(res, A);
}
n >>= 1;
A = mul(A, A);
}
return res;
}
void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);
long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s + "ms");
}
public static void main(String[] args) throws Exception {
new Main().run();
}
private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;
private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf++];
}
private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}
private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}
private double nd() {
return Double.parseDouble(ns());
}
private char nc() {
return (char) skip();
}
private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
private char[] ns(int n) {
char[] buf = new char[n];
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p++] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}
private char[][] nm(int n, int m) {
char[][] map = new char[n][];
for (int i = 0; i < n; i++)
map[i] = ns(m);
return map;
}
private int[] na(int n) {
int[] a = new int[n];
for (int i = 0; i < n; i++)
a[i] = ni();
return a;
}
private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}
while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}
private boolean oj = System.getProperty("ONLINE_JUDGE") != null;
private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}
当然你也可以在生成矩阵时,直接对原始矩阵进行操作,不过这是代码量的优化,无关乎算法,具体代码参考博文:http://www.hankcs.com/program/algorithm/poj-3735-training-little-cats-time.html