整除分块

ABC414E Count A%B=C

题目描述

请计算满足以下条件的整数三元组 $(a, b, c)$ 的个数,并将结果对 $998244353$ 取模。

  • $1 \leq a, b, c \leq N$。
  • $a, b, c$ 互不相同。
  • $a$ 除以 $b$ 的余数等于 $c$。

输入格式

输入从标准输入中给出,格式如下:

$N$

输出格式

请输出答案,输出占一行。

输入输出样例 #1

输入 #1

7

输出 #1

12

输入输出样例 #2

输入 #2

441

输出 #2

94700

输入输出样例 #3

输入 #3

411111111114

输出 #3

462474062

说明/提示

限制条件

  • $3 \leq N \leq 10^{12}$
  • $N$ 是整数

样例解释 1

满足条件的整数三元组 $(a, b, c)$ 共 $12$ 组,分别为:

  • $(3,2,1)$
  • $(4,3,1)$
  • $(5,2,1)$
  • $(5,3,2)$
  • $(5,4,1)$
  • $(6,4,2)$
  • $(6,5,1)$
  • $(7,2,1)$
  • $(7,3,1)$
  • $(7,4,3)$
  • $(7,5,2)$
  • $(7,6,1)$

分析:
如果能确定a, b, 那么 c 也能确定。
但是题目限制a, b, c互不相等
考虑两种情况:

  1. a < b
  2. a % b == 0

显然两种情况不可能重复
那最终的三元组的数量就是 $n^2$ 减去 这两种情况
第一种是个等差数列
第二种发现可以枚举以i作为因子对总数的贡献 $\sum_{i=1}^{n} \left\lfloor \frac{n}{i} \right\rfloor$

这里用分块来做

证明

已知条件:
对任意块 $[l, r]$,满足
$\left\lfloor \dfrac{n}{l} \right\rfloor = \left\lfloor \dfrac{n}{r} \right\rfloor = k$

目标:
证明块右端点最大值为
$r_{\max} = \left\lfloor \dfrac{n}{\lfloor n/l \rfloor} \right\rfloor$

证明过程:

  1. 由取整函数性质得:

$$
k \leq \dfrac{n}{r} < k + 1
$$

  1. 不等式变形:

$$
\dfrac{n}{k+1} < r \leq \dfrac{n}{k}
$$

  1. 代入 $k = \left\lfloor \dfrac{n}{l} \right\rfloor$:

$$
r \leq \dfrac{n}{\lfloor n/l \rfloor}
$$

  1. 考虑整数约束:
    因 $r \in \mathbf{N}$,且 $r$ 需满足不等式:

$$
r_{\max} = \max \left{ r \in \mathbf{Z} \ \middle| \ r \leq \dfrac{n}{\lfloor n/l \rfloor} \right} = \left\lfloor \dfrac{n}{\lfloor n/l \rfloor} \right\rfloor
$$

结论:

$$
\boxed{r_{\max} = \left\lfloor \dfrac{n}{\lfloor n/l \rfloor} \right\rfloor}
$$

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define int long long
const ll mod = 998244353;

ll ksm(ll a, ll b) {
    ll ans = 1;
    while (b) {
        if (b&1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
} 

void solve() {
    ll n;
    cin >> n;
    // ll ans = n * n - n * (n - 1) / 2;
    ll ans = (n % mod * ((n + 1) % mod)) % mod * ((mod+1) / 2)  % mod; 
    ll cnt = 0;
    for (int l = 1, r = 1; l <= n; l = r + 1) {
        r = min(n,n/(n/l));
        cnt = (cnt + (r - l + 1) % mod * (n / l) % mod) % mod;
    }
    cout << (ans - cnt + mod) % mod<< '\n';
}

signed main() {
    ios::sync_with_stdio(false), cin.tie(nullptr);
    int tt = 1;
    while (tt--) solve();
    return 0;
}
github