題目連結

Matrix fast exponentiation!

Observe that when $n = 2$

$ \begin{bmatrix} 1 & 1 \newline 1 & 0 \newline \end{bmatrix}^{2 - 1} * \begin{bmatrix} f(1) \newline f(0) \newline \end{bmatrix} = \begin{bmatrix} f(0) + f(1) \newline f(1) \newline \end{bmatrix} = \begin{bmatrix} f(2) \newline f(1) \newline \end{bmatrix} $

So in general, we get

$ \begin{bmatrix} 1 & 1 \newline 1 & 0 \newline \end{bmatrix}^{n - 1} * \begin{bmatrix} f(1) \newline f(0) \newline \end{bmatrix} = \begin{bmatrix} f(n) \newline f(n - 1) \newline \end{bmatrix} $

AC Code

Bad coding style… GG

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int, int> ii;

struct Matrix {
private:
	int r, c;

public:
	vector<vector<int>> matrix;

	Matrix(int _r, int _c)
	{
		r = _r;
		c = _c;
		matrix = vector<vector<int>>(r, vector<int>(c, 0));
	}
};

int mod;
Matrix operator*(const Matrix &a, const Matrix &b)
{
	Matrix matrix(a.matrix.size(), b.matrix[0].size());

	for (int i = 0; i < (int)a.matrix.size(); i++)       // a row
		for (int j = 0; j < (int)b.matrix[0].size(); j++)  // b col
			for (int k = 0; k < (int)b.matrix.size(); k++) { // inner driver
				matrix.matrix[i][j] += a.matrix[i][k] * b.matrix[k][j] % mod;
				matrix.matrix[i][j] %= mod;
			}

	// for (int i = 0; i < a.matrix.size(); i++) {
	// 	for (int j = 0; j < b.matrix[0].size(); j++)
	// 		printf("%d ", matrix.matrix[i][j]);
	// 	puts("");
	// }
	return matrix;
}

Matrix operator^(const Matrix &a, int n)
{
	Matrix base(a.matrix.size(), a.matrix.size()),
	       ans(a.matrix.size(), a.matrix.size());
	ans.matrix[0][0] = ans.matrix[1][1] = 1;
	base = a;
	while (n > 0) {
		// printf("n %d\n", n);
		if (n & 1) {
			ans = base * ans;
		}

		base = base * base;
		n >>= 1;
	}

	return ans;
}

void solve()
{
	int a, b, n, m;
	scanf("%d %d %d %d", &a, &b, &n, &m);
	mod = 1;
	for (int i = 0; i < m; i++)
		mod *= 10;

	Matrix init(2, 1);
	init.matrix[0][0] = b;
	init.matrix[1][0] = a;
	Matrix square(2, 2);
	square.matrix[0][0] = square.matrix[0][1] = square.matrix[1][0] = 1;

	if (n == 0)
		printf("%d\n", a);
	if (n == 1)
		printf("%d\n", b);
	else {
		Matrix ans = (square ^ (n - 1)) * init;
		printf("%d\n", ans.matrix[0][0]);
	}
}

int main()
{
	int ncase;
	scanf("%d", &ncase);
	while (ncase--)
		solve();

	return 0;
}