F - Count Arrays

F - Count Arrays

题意:
给定长度为n的数组a,和整数m
求满足条件的数组b的个数取模998244353的结果
数组b的每个元素值<=m
b_i <= b_a_i

转换成图,有n个点和n条边,所得的图必定至少有一个环

对于每个环,要满足约束条件必定环上的点的值相等,可以把环缩成一个点,
这样多个环形成了多棵以环为根的有根树,用tarjan缩点后按照拓扑序进行DP。
记dp[i][j]表示第i个点为j时它的所有子树的满足约束的数目,从外往里进行转移即可。

Code


#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define int ll
const int mod = 998244353; 
#define dbg(x) cout << #x << "=" << x << endl


struct SCC {
    int n, now, cnt;
    vector<vector<int>> ver;
    vector<int> dfn, low, col, stk;
    SCC(int n) : n(n), ver(n + 1), low(n + 1) {
        dfn.resize(n + 1, -1);
        col.resize(n + 1, -1);
        now = cnt = 0;
    }
    void add(int x, int y) {
        ver[x].push_back(y);
    }
    void tarjan(int x) {
        dfn[x] = low[x] = now++;
        stk.push_back(x);
        for (auto y : ver[x]) {
            if (dfn[y] == -1) {
                tarjan(y);
                low[x] = min(low[x], low[y]);
            } else if (col[y] == -1) {
                low[x] = min(low[x], dfn[y]);
            }
        }
        if (dfn[x] == low[x]) {
            int pre;
            cnt++;
            do {
                pre = stk.back();
                col[pre] = cnt;
                stk.pop_back();
            } while (pre != x);
        }
    }
    auto work() {                       // cnt 新图的顶点数量
        for (int i = 1; i <= n; i++) {  // 避免图不连通
            if (dfn[i] == -1) {
                tarjan(i);
            }
        }
        vector<int> siz(cnt + 1);  // siz 每个 scc 中点的数量
        vector<vector<int>> adj(cnt + 1);
        for (int i = 1; i <= n; i++) {
            siz[col[i]]++;
            for (auto j : ver[i]) {
                int x = col[i], y = col[j];
                if (x != y) {
                    adj[x].push_back(y);
                }
            }
        }
        return make_tuple(cnt, adj, col, siz);
    }
    
    //SCC scc(n);
    //auto [cnt, adj, col, siz] = scc.work(); C++17 结构化绑定
};
signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> n >> m;
    SCC scc(n);
    for (int i = 1; i <= n; i++) {
        int x;
        cin >> x;
        scc.add(i,x);
    }
    vector<int> in(n+1), out(n+1);
    vector<vector<int>> dp(n+1,vector<int>(m+1,1));
    auto [cnt, adj, col, siz] = scc.work();
    for (int i = 1; i <= cnt; i++) {
        for (auto j : adj[i]) {
            in[j]++;
            out[i]++;
        }
    }

    // dp[i][j] 为节点i为j时的子树合法种数
    for (int i = cnt; i >= 1; i--) {
        for (auto j : adj[i]) {
            int cc = 0;
            for (int k = 1; k <= m; k++) {
                cc = (cc + dp[i][k]) % mod;
                dp[j][k] = (cc * dp[j][k]) % mod; 
            }
        }
    }

    vector<int> st;
    for (int i = 1; i <= cnt; i++) {
        if (out[i] == 0) {
            st.push_back(i);
        }
    }

    ll ans = 1;
    for (auto i : st) {
        ll cal = 0;
        for (int j = 1; j <= m; j++) {
            cal = (cal + dp[i][j]) % mod;
        }
        ans = (ans * cal) % mod;// 每个联通块之间独立
    }
    cout << ans;
    return 0;
}
github