Codeforces Educational Round 13 Problem D

CFEdu13D Iterated Linear Function

Solution Sketch

Observe the following formula, you can see that the answer is hidden in it!

$$\begin{aligned} \begin{bmatrix} a & b \\ 0 & 1 \\ \end{bmatrix}^{n} \begin{bmatrix} x \\ 1 \\ \end{bmatrix} = \begin{bmatrix} f_n\\ 1 \\ \end{bmatrix} \end{aligned}$$

Expansion:

$n = 1$
$$\begin{aligned} \begin{bmatrix} a & b \\ 0 & 1 \\ \end{bmatrix} \begin{bmatrix} x \\ 1 \\ \end{bmatrix} = \begin{bmatrix} ax + b\\ 1 \\ \end{bmatrix} \end{aligned}$$

$n = 2$, uses result from $n = 1$
$$\begin{aligned} \begin{bmatrix} a & b \\ 0 & 1 \\ \end{bmatrix}^{2} \begin{bmatrix} x \\ 1 \\ \end{bmatrix} = \begin{bmatrix} a & b \\ 0 & 1 \\ \end{bmatrix} \begin{bmatrix} ax + b\\ 1 \\ \end{bmatrix} = \begin{bmatrix} a(ax + b) + b\\ 1 \\ \end{bmatrix} \end{aligned}$$

AC Code (Using operator overloading)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
const ll M = ((ll)1e9 + 7);
ll a, b, x, n;
struct matrix {
ll data[2][2];
matrix() // default constructor
{
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
data[i][j] = 0;
}
matrix operator*(matrix q) // overload *
{
matrix ans;
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
for (int k = 0; k < 2; k++)
ans.data[i][j] =
(ans.data[i][j] + (data[i][k] * q.data[k][j]) % M) % M;
return ans;
}
};
ll fast_pow(ll exp)
{
matrix base;
base.data[0][0] = a;
base.data[0][1] = b;
base.data[1][1] = 1;
// res = base ^ n
matrix res;
res.data[0][0] = res.data[1][1] = 1;
while (exp) {
if (exp & 1) {
res = res * base;
}
base = base * base;
exp >>= 1;
};
// res1 = res * [x, 1]
ll res1 = (res.data[0][0] * x % M + res.data[0][1]) % M;
return res1 % M;
}
int main()
{
scanf("%lld %lld %lld %lld", &a, &b, &n, &x);
// fast pow
printf("%lld\n", fast_pow(n));
return 0;
}

AC code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
const ll M = ((ll)1e9 + 7);
ll a, b, x, n;
void mul(ll p[2][2], ll q[2][2])
{
ll tmp[2][2];
memset(tmp, 0, sizeof(tmp));
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
for (int k = 0; k < 2; k++) {
tmp[i][j] = (tmp[i][j] + (p[i][k] * q[k][j]) % M) % M;
}
}
}
memcpy(p, tmp, sizeof(tmp));
}
ll fast_pow(ll exp)
{
ll base[2][2] = {{a, b}, {0, 1}};
// res = base ^ n
ll res[2][2] = {{1, 0}, {0, 1}};
while (exp) {
if (exp & 1) {
mul(res, base);
}
mul(base, base);
exp >>= 1;
};
// res1 = res * [x, 1]
ll res1 = (res[0][0] * x % M + res[0][1]) % M;
return res1 % M;
}
int main()
{
scanf("%lld %lld %lld %lld", &a, &b, &n, &x);
// fast pow
printf("%lld\n", fast_pow(n));
return 0;
}