ABC414E Count A%B=C
题目描述
请计算满足以下条件的整数三元组 $(a, b, c)$ 的个数,并将结果对 $998244353$ 取模。
- $1 \leq a, b, c \leq N$。
- $a, b, c$ 互不相同。
- $a$ 除以 $b$ 的余数等于 $c$。
输入格式
输入从标准输入中给出,格式如下:
$N$
输出格式
请输出答案,输出占一行。
输入输出样例 #1
输入 #1
7
输出 #1
12
输入输出样例 #2
输入 #2
441
输出 #2
94700
输入输出样例 #3
输入 #3
411111111114
输出 #3
462474062
说明/提示
限制条件
- $3 \leq N \leq 10^{12}$
- $N$ 是整数
样例解释 1
满足条件的整数三元组 $(a, b, c)$ 共 $12$ 组,分别为:
- $(3,2,1)$
- $(4,3,1)$
- $(5,2,1)$
- $(5,3,2)$
- $(5,4,1)$
- $(6,4,2)$
- $(6,5,1)$
- $(7,2,1)$
- $(7,3,1)$
- $(7,4,3)$
- $(7,5,2)$
- $(7,6,1)$
分析:
如果能确定a, b, 那么 c 也能确定。
但是题目限制a, b, c互不相等
考虑两种情况:
- a < b
- a % b == 0
显然两种情况不可能重复
那最终的三元组的数量就是 $n^2$ 减去 这两种情况
第一种是个等差数列
第二种发现可以枚举以i作为因子对总数的贡献 $\sum_{i=1}^{n} \left\lfloor \frac{n}{i} \right\rfloor$
这里用分块来做
证明
已知条件:
对任意块 $[l, r]$,满足
$\left\lfloor \dfrac{n}{l} \right\rfloor = \left\lfloor \dfrac{n}{r} \right\rfloor = k$
目标:
证明块右端点最大值为
$r_{\max} = \left\lfloor \dfrac{n}{\lfloor n/l \rfloor} \right\rfloor$
证明过程:
- 由取整函数性质得:
$$
k \leq \dfrac{n}{r} < k + 1
$$
- 不等式变形:
$$
\dfrac{n}{k+1} < r \leq \dfrac{n}{k}
$$
- 代入 $k = \left\lfloor \dfrac{n}{l} \right\rfloor$:
$$
r \leq \dfrac{n}{\lfloor n/l \rfloor}
$$
- 考虑整数约束:
因 $r \in \mathbf{N}$,且 $r$ 需满足不等式:
$$
r_{\max} = \max \left{ r \in \mathbf{Z} \ \middle| \ r \leq \dfrac{n}{\lfloor n/l \rfloor} \right} = \left\lfloor \dfrac{n}{\lfloor n/l \rfloor} \right\rfloor
$$
结论:
$$
\boxed{r_{\max} = \left\lfloor \dfrac{n}{\lfloor n/l \rfloor} \right\rfloor}
$$
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define int long long
const ll mod = 998244353;
ll ksm(ll a, ll b) {
ll ans = 1;
while (b) {
if (b&1) ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}
void solve() {
ll n;
cin >> n;
// ll ans = n * n - n * (n - 1) / 2;
ll ans = (n % mod * ((n + 1) % mod)) % mod * ((mod+1) / 2) % mod;
ll cnt = 0;
for (int l = 1, r = 1; l <= n; l = r + 1) {
r = min(n,n/(n/l));
cnt = (cnt + (r - l + 1) % mod * (n / l) % mod) % mod;
}
cout << (ans - cnt + mod) % mod<< '\n';
}
signed main() {
ios::sync_with_stdio(false), cin.tie(nullptr);
int tt = 1;
while (tt--) solve();
return 0;
}