Double Sum 2

根据lowbit不同进行分组, 不同组间直接算, 同一组间用 0-1 trie

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define int long long
#define LOCAL
#ifdef LOCAL
#define dbg(...) _((char*)#__VA_ARGS__,__VA_ARGS__)
template<typename t> void _(char* p,t x){cout<<p<<'='<<x<<'\n';}
template<typename t,typename... a>
void _(char* p,t x,a... y){while(*p!=',')cout<<*p++;cout<<'='<<x<<',';_(p+1,y...);}
#else
#define dbg(...) 0
#endif

int lowbit(int x) {
    return x & -x;
}

const int N = 6e6 + 10;
int nx[N][2];
int tot[N], num[N];
int cnt = 0;

void insert(int x, int p) {
    for (int i = 1; i <= 28; i++) {
        bool c = (1 << i) & x;
        if (!nx[p][c]) nx[p][c] = ++cnt;
        p = nx[p][c];
        tot[p] += x, num[p]++;
    }
}

int qr(int x, int p) {
    int ans = 0;
    for (int i = 1; i <= 28; i++) {
        bool c = (1 << i) & x;
        ans += (tot[nx[p][c]] + num[nx[p][c]] * x) >> i;
        if (!nx[p][!c]) nx[p][!c] = ++cnt;
        p = nx[p][!c];
    }
    return ans;
}

signed main() {
    ios::sync_with_stdio(false), cin.tie(nullptr);

    int n;
    cin >> n;
    vector<int> a(n+1);
    
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    int ans = 0;
    sort(a.begin() + 1,a.end(),[](int a, int b) { return lowbit(a) > lowbit(b);});
    
    for (int sum1 = 0, sum2 = 0, cnt1 = 0, cnt2 = 0, r = 0, i = 1; i <= n; i++) {
        ans += (a[i] * cnt1 + sum1) /  lowbit(a[i]);
        sum2 += a[i];
        cnt2++;
        int now = a[i] / lowbit(a[i]);

        insert(now,r);
        ans += qr(now,r);

        if (i < n && lowbit(a[i]) != lowbit(a[i+1])) {
            sum1 += sum2;
            cnt1 += cnt2;
            sum2 = cnt2 = 0;
            r = ++cnt;
        }
    }

    cout << ans;
    
    return 0;
}
github