题意:
给定长度为n的数组a,和整数m
求满足条件的数组b的个数取模998244353的结果
数组b的每个元素值<=m
b_i <= b_a_i
转换成图,有n个点和n条边,所得的图必定至少有一个环
对于每个环,要满足约束条件必定环上的点的值相等,可以把环缩成一个点,
这样多个环形成了多棵以环为根的有根树,用tarjan缩点后按照拓扑序进行DP。
记dp[i][j]表示第i个点为j时它的所有子树的满足约束的数目,从外往里进行转移即可。
Code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define int ll
const int mod = 998244353;
#define dbg(x) cout << #x << "=" << x << endl
struct SCC {
int n, now, cnt;
vector<vector<int>> ver;
vector<int> dfn, low, col, stk;
SCC(int n) : n(n), ver(n + 1), low(n + 1) {
dfn.resize(n + 1, -1);
col.resize(n + 1, -1);
now = cnt = 0;
}
void add(int x, int y) {
ver[x].push_back(y);
}
void tarjan(int x) {
dfn[x] = low[x] = now++;
stk.push_back(x);
for (auto y : ver[x]) {
if (dfn[y] == -1) {
tarjan(y);
low[x] = min(low[x], low[y]);
} else if (col[y] == -1) {
low[x] = min(low[x], dfn[y]);
}
}
if (dfn[x] == low[x]) {
int pre;
cnt++;
do {
pre = stk.back();
col[pre] = cnt;
stk.pop_back();
} while (pre != x);
}
}
auto work() { // cnt 新图的顶点数量
for (int i = 1; i <= n; i++) { // 避免图不连通
if (dfn[i] == -1) {
tarjan(i);
}
}
vector<int> siz(cnt + 1); // siz 每个 scc 中点的数量
vector<vector<int>> adj(cnt + 1);
for (int i = 1; i <= n; i++) {
siz[col[i]]++;
for (auto j : ver[i]) {
int x = col[i], y = col[j];
if (x != y) {
adj[x].push_back(y);
}
}
}
return make_tuple(cnt, adj, col, siz);
}
//SCC scc(n);
//auto [cnt, adj, col, siz] = scc.work(); C++17 结构化绑定
};
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m;
cin >> n >> m;
SCC scc(n);
for (int i = 1; i <= n; i++) {
int x;
cin >> x;
scc.add(i,x);
}
vector<int> in(n+1), out(n+1);
vector<vector<int>> dp(n+1,vector<int>(m+1,1));
auto [cnt, adj, col, siz] = scc.work();
for (int i = 1; i <= cnt; i++) {
for (auto j : adj[i]) {
in[j]++;
out[i]++;
}
}
// dp[i][j] 为节点i为j时的子树合法种数
for (int i = cnt; i >= 1; i--) {
for (auto j : adj[i]) {
int cc = 0;
for (int k = 1; k <= m; k++) {
cc = (cc + dp[i][k]) % mod;
dp[j][k] = (cc * dp[j][k]) % mod;
}
}
}
vector<int> st;
for (int i = 1; i <= cnt; i++) {
if (out[i] == 0) {
st.push_back(i);
}
}
ll ans = 1;
for (auto i : st) {
ll cal = 0;
for (int j = 1; j <= m; j++) {
cal = (cal + dp[i][j]) % mod;
}
ans = (ans * cal) % mod;// 每个联通块之间独立
}
cout << ans;
return 0;
}