競プロメモ

競プロなど

JOI'22 春合宿参加記

JOI'22 春合宿参加記

春合宿はまじで楽しいので全人類行きましょう。

結果

A B C 日計 当日順位 順位
Day1 66 10 0 76 76 17 17
Day2 20 15 64 99 175 26 21
Day3 5 53 6 64 239 17 22
Day4 7 13 14 34 273 22 22

感想

さすがにやっぱりレベルが違って、全く勝負にならなかった。が、全体としてかなり楽しくて超満足なのでぜひ皆さん春合宿に行きましょう!

ほとんど解いたことのない(おい)春合宿タイプの JOI の問題は、5 時間 3 問想定なだけあって、一つの問題に対してかなり長い時間をかけてひとステップずつ考察していくので、かなり楽しい。また今度たくさん精進したいな。

また、初めての競プロ系オンサイトで割と交流できてよかった。みつばちおもろい。

Day1

B と C が全くわからなかったのでほとんど A に使った。

A: jail

まず小課題 3 の計算量が明らかに $O(QNM!)$ くらいなので、順列全探索でおそらく解けるということを考えると、ある人に対して操作を行うときは全部やってしまえばいいことがわかる。

各順列において、操作が可能かどうかを愚直に調べると、小課題 2, 3 が通る。

次に小課題 1 を見ると、列においてこの問題を解けば良いので、いい感じにやる。

ここらへんで、列において人 $i$, $j$ が干渉する条件を考えていたので、木においても同じことを考えると、$P_i$ を人 $i$ の通るパス $S_i$ $T_i$ とすると、

  • $P_i$ 内に $S_j$, $T_j$ がともに含まれているとき、No

  • $P_i$ 内に $S_j$ が含まれているときと、$P_j$ 内に $T_j$ が含まれているとき、j を i より先に動かす

  • 上記の二つを満たさないとき、制約なし

という条件が得られるので、いい感じに処理の順番の制約を作ったらそれをトポソして判定。小課題 4 では $O(QM^{2}N)$, 小課題 5 では $O(Q(NMlogN+M^{2}logN))$ くらいで解く。

小課題 6 では $P_i$ をすべて陽に持てるが、$O(パスの長さ)$ でこれを求めなければならないので、lca を書いてちょちょっとやると解ける。実装間に合わず。

小課題を一つずつ見て行って、要求されていることを一つずつ考察していくと進んでいったので楽しかった。

f:id:RheoTommy:20220321083619p:plain

小課題 1

#include <algorithm>
#include <bits/stdc++.h>
#include <numeric>
#include <shared_mutex>
using namespace std;

void solve() {
    int n;
    cin >> n;
    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        graph[a].push_back(b);
        graph[b].push_back(a);
    }

    int m;
    cin >> m;
    vector<int> s(m);
    vector<int> t(m);
    for (int j = 0; j < m; j++)
        cin >> s[j] >> t[j], s[j]--, t[j]--;

    bool flag1 = true;
    {
        vector<int> si(n, -1);
        vector<int> ti(n - 1);
        for (int j = 0; j < m; j++)
            si[s[j]] = j, ti[t[j]] = j;

        set<int> t_st;
        for (int j = 0; j < m; j++)
            t_st.emplace(t[j]);

        for (int i = 0; i < n; i++) {
            if (si[i] == -1)
                continue;

            int j = si[i];
            if (s[j] > t[j])
                continue;

            flag1 &= *t_st.lower_bound(i) == t[j];
            t_st.erase(t[j]);
        }
    }

    for (int j = 0; j < m; j++)
        swap(s[j], t[j]);

    bool flag2 = true;
    {
        vector<int> si(n, -1);
        vector<int> ti(n - 1);
        for (int j = 0; j < m; j++)
            si[s[j]] = j, ti[t[j]] = j;

        set<int> t_st;
        for (int j = 0; j < m; j++)
            t_st.emplace(t[j]);

        for (int i = 0; i < n; i++) {
            if (si[i] == -1)
                continue;

            int j = si[i];
            if (s[j] > t[j])
                continue;

            flag2 &= *t_st.lower_bound(i) == t[j];
            t_st.erase(t[j]);
        }
    }

    cout << (flag1 && flag2 ? "Yes" : "No") << '\n';
}

int main() {
    int q;
    cin >> q;
    while (q--)
        solve();
}

小課題 2, 3

#include <algorithm>
#include <bits/stdc++.h>
#include <numeric>
using namespace std;

void solve() {
    int n;
    cin >> n;
    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        graph[a].push_back(b);
        graph[b].push_back(a);
    }

    int m;
    cin >> m;
    vector<int> s(m);
    vector<int> t(m);
    for (int j = 0; j < m; j++)
        cin >> s[j] >> t[j], s[j]--, t[j]--;

    vector<bool> staying(n, false);
    for (int j = 0; j < m; j++)
        staying[s[j]] = true;

    vector<int> p(m);
    iota(p.begin(), p.end(), 0);

    auto check_dfs = [&](auto &&dfs, int now, int par, int t) -> bool {
        if (now == t)
            return true;

        bool flag = false;
        for (auto next : graph[now]) {
            if (next == par)
                continue;

            if (!staying[next])
                flag |= dfs(dfs, next, now, t);
        }

        return flag;
    };

    bool ans = false;

    do {
        staying = vector<bool>(n, false);
        for (int j = 0; j < m; j++)
            staying[s[j]] = true;

        bool flag = true;
        for (int k = 0; k < m; k++) {
            flag &= check_dfs(check_dfs, s[p[k]], -1, t[p[k]]);
            staying[s[p[k]]] = false;
            staying[t[p[k]]] = true;
        }
        ans |= flag;
    } while (next_permutation(p.begin(), p.end()));

    cout << (ans ? "Yes" : "No") << '\n';
}

int main() {
    int q;
    cin >> q;
    while (q--)
        solve();
}

小課題 4

#include <algorithm>
#include <bits/stdc++.h>
#include <numeric>
using namespace std;

void solve() {
    int n;
    cin >> n;
    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        graph[a].push_back(b);
        graph[b].push_back(a);
    }

    int m;
    cin >> m;
    vector<int> s(m);
    vector<int> t(m);
    for (int j = 0; j < m; j++)
        cin >> s[j] >> t[j], s[j]--, t[j]--;

    auto get_path_set = [&](int s, int t) -> set<int> {
        vector<int> depth(n, -1);
        depth[s] = 0;
        auto dfs = [&](auto &&dfs, int now, int par) -> void {
            for (auto next : graph[now]) {
                if (next == par)
                    continue;
                if (depth[next] != -1)
                    continue;

                depth[next] = depth[now] + 1;
                dfs(dfs, next, now);
            }
        };
        dfs(dfs, s, -1);

        set<int> st;
        st.emplace(s);

        int now = t;
        while (now != s) {
            st.emplace(now);
            for (auto next : graph[now]) {
                if (depth[next] + 1 == depth[now]) {
                    now = next;
                    break;
                }
            }
        }
        return st;
    };

    vector<vector<int>> topo_graph(m);
    vector<int> degree(m);

    for (int i = 0; i < m; i++) {
        for (int j = i + 1; j < m; j++) {
            auto pi = get_path_set(s[i], t[i]);
            auto pj = get_path_set(s[j], t[j]);

            // cout << "HERE i " << pi.size() << endl;
            // cout << "HERE j " << pj.size() << endl;
            // for (auto p : pi)
            //     cout << p << ' ';
            // cout << endl;
            // for (auto p : pj)
            //     cout << p << ' ';
            // cout << endl;

            bool ng_flag = true;
            for (auto t : pj)
                ng_flag &= pi.count(t);
            if (ng_flag) {
                cout << "No" << endl;
                return;
            }

            if (pj.count(s[i]) || pi.count(t[j]))
                topo_graph[i].push_back(j), degree[j]++;
            if (pi.count(s[j]) || pj.count(t[i]))
                topo_graph[j].push_back(i), degree[i]++;
        }
    }

    vector<int> sorted;
    queue<int> que;
    for (int i = 0; i < m; i++)
        if (degree[i] == 0)
            que.emplace(i);

    while (!que.empty()) {
        int now = que.front();
        sorted.emplace_back(now);
        que.pop();
        for (auto next : topo_graph[now]) {
            degree[next]--;
            if (degree[next] == 0)
                que.emplace(next);
        }
    }

    cout << (sorted.size() == m ? "Yes" : "No") << endl;
}

int main() {
    int q;
    cin >> q;
    while (q--)
        solve();
}

小課題 5

#include <algorithm>
#include <bits/stdc++.h>
#include <numeric>
using namespace std;

void solve() {
    int n;
    cin >> n;
    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        graph[a].emplace_back(b);
        graph[b].emplace_back(a);
    }

    int m;
    cin >> m;
    vector<int> s(m);
    vector<int> t(m);
    for (int j = 0; j < m; j++)
        cin >> s[j] >> t[j], s[j]--, t[j]--;

    auto get_path = [&](int s, int t) -> vector<int> {
        vector<int> depth(n, -100);
        depth[s] = 0;
        auto dfs = [&](auto &&dfs, int now, int par) -> void {
            if (now == t)
                return;

            for (auto next : graph[now]) {
                if (next == par)
                    continue;
                if (depth[next] != -100)
                    continue;

                depth[next] = depth[now] + 1;
                dfs(dfs, next, now);
            }
        };
        dfs(dfs, s, -1);

        vector<int> v;
        v.emplace_back(s);

        int now = t;
        while (now != s) {
            v.emplace_back(now);
            for (auto next : graph[now]) {
                if (depth[next] + 1 == depth[now]) {
                    now = next;
                    break;
                }
            }
        }

        sort(v.begin(), v.end());
        return v;
    };

    vector<vector<int>> paths(m);
    for (int i = 0; i < m; i++)
        paths[i] = get_path(s[i], t[i]);

    auto contains = [&](int i, int x) -> bool {
        return lower_bound(paths[i].begin(), paths[i].end(), x) !=
                   paths[i].end() &&
               *lower_bound(paths[i].begin(), paths[i].end(), x) == x;
    };

    vector<vector<int>> topo_graph(m);
    vector<int> degree(m);

    for (int i = 0; i < m; i++) {
        for (int j = i + 1; j < m; j++) {
            if ((contains(i, s[j]) && contains(i, t[j])) ||
                (contains(j, s[i]) && contains(j, t[i]))) {
                cout << "No" << '\n';
                return;
            }

            if (contains(j, s[i]) || contains(i, t[j]))
                topo_graph[i].push_back(j), degree[j]++;
            if (contains(i, s[j]) || contains(j, t[i]))
                topo_graph[j].push_back(i), degree[i]++;
        }
    }

    vector<int> sorted;
    queue<int> que;
    for (int i = 0; i < m; i++)
        if (degree[i] == 0)
            que.emplace(i);

    while (!que.empty()) {
        int now = que.front();
        sorted.emplace_back(now);
        que.pop();
        for (auto next : topo_graph[now]) {
            degree[next]--;
            if (degree[next] == 0)
                que.emplace(next);
        }
    }

    cout << (sorted.size() == m ? "Yes" : "No") << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int q;
    cin >> q;
    while (q--)
        solve();
}

小課題 6

#include <algorithm>
#include <bits/stdc++.h>
#include <numeric>
using namespace std;

void solve() {
    int n;
    cin >> n;
    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        graph[a].emplace_back(b);
        graph[b].emplace_back(a);
    }

    int m;
    cin >> m;
    vector<int> s(m);
    vector<int> t(m);
    for (int j = 0; j < m; j++)
        cin >> s[j] >> t[j], s[j]--, t[j]--;

    vector<int> depth(n, -1);
    depth[0] = 0;
    vector<int> ord(n);
    vector<int> num_to_ord(n);
    vector<vector<int>> doubling(18, vector<int>(n));
    vector<int> sz(n);
    doubling[0][0] = 0;
    int k = 0;
    auto dfs = [&](auto &&dfs, int now, int par) -> void {
        sz[now] = 1;
        ord[k] = now;
        num_to_ord[now] = k++;
        for (auto next : graph[now]) {
            if (next == par)
                continue;
            depth[next] = depth[now] + 1;
            doubling[0][next] = now;
            dfs(dfs, next, now);
            sz[now] += sz[next];
        }
    };
    dfs(dfs, 0, -1);

    for (int k = 0; k < 17; k++) {
        for (int i = 0; i < n; i++) {
            doubling[k + 1][i] = doubling[k][doubling[k][i]];
        }
    }

    auto lca = [&](int x, int y) {
        if (depth[x] > depth[y])
            swap(x, y);

        int diff = depth[y] - depth[x];
        for (int k = 17; k >= 0; k--) {
            if (diff >> k & 1)
                y = doubling[k][y];
        }

        if (x == y)
            return x;

        for (int k = 17; k >= 0; k--) {
            if (doubling[k][x] != doubling[k][y])
                x = doubling[k][x], y = doubling[k][y];
        }

        return doubling[0][x];
    };

    auto get_path = [&](int x, int y) {
        int l = lca(x, y);
        set<int> st;
        int lo = num_to_ord[l];
        int xo = num_to_ord[x];
        int yo = num_to_ord[y];

        while (x != l) {
            st.insert(x);
            x = doubling[0][x];
        }
        while (y != l) {
            st.insert(y);
            y = doubling[0][y];
        }
        st.insert(l);

        vector<int> v;
        for (auto si : st)
            v.push_back(si);
        sort(v.begin(), v.end());
        return v;
    };

    vector<vector<int>> paths(m);
    for (int i = 0; i < m; i++)
        paths[i] = get_path(s[i], t[i]);

    auto contains = [&](int i, int x) -> bool {
        return lower_bound(paths[i].begin(), paths[i].end(), x) !=
                   paths[i].end() &&
               *lower_bound(paths[i].begin(), paths[i].end(), x) == x;
    };

    vector<vector<int>> includes_x(n);
    for (int i = 0; i < m; i++) {
        for (auto p : paths[i])
            includes_x[p].emplace_back(i);
    }

    vector<set<int>> topo_graph(m);
    vector<int> degree(m);

    for (int i = 0; i < m; i++) {
        set<int> set_s;
        set<int> set_t;
        for (auto j : includes_x[s[i]])
            set_s.emplace(j);
        for (auto j : includes_x[t[i]])
            set_t.emplace(j);

        for (auto j : set_s) {
            if (i == j)
                continue;
            if (set_t.count(j)) {
                cout << "No" << '\n';
                return;
            }
            if (!topo_graph[i].count(j)) {
                topo_graph[i].insert(j);
                degree[j]++;
            }
        }
        for (auto j : set_t) {
            if (i == j)
                continue;
            if (!topo_graph[j].count(i)) {
                topo_graph[j].insert(i);
                degree[i]++;
            }
        }
    }

    vector<int> sorted;
    queue<int> que;
    for (int i = 0; i < m; i++)
        if (degree[i] == 0)
            que.emplace(i);

    while (!que.empty()) {
        int now = que.front();
        sorted.emplace_back(now);
        que.pop();
        for (auto next : topo_graph[now]) {
            degree[next]--;
            if (degree[next] == 0)
                que.emplace(next);
        }
    }

    cout << (sorted.size() == m ? "Yes" : "No") << '\n';
}

int main() {
    // ios::sync_with_stdio(false);
    // cin.tie(nullptr);
    // cout.tie(nullptr);

    int q;
    cin >> q;
    while (q--)
        solve();
}

B: kyoto

小課題 1 が自明なので取る。

小課題 2 がちまちま考えても一生わからず。 $A_i \leq s$, $B_j \leq t$ の時のなんか、みたいな方針も考えたけどうまくいかず。

f:id:RheoTommy:20220321083815p:plain

小課題 1

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll linf = 1001001001001001001ll;

template <typename T> bool chmin(T &x, T y) {
    if (x > y) {
        x = y;
        return true;
    }
    return false;
}

int main() {
    int h, w;
    cin >> h >> w;
    vector<ll> a(h);
    vector<ll> b(w);
    for (int i = 0; i < h; i++)
        cin >> a[i];
    for (int j = 0; j < w; j++)
        cin >> b[j];

    vector<vector<ll>> dp(h, vector<ll>(w, linf));

    dp[0][0] = 0;
    for (int i = 0; i < h; i++) {
        for (int j = 0; j < w; j++) {
            if (i + 1 < h) {
                chmin(dp[i + 1][j], dp[i][j] + b[j]);
            }
            if (j + 1 < w) {
                chmin(dp[i][j + 1], dp[i][j] + a[i]);
            }
        }
    }

    cout << dp[h - 1][w - 1] << '\n';
}

C: misspelling

$S_i \leq S_j$ を同値変形できませんでした!!!!! 隣接する文字の大小関係だけに持ち込んだら確かにある程度行けそうだけど簡単ではないだろ。

Day2

小課題をちまちまやりすぎて時間が足りなくなった。考察は小課題をちゃんと一つ一つ見て行ってもいいけど、実装はある程度まとめてやった方がよさそう(自明取らずに中途半端にバグらせて全部落とすのが怖くてついこまめに実装してしまう)

A: copypaste3

小課題 1 は普通に性質を考えて場合分け。小課題 2 は $X$, $Y$ が文字列の長さで表せるので、頂点数 $O(N^{2})$ ダイクストラをする。

小課題 3 で、$X$, $Y$ がともに $S$ の部分文字列にしかならないことに気づいたので、状態量 $O(N^{4})$, 遷移 $O(N^{2})$ のダイクストラで通す。

priority_queue のキーに tuple じゃなくて自作構造体を使おうとしたら、Ord の実装の仕方がわからなかったり、普通にバグらせたりしてここまでで 70 分かかった。

そのあとも当然考察をして、S を作る際に

  • $X$ が空な状態から、A か C のみを用いて適当な文字列を作る
  • すべて切り取って $X$ を空にし、$Y$ を更新する

のステップを繰り返すことは分かったので、いい感じの区間 DP を考えるといいなーというところまでわかる。

ここで、適当な文字列 $S_i$ を作るためのコストを $S_i$ に含まれる任意の部分文字列を作るコストが既知としてメモ化再帰する方針を立ててしまい、$Y$ の文字列と $S$ の文字列の組が状態になるような方針しか生えなくなってしまった。

すでに S を作る際の操作手順を考えていたので、

  • 最初, $X=空文字列$, $Y=S_y$, 完成

というプロセスから素直に $X=空文字列$, $Y=S_y$ を一つの状態とする DP が思いついてもよかったと今は感じているけど、できなかった。

z_algo もロリハもかけないので DP 高速化パートで困るのはわかっていたけど、一時間くらいは考察していたので、せめて $N \leq 200$ は取りたかった。

f:id:RheoTommy:20220321192848p:plain

小課題 1

#include <bits/stdc++.h>
#include <cassert>

using namespace std;
using ll = long long;

int main() {
    int n;
    string s;
    ll a, b, c;
    cin >> n >> s >> a >> b >> c;

    assert(n == 3);

    if (s[0] == s[1] && s[1] == s[2]) {
        cout << min(3 * a, a + b + c * 3) << endl;
    } else if (s[0] == s[1] || s[0] == s[2] || s[1] == s[2]) {
        cout << min(a * 3, a + b + a + c * 2) << endl;
    } else {
        cout << a * 3 << endl;
    }
}

小課題 2

#include <bits/stdc++.h>
#include <cassert>
#include <queue>
#include <tuple>

using namespace std;
using ll = long long;

const ll linf = 1001001001001001001ll;

template <typename T> bool chmin(T &x, T y) {
    if (x > y) {
        x = y;
        return true;
    }
    return false;
}

int main() {
    int n;
    string s;
    ll a, b, c;
    cin >> n >> s >> a >> b >> c;

    for (int i = 0; i < n; i++) assert(s[i] == 'a');

    vector<vector<ll>> dp(n + 1, vector<ll>(n + 1, linf));
    dp[0][0] = 0;
    priority_queue<tuple<ll, int, int>, vector<tuple<ll, int, int>>, greater<>>
        pri;
    pri.emplace(0, 0, 0);

    while (!pri.empty()) {
        ll cost;
        int x, y;
        tie(cost, x, y) = pri.top();
        pri.pop();

        if (x + 1 <= n && chmin(dp[x + 1][y], cost + a))
            pri.emplace(cost + a, x + 1, y);

        if (chmin(dp[0][x], cost + b)) pri.emplace(cost + b, 0, x);

        if (x + y <= n && chmin(dp[x + y][y], cost + c))
            pri.emplace(cost + c, x + y, y);
    }

    ll ans = linf;

    for (int y = 0; y <= n; y++) chmin(ans, dp[n][y]);

    cout << ans << endl;
}

小課題 3

#include <bits/stdc++.h>
#include <cassert>
#include <queue>
#include <tuple>

using namespace std;
using ll = long long;

const ll linf = 1001001001001001001ll;

template <typename T> bool chmin(T &x, T y) {
    if (x > y) {
        x = y;
        return true;
    }
    return false;
}

int main() {
    int n;
    string s;
    ll a, b, c;
    cin >> n >> s >> a >> b >> c;

    assert(n <= 30);

    map<pair<int, int>, string> mp;
    vector<string> subs;
    for (int l = 0; l < n; l++) {
        for (int r = l + 1; r <= n; r++) {
            mp[{l, r}] = s.substr(l, r - l);
        }
    }
    map<string, vector<pair<int, int>>> st;
    for (auto [k, v] : mp) st[v].emplace_back(k);

    vector<vector<vector<vector<ll>>>> dp(
        n + 1, vector<vector<vector<ll>>>(
                   n + 1, vector<vector<ll>>(n + 1, vector<ll>(n + 1, linf))));
    priority_queue<tuple<ll, int, int, int, int>,
                   vector<tuple<ll, int, int, int, int>>, greater<>>
        pri;

    for (int x = 0; x < n; x++) {
        for (int y = 0; y < n; y++) {
            dp[x][x][y][y] = 0;
            pri.emplace(0, x, x, y, y);
        }
    }

    while (!pri.empty()) {
        ll cost;
        int lx, rx, ly, ry;
        tie(cost, lx, rx, ly, ry) = pri.top();
        pri.pop();

        for (auto p : st[mp[{lx, rx + 1}]]) {
            auto [l, r] = p;
            if (chmin(dp[l][r][ly][ry], cost + a))
                pri.emplace(cost + a, l, r, ly, ry);
        }

        if (chmin(dp[0][0][lx][rx], cost + b))
            pri.emplace(cost + b, 0, 0, lx, rx);

        for (auto p : st[mp[{lx, rx}] + mp[{ly, ry}]]) {
            auto [l, r] = p;
            if (chmin(dp[l][r][ly][ry], cost + c))
                pri.emplace(cost + c, l, r, ly, ry);
        }
    }

    ll ans = linf;

    for (int ly = 0; ly < n; ly++) {
        for (int ry = 0; ry <= n; ry++) {
            chmin(ans, dp[0][n][ly][ry]);
        }
    }

    cout << ans << endl;
}

B: flights

基本的には Communication Task なので後に回していて、C が少し行き詰った時に自明だけ取りに来た。

20 bit で $X$ と $Y$ の上位 8 bit だけ送り、$Y$ の候補 $2^{8}$ すべてを送るので、$14 \times 2^{8} = 3584$ で 15 点。

バグらせて 50 分使った。

f:id:RheoTommy:20220321193715p:plain

小課題 1

#include "Ali.h"
#include <bits/stdc++.h>
#include <string>

using namespace std;

namespace {
vector<vector<int>> graph;
int n;

string num_to_bits(int x) {
    string s;
    for (int i = 0; i < 14; i++) {
        s.push_back(((x >> (13 - i)) & 1) ? '1' : '0');
    }
    return s;
}

int bits_to_num(string s) {
    int x = 0;
    for (int i = 0; i < s.size(); i++) {
        x *= 2;
        x += s[i] == '1';
    }
    return x;
}
} // namespace

void Init(int N, std::vector<int> U, std::vector<int> V) {
    for (int i = 0; i < N; i++) SetID(i, i);
    graph = vector<vector<int>>(N);
    for (int i = 0; i < N - 1; i++) {
        graph[U[i]].emplace_back(V[i]);
        graph[V[i]].emplace_back(U[i]);
    }
    n = N;
}

std::string SendA(std::string S) {
    int x = bits_to_num(S.substr(0, 14));
    int y = bits_to_num(S.substr(14, 6));
    y <<= 8;

    vector<int> depth(10000, -1);
    depth[x] = 0;

    auto dfs = [&](auto &&dfs, int now, int par) -> void {
        for (auto next : graph[now]) {
            if (next == par) continue;
            depth[next] = depth[now] + 1;
            dfs(dfs, next, now);
        }
    };

    dfs(dfs, x, -1);

    vector<int> y_list;
    for (int k = 0; k < (1 << 8); k++) y_list.emplace_back(y + k);
    map<int, int> mp;
    for (auto yi : y_list) mp[yi] = depth[yi];

    string res;
    for (auto [k, v] : mp) {
        res += num_to_bits(v);
    }
    return res;
}
#include "Benjamin.h"
#include <bits/stdc++.h>
#include <string>
#include <vector>
using namespace std;

namespace {
int y;
string num_to_bits(int x) {
    string s;
    for (int i = 0; i < 14; i++) {
        s.push_back(((x >> (13 - i)) & 1) ? '1' : '0');
    }
    return s;
}

int bits_to_num(string s) {
    int x = 0;
    for (int i = 0; i < 14; i++) {
        x *= 2;
        x += s[i] == '1';
    }
    return x;
}

} // namespace

std::string SendB(int N, int X, int Y) {
    y = Y;
    string x = num_to_bits(X);
    string yi = num_to_bits(Y);
    return x + yi.substr(0, 6);
}

int Answer(std::string T) {
    int len = T.size() / 14;
    vector<int> y_depth;
    for (int l = 0; l < len; l++) {
        y_depth.emplace_back(bits_to_num(T.substr(l * 14, 14)));
    }
    int yz = y & ((1 << 8) - 1);
    return y_depth[yz];
}

C: team

小課題を全部愚直に考察実装していたら時間を食いつぶしてしまった。

小課題 1 で全探索。小課題 2 では、一人固定して、ソートを使って残り二人の組み合わせを $O(NlogN)$ で舐めて $O(N^{2}logN)$ で通した。

小課題 3 は $(x, y, z)$ の組み合わせ $5^{3}$ すべてに対して $O(N)$ で判定。小課題 4 は $(x, y, z)$ の組み合わせが $20^{3}$ で $N$ より小さいので、$N=20^{3}$ として小課題 2 の解法を使う。

小課題 5 は $300^{3}$ を考えたくなったので、$(x, y, z)$ それぞれの組に対して、

  • $X_i = x$, $Y_i < y$, $Z_i < z$
  • $Y_i = y$, $X_i < x$, $Z_i < z$
  • $Z_i = z$, $X_i < x$, $Y_i < Y$

を満たすモノが存在するかどうかを高速に求めたくなり、$X_i = x$, $Y_i < y$, $Z_i < z$ を $X_i$ ごとに $Y$, $Z$ で二次元累積和を持つことで $O(1)$ にした。

小課題 6 は $4000N$ が間に合いそうなので、$X$ を固定して各 $O(N)$ で求めればよさそう。

$X=X_i$ のとき、$X_j < X$ を満たすもののうち大抵 $Y$ が最大のものと $Z$ が最大になるものが組になる(最大値同士の $i$ でない限り)ので、$Y$ と $Z$ も同じように固定する求めることで最大値同士のペアを無視できて、$X_i=X$ で $max Y > Y_i$ かつ $max X > Z_i$ な $i$ があるかをデータ構造で高速に判定したくなる。

$4000N$ が間に合うなら $X$ を愚直に全部チェックすればいいはず。実装間に合わず。

満点解法も、これを基準にすると二次元セグ木があればできそうであることは分かったが、実装時間もなかったので不可能だった。

想定満点解法は、部分点をまじめに考えない方が出てくると思う。

f:id:RheoTommy:20220321195729p:plain

小課題 1, 2

#include <algorithm>
#include <bits/stdc++.h>

using namespace std;
const int inf = 1001001001;

template <typename T> bool chmax(T &x, T y) {
    if (x < y) {
        x = y;
        return true;
    }
    return false;
}

int main() {
    int n;
    cin >> n;

    assert(n <= 4000);

    vector<int> x(n);
    vector<int> y(n);
    vector<int> z(n);

    vector<tuple<int, int, int>> t;

    for (int i = 0; i < n; i++) cin >> x[i] >> y[i] >> z[i];
    for (int i = 0; i < n; i++) t.emplace_back(x[i], y[i], z[i]);

    int ans = -inf;

    for (int i = 0; i < n; i++) {
        int xi = x[i];
        int yi = y[i];
        int zi = z[i];
        vector<pair<int, int>> yz;

        for (int j = 0; j < n; j++) {
            if (x[j] < xi && (y[j] > yi || z[j] > zi))
                yz.emplace_back(y[j], z[j]);
        }

        sort(yz.begin(), yz.end());

        int m = yz.size();

        if (m < 2) continue;

        vector<int> mx(m + 1, -inf);
        vector<int> yy(m);
        for (int j = 0; j < m; j++) mx[j + 1] = max(mx[j], yz[j].second);
        for (int j = 0; j < m; j++) yy[j] = yz[j].first;

        for (int j = 0; j < m; j++) {
            int k = lower_bound(yy.begin(), yy.end(), yy[j]) - yy.begin();

            if (yy[j] > yi && mx[k] > zi && mx[k] > yz[j].second)
                chmax(ans, xi + yy[j] + mx[k]);
            // cout << xi << ' ' << yy[j] << ' ' << mx[k] << endl;
        }
    }

    if (ans < 0) ans = -1;
    cout << ans << endl;
}

小課題 3

#include <algorithm>
#include <bits/stdc++.h>
#include <cassert>

using namespace std;
const int inf = 1001001001;

template <typename T> bool chmax(T &x, T y) {
    if (x < y) {
        x = y;
        return true;
    }
    return false;
}

int main() {
    cin.tie(nullptr);
    cout.tie(nullptr);
    ios::sync_with_stdio(false);

    int n;
    cin >> n;

    vector<int> x(n);
    vector<int> y(n);
    vector<int> z(n);

    vector<tuple<int, int, int>> t;

    for (int i = 0; i < n; i++) {
        cin >> x[i] >> y[i] >> z[i];
        assert(x[i] <= 20);
        assert(y[i] <= 20);
        assert(z[i] <= 20);
    }
    for (int i = 0; i < n; i++) t.emplace_back(x[i], y[i], z[i]);

    int ans = -inf;

    for (int xt = 20; xt >= 1; xt--) {
        for (int yt = 20; yt >= 1; yt--) {
            for (int zt = 20; zt >= 1; zt--) {
                if (xt + yt + zt <= ans) continue;

                bool x_flag = false;
                bool y_flag = false;
                bool z_flag = false;

                for (int i = 0; i < n; i++) {
                    x_flag |= x[i] == xt && y[i] < yt && z[i] < zt;
                    y_flag |= y[i] == yt && x[i] < xt && z[i] < zt;
                    z_flag |= z[i] == zt && x[i] < xt && y[i] < yt;
                }

                if (x_flag && y_flag && z_flag) chmax(ans, xt + yt + zt);
            }
        }
    }

    if (ans < 0) ans = -1;
    cout << ans << endl;
}

小課題 4

#include <algorithm>
#include <bits/stdc++.h>
#include <cassert>
#include <tuple>

using namespace std;
const int inf = 1001001001;

template <typename T> bool chmax(T &x, T y) {
    if (x < y) {
        x = y;
        return true;
    }
    return false;
}

int main() {
    int n;
    cin >> n;

    set<tuple<int, int, int>> st;
    for (int i = 0; i < n; i++) {
        int x, y, z;
        cin >> x >> y >> z;
        st.emplace(x, y, z);
    }

    n = st.size();

    vector<int> x;
    vector<int> y;
    vector<int> z;

    vector<tuple<int, int, int>> t;

    for (auto tup : st) {
        int xi, yi, zi;
        tie(xi, yi, zi) = tup;
        x.emplace_back(xi);
        y.emplace_back(yi);
        z.emplace_back(zi);
    }
    for (int i = 0; i < n; i++) t.emplace_back(x[i], y[i], z[i]);

    int ans = -inf;

    for (int i = 0; i < n; i++) {
        int xi = x[i];
        int yi = y[i];
        int zi = z[i];
        vector<pair<int, int>> yz;

        for (int j = 0; j < n; j++) {
            if (x[j] < xi && (y[j] > yi || z[j] > zi))
                yz.emplace_back(y[j], z[j]);
        }

        sort(yz.begin(), yz.end());

        int m = yz.size();

        if (m < 2) continue;

        vector<int> mx(m + 1, -inf);
        vector<int> yy(m);
        for (int j = 0; j < m; j++) mx[j + 1] = max(mx[j], yz[j].second);
        for (int j = 0; j < m; j++) yy[j] = yz[j].first;

        for (int j = 0; j < m; j++) {
            int k = lower_bound(yy.begin(), yy.end(), yy[j]) - yy.begin();

            if (yy[j] > yi && mx[k] > zi && mx[k] > yz[j].second)
                chmax(ans, xi + yy[j] + mx[k]);
            // cout << xi << ' ' << yy[j] << ' ' << mx[k] << endl;
        }
    }

    if (ans < 0) ans = -1;
    cout << ans << endl;
}

小課題 5

#include <algorithm>
#include <bits/stdc++.h>
#include <cassert>
#include <tuple>

using namespace std;
const int inf = 1001001001;

template <typename T> bool chmax(T &x, T y) {
    if (x < y) {
        x = y;
        return true;
    }
    return false;
}

int main() {
    int n;
    cin >> n;

    vector<vector<vector<bool>>> yz(
        301, vector<vector<bool>>(302, vector<bool>(302, 0)));
    vector<vector<vector<bool>>> xz(
        301, vector<vector<bool>>(302, vector<bool>(302, 0)));
    vector<vector<vector<bool>>> xy(
        301, vector<vector<bool>>(302, vector<bool>(302, 0)));

    vector<int> xn(n);
    vector<int> yn(n);
    vector<int> zn(n);

    for (int i = 0; i < n; i++) {
        cin >> xn[i] >> yn[i] >> zn[i];
        assert(xn[i] <= 300 && yn[i] <= 300 && zn[i] <= 300);
        yz[xn[i]][yn[i]][zn[i]] = true;
        xz[yn[i]][xn[i]][zn[i]] = true;
        xy[zn[i]][xn[i]][yn[i]] = true;
    }

    for (int x = 1; x <= 300; x++) {
        // yz[x][0][0] = true;
        for (int y = 0; y <= 300; y++) {
            for (int z = 0; z <= 300; z++) {
                yz[x][y][z + 1] = yz[x][y][z + 1] || yz[x][y][z];
            }
        }

        for (int z = 0; z <= 300; z++) {
            for (int y = 0; y <= 300; y++) {
                yz[x][y + 1][z] = yz[x][y + 1][z] || yz[x][y][z];
            }
        }
    }

    for (int y = 1; y <= 300; y++) {
        // xz[y][0][0] = true;
        for (int x = 0; x <= 300; x++) {
            for (int z = 0; z <= 300; z++) {
                xz[y][x][z + 1] = xz[y][x][z + 1] || xz[y][x][z];
            }
        }

        for (int z = 0; z <= 300; z++) {
            for (int x = 0; x <= 300; x++) {
                xz[y][x + 1][z] = xz[y][x + 1][z] || xz[y][x][z];
            }
        }
    }

    for (int z = 1; z <= 300; z++) {
        // xy[z][0][0] = true;
        for (int x = 0; x <= 300; x++) {
            for (int y = 0; y <= 300; y++) {
                xy[z][x][y + 1] = xy[z][x][y + 1] || xy[z][x][y];
            }
        }

        for (int y = 0; y <= 300; y++) {
            for (int x = 0; x <= 300; x++) {
                xy[z][x + 1][y] = xy[z][x + 1][y] || xy[z][x][y];
            }
        }
    }

    int ans = -inf;

    for (int x = 1; x <= 300; x++) {
        for (int y = 1; y <= 300; y++) {
            for (int z = 1; z <= 300; z++) {
                if (yz[x][y - 1][z - 1] > 0 && xz[y][x - 1][z - 1] > 0 &&
                    xy[z][x - 1][y - 1] > 0)
                    chmax(ans, x + y + z);
            }
        }
    }

    // for (int x = 1; x <= 5; x++) {
    //     for (int y = 1; y <= 5; y++) {
    //         for (int z = 1; z <= 5; z++) {
    //             cerr << x << ' ' << y << ' ' << z << ' ' << yz[x][y][z] << ' '
    //                  << xz[y][x][z] << ' ' << xy[z][x][y] << endl;
    //         }
    //     }
    // }

    if (ans < 0) ans = -1;

    cout << ans << endl;
}

Day3

最初の方ずっと 3 点とかでまー--じで焦った。結論から言うと難しい回だったっぽいけど、それにしても。

A: device2

Communication Task なので、実装でバグらせてデバッグに無限時間使うのが怖くて後回し。自明だけ取って撤退。

f:id:RheoTommy:20220326203838p:plain

小課題 1

#include "Anna.h"
#include <utility>
#include <vector>

using namespace std;

namespace {}

int Declare() { return 2000; }

std::pair<std::vector<int>, std::vector<int>> Anna(long long A) {
    vector<int> a;
    for (int i = 0; i < A; i++) a.push_back(1);
    for (int i = 0; i < 2000 - A; i++) a.push_back(0);
    vector<int> s;
    vector<int> t;
    for (int i = 0; i < 1000; i++) s.push_back(a[i]);
    for (int i = 0; i < 1000; i++) t.push_back(a[1000 + i]);
    return {s, t};
}
#include "Bruno.h"
#include <utility>
#include <vector>

using ll = long long;

namespace {

int variable_example = 0;

}

long long Bruno(std::vector<int> u) {
    ll cnt = 0;
    for (auto ui : u) cnt += ui == 1;
    return cnt;
}

B: sprinkler

これ満点取れたから悔しい。

小課題 1 はとりあえずやるだけなので消化。

小課題 2, 3 も方針は比較的早く?思いついた気もしたけど、めっちゃくっちゃバグってた。

まず、木を見たのでオイラーツアーを考えると、どうにもきれいな性質がないので、BFS 順を考えると、頂点 $v$ から深さ $d$ の区間が連続になっているので遅延セグ木で勝てる!となる。 遅延セグ木事態はすぐかけたんだけれど、BFS 順を振ったため、もとの頂点番号と BFS 順における頂点番号がごっちゃになっていろいろなバグを生み、小課題 2 が通るまでにかなーり時間を使う。

小課題 3 も同じ方針ですでに思いついていたけれど、親がいないときの自分自身に対するクエリの処理でずっとバグっていて時間を食った。

小課題 4 は DFS を再帰的に呼ぶみたいなイメージですぐに解いた。

ここで、各クエリごとに自分の親頂点からそれぞれ二つの区間に対して操作をすればいい、というのに気づいたので、満点解法は $O((N+QD)logN)$ で終わるじゃん!と思い、実装。3 点しか入らず、キレ、終わり。

後から見ると、この時点で満点解法に必要な考察はすべて終わっていて、$dp[v][d] = 頂点 v から深さ d の頂点に一様にかけられた値$ を考えればもう終わりなんだけれど、木を見たら適当な順に番号を振りたくなるパターン認識的な思考に完全に敗北してしまった。

f:id:RheoTommy:20220326205142p:plain

小課題 1

#include <bits/stdc++.h>

using namespace std;
using ll = long long;

int main() {
    cin.tie(nullptr);
    cout.tie(nullptr);
    ios::sync_with_stdio(false);

    int n;
    ll l;
    cin >> n >> l;

    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        graph[a].emplace_back(b);
        graph[b].emplace_back(a);
    }

    vector<ll> h(n);
    for (int i = 0; i < n; i++) cin >> h[i];

    int q;
    cin >> q;
    for (int qi = 0; qi < q; qi++) {
        int t;
        cin >> t;
        if (t == 1) {
            int x, d;
            ll w;
            cin >> x >> d >> w;
            x--;

            auto dfs = [&](auto &&dfs, int now, int par, int dist) -> void {
                if (dist > d) return;
                h[now] *= w;
                h[now] %= l;

                for (auto next : graph[now]) {
                    if (next == par) continue;
                    dfs(dfs, next, now, dist + 1);
                }
            };

            dfs(dfs, x, -1, 0);
        } else {
            int x;
            cin >> x;
            x--;
            cout << h[x] << endl;
        }
    }
}

小課題 2, 3

#include <algorithm>
#include <bits/stdc++.h>
#include <functional>

using namespace std;
using ll = long long;

template <typename T, typename L> struct lazy_segtree {
    using F = function<T(T, T)>;
    using FL = function<L(L, L)>;
    using FM = function<T(T, L)>;

    int n;
    int height;
    vector<T> node;
    vector<L> lazy;

    T idt;
    L idl;

    F op;
    FL composition;
    FM merge;

  private:
    T propagate_at(int i) {
        if (lazy[i] != idl) {
            node[i] = merge(node[i], lazy[i]);
            if (i < n) {
                lazy[i * 2] = composition(lazy[i * 2], lazy[i]);
                lazy[i * 2 + 1] = composition(lazy[i * 2 + 1], lazy[i]);
            }
            lazy[i] = idl;
        }
        return node[i];
    }

    void propagate_topdown(int i) {
        for (int k = height; k >= 0; k--) propagate_at(i >> k);
    }

  public:
    lazy_segtree(int _n, F op, FL composition, FM merge, T idt, L idl)
        : op(op), composition(composition), merge(merge), idt(idt), idl(idl) {
        int size = 1;
        int h = 0;
        while (size < _n) size *= 2, h++;
        n = size;
        height = h;

        node = vector<T>(2 * n, idt);
        lazy = vector<L>(2 * n, idl);
    }

    T get(int i) {
        i += n;
        propagate_topdown(i);
        return node[i];
    }

    void set(int i, T x) {
        i += n;
        propagate_topdown(i);
        node[i] = x;
        while (i /= 2)
            node[i] = op(propagate_at(i * 2), propagate_at(i * 2 + 1));
    }

    void set(int a, int b, L y) {
        if (a >= b) return;
        int l = a + n, r = b + n;
        propagate_topdown(l);
        propagate_topdown(r - 1);
        for (; l < r; l /= 2, r /= 2) {
            if (l & 1) propagate_at(l), lazy[l++] = y;
            if (r & 1) propagate_at(--r), lazy[r] = y;
        }
        l = a + n, r = b + n - 1;
        while (l /= 2)
            node[l] = op(propagate_at(l * 2), propagate_at(l * 2 + 1));
        while (r /= 2)
            node[r] = op(propagate_at(r * 2), propagate_at(r * 2 + 1));
    }
};

int main() {
    cin.tie(nullptr);
    cout.tie(nullptr);
    ios::sync_with_stdio(false);

    int n;
    ll l;
    cin >> n >> l;

    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        graph[a].emplace_back(b);
        graph[b].emplace_back(a);
    }

    vector<ll> h(n);
    for (int i = 0; i < n; i++) cin >> h[i];

    vector<int> idx(n, -1);
    vector<int> ord;
    vector<int> par(n, -1);
    vector<int> left(n, -1);
    vector<int> right(n, -1);

    queue<int> que;
    que.emplace(0);
    while (!que.empty()) {
        int now = que.front();
        que.pop();
        idx[now] = ord.size();
        ord.emplace_back(now);
        int last = -1;
        for (auto next : graph[now]) {
            if (idx[next] != -1) continue;
            if (left[now] == -1) left[now] = next;
            par[next] = now;
            last = next;
            que.emplace(next);
        }
        if (last != -1) right[now] = last;
    }

    vector<int> par_by_idx(n);
    for (int i = 0; i < n; i++) par_by_idx[i] = idx[par[ord[i]]];

    auto op = [&](ll x, ll y) { return (x * y) % l; };

    auto composition = [&](ll x, ll y) { return (x * y) % l; };

    auto merge = [&](ll x, ll y) { return (x * y) % l; };

    auto lst = lazy_segtree<ll, ll>(n, op, composition, merge, 1, 1);

    for (int i = 0; i < n; i++) lst.set(idx[i], h[i]);

    int q;
    cin >> q;
    for (int qi = 0; qi < q; qi++) {
        int t;
        cin >> t;
        if (t == 1) {
            int x, d;
            ll w;
            cin >> x >> d >> w;
            x--;

            assert(d <= 2);

            if (d == 0) {
                lst.set(idx[x], idx[x] + 1, w);
            } else if (d == 1) {
                lst.set(idx[x], idx[x] + 1, w);
                if (par[x] != -1) lst.set(idx[par[x]], idx[par[x]] + 1, w);
                if (left[x] != -1) lst.set(idx[left[x]], idx[right[x]] + 1, w);
            } else {
                if (par[x] != -1 && par[par[x]] != -1)
                    lst.set(idx[par[par[x]]], idx[par[par[x]]] + 1, w);
                if (par[x] != -1) lst.set(idx[par[x]], idx[par[x]] + 1, w);
                if (par[x] != -1)
                    lst.set(idx[left[par[x]]], idx[right[par[x]]] + 1, w);
                else
                    lst.set(idx[x], idx[x] + 1, w);
                if (left[x] != -1) lst.set(idx[left[x]], idx[right[x]] + 1, w);

                if (left[x] != -1) {
                    int l = idx[left[x]];
                    int r = idx[right[x]];

                    int p =
                        lower_bound(par_by_idx.begin(), par_by_idx.end(), l) -
                        par_by_idx.begin();

                    int q =
                        upper_bound(par_by_idx.begin(), par_by_idx.end(), r) -
                        par_by_idx.begin();

                    lst.set(p, q, w);
                }
            }
        } else {
            int x;
            cin >> x;
            x--;
            cout << lst.get(idx[x]) << "\n";
        }
    }
}

小課題 4

#include <bits/stdc++.h>

using namespace std;
using ll = long long;

int main() {
    cin.tie(nullptr);
    cout.tie(nullptr);
    ios::sync_with_stdio(false);

    int n;
    ll l;
    cin >> n >> l;

    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        a--, b--;
        graph[a].emplace_back(b);
        graph[b].emplace_back(a);
    }

    vector<ll> h(n);
    for (int i = 0; i < n; i++) cin >> h[i];

    vector<vector<bool>> seen(n, vector<bool>(40, false));

    auto dfs = [&](auto &&dfs, int now, int par, int d, int w) -> void {
        if (seen[now][d]) return;
        seen[now][d] = true;
        h[now] *= w;
        h[now] %= l;
        if (d == 0) return;

        for (auto next : graph[now]) {
            if (next == par) continue;
            dfs(dfs, next, now, d - 1, w);
        }
    };

    int q;
    cin >> q;
    for (int qi = 0; qi < q; qi++) {
        int t;
        cin >> t;
        if (t == 1) {
            int x, d;
            ll w;
            cin >> x >> d >> w;
            x--;

            assert(w == 0);

            dfs(dfs, x, -1, d, 0);
        } else {
            int x;
            cin >> x;
            x--;
            cout << h[x] << endl;
        }
    }
}

C: suger

こんなのは、無理だよ。

maroonk さんが解説してた。

f:id:RheoTommy:20220326205833p:plain

小課題 1

#include <bits/stdc++.h>

using namespace std;
using ll = long long;

int main() {
    int q, l;
    cin >> q >> l;

    map<ll, ll> suger;
    map<ll, ll> ants;

    for (int qi = 0; qi < q; qi++) {
        int t;
        ll x, a;
        cin >> t >> x >> a;
        if (t == 1)
            ants[x] += a;
        else
            suger[x] += a;

        vector<pair<ll, ll>> sv;
        vector<pair<ll, ll>> av;
        for (auto [k, v] : suger) sv.emplace_back(k, v);
        for (auto [k, v] : ants) av.emplace_back(k, v);

        ll ans = 0;
        int i = 0;
        int j = 0;
        while (i < sv.size() && j < av.size()) {
            if (abs(sv[i].first - av[j].first) <= l) {
                ll m = min(sv[i].second, av[j].second);
                sv[i].second -= m;
                av[j].second -= m;
                ans += m;
                if (sv[i].second == 0) i++;
                if (av[j].second == 0) j++;
            } else {
                if (sv[i].first < av[j].first)
                    i++;
                else
                    j++;
            }
        }

        cout << ans << endl;
    }
}

Day4

過去一点が取れなかった。終わってみると、B はみんな取れてなくて、C は 100 か 14 以下か、みたいな感じで、取れるべきは A だったなぁと。

A: dango3

Communication Task. Output Only はでなかった。傾向的に Communication Task はひらめかないと結構厳しくて、Batch に可能枠があるのでそっちを優先してしまい、自明しかとってない。

f:id:RheoTommy:20220326210415p:plain

小課題 1, 2

#include "dango3.h"
#include <bits/stdc++.h>

using namespace std;

namespace {

int variable_example = 1;

} // namespace

void Solve(int N, int M) {
    vector<bool> checked(N * M + 1, false);
    vector<vector<int>> ids;

    // cerr << N << " " << M << endl;

    for (int ni = 0; ni < N; ni++) {
        vector<int> ask;
        for (auto v : ids) ask.emplace_back(v[0]);

        // for (auto k : ask) cerr << k << ' ';
        // cerr << endl;

        int cnt = 0;
        int id = 1;
        while (cnt < N - 1 - ni) {
            if (!checked[id]) cnt++, ask.emplace_back(id);
            id++;
        }

        vector<int> group;
        while (id <= N * M) {
            if (!checked[id]) {
                ask.emplace_back(id);
                if (Query(ask)) {
                    group.emplace_back(id);
                    ask.pop_back();
                    checked[id] = true;
                }
            }
            id++;
        }

        ids.emplace_back(group);
    }

    for (int j = 0; j < M; j++) {
        vector<int> ans(N);
        for (int i = 0; i < N; i++) ans[i] = ids[i][j];
        Answer(ans);
    }
}

B: fish2

見た目が解かれそうなやつでめっちゃ時間かけちゃった気がする。

小課題 1 は愚直に処理するだけなので、やる。

ある魚が勝てない必要十分条件は、$X\ a\ b\ c\ Y$ みたいな並びで $X > a + b + c$ かつ $a + b + c < Y$ であることなので、ある $X$ に対して $Y$ の候補は一つで、これを二分探索で求めれば小課題 2 は通る。

小課題 3 も、これを毎回やれば $O(QNlogN)$ だけど通る気がしなかったので手を付けず。

小課題 2 の実装で $X$ や $Y$ が右端、左端のときの処理をしてなかったりで、実装には苦労した。

f:id:RheoTommy:20220326211108p:plain

小課題 1

#include <bits/stdc++.h>

using namespace std;
using ll = long long;

int main() {
    int n;
    cin >> n;
    vector<ll> a(n);
    for (int i = 0; i < n; i++) cin >> a[i];

    int q;
    cin >> q;
    for (int qi = 0; qi < q; qi++) {
        int t;
        cin >> t;
        if (t == 1) {
            int x;
            ll y;
            cin >> x >> y;
            x--;
            a[x] = y;
        } else {
            int l, r;
            cin >> l >> r;
            l--;

            int ans = 0;

            for (int k = l; k < r; k++) {
                int li = k - 1;
                int ri = k + 1;
                ll now = a[k];
                while (true) {
                    if (li >= l && now >= a[li])
                        now += a[li], li--;
                    else if (ri < r && now >= a[ri])
                        now += a[ri], ri++;
                    else
                        break;
                }
                ans += li < l && ri + 1 > r;
            }

            cout << ans << endl;
        }
    }
}

小課題 2

#include <algorithm>
#include <bits/stdc++.h>

using namespace std;
using ll = long long;

int main() {
    int n;
    cin >> n;
    vector<ll> a(n);
    for (int i = 0; i < n; i++) cin >> a[i];

    vector<ll> sum(n + 1);
    for (int i = 0; i < n; i++) sum[i + 1] = sum[i] + a[i];

    vector<ll> sum_r(n + 1);
    for (int i = n - 1; i >= 0; i--) sum_r[i] += sum_r[i + 1] + a[i];

    auto max_l = [&](int r) {
        int ok = -1;
        int ng = r;
        while (ng - ok > 1) {
            int mid = (ng + ok) / 2;

            if (sum_r[mid] - sum_r[r] >= a[r])
                ok = mid;
            else
                ng = mid;
        }
        return ok;
    };

    int q;
    cin >> q;
    assert(q == 1);
    for (int qi = 0; qi < q; qi++) {
        int t;
        cin >> t;
        assert(t == 2);
        if (t == 1) {
            int x;
            ll y;
            cin >> x >> y;
            x--;
            a[x] = y;
        } else {
            int l, r;
            cin >> l >> r;
            l--;
            assert(l == 0 && r == n);

            vector<ll> imos(n + 1, 0);

            for (int l = 0; l < n - 1; l++) {
                int r = lower_bound(sum.begin(), sum.end(), a[l] + sum[l + 1]) -
                        sum.begin() - 1;
                int s = sum[r] - sum[l + 1];
                if (s < a[r]) {
                    imos[l + 1] += 1;
                    imos[r] -= 1;
                }
            }

            for (int r = n - 1; r > 0; r--) {
                int l = max_l(r);
                if (l == -1 || sum_r[l + 1] - sum_r[r] < a[l])
                    imos[l + 1] += 1, imos[r] -= 1;
            }

            for (int i = 1; i < n; i++) {
                if (sum[i] < a[i]) imos[0] += 1, imos[i] -= 1;
            }

            for (int k = 1; k < n; k++) {
                if (sum[n] - sum[n - k] < a[n - k - 1])
                    imos[n - k] += 1, imos[n] -= 1;
            }

            for (int i = 0; i < n; i++) imos[i + 1] += imos[i];
            int cnt = 0;
            for (int i = 0; i < n; i++) cnt += imos[i] == 0;

            cout << cnt << endl;
        }
    }
}

C: reconstruction

可能枠よりに見えたけど、自明なクラスカル法から抜け出せなかった。

ある辺が採用される必要十分条件を考えるような方針にまずあんまりならなかったのがまずい。

小課題 3 ですべての辺を使う場合をやったので、ある辺が使われる条件を考える流れにもっていかないとだけど、小課題 4 とかで、グラフの差分が高速に計算できるんじゃないか的な発想に支配され、ダメだった。

f:id:RheoTommy:20220326211912p:plain

小課題 1, 2

#include <algorithm>
#include <bits/stdc++.h>
#include <shared_mutex>
#include <tuple>

using namespace std;
using ll = long long;


template <typename T> bool chmin(T &x, T y) {
    if (x > y) {
        x = y;
        return true;
    }
    return false;
}

struct union_find {
    int n;
    vector<int> par;

    explicit union_find(int n) : n(n), par(n, -1) {}

    int root(int x) {
        if (par[x] < 0) return x;
        return par[x] = root(par[x]);
    }

    bool unite(int x, int y) {
        x = root(x), y = root(y);
        if (x == y) return false;
        if (x > y) swap(x, y);
        par[x] += par[y];
        par[y] = x;
        return true;
    }

    bool is_same(int x, int y) { return root(x) == root(y); }
};

int main() {
    int n, m;
    cin >> n >> m;

    vector<tuple<int, int, ll>> edges;
    for (int i = 0; i < m; i++) {
        int a, b;
        ll w;
        cin >> a >> b >> w;
        a--, b--;
        edges.emplace_back(a, b, w);
    }

    int q;
    cin >> q;
    for (int qi = 0; qi < q; qi++) {
        ll x;
        cin >> x;

        auto uf = union_find(n);
        ll ans = 0;

        vector<tuple<ll, int, int>> diff;
        for (int i = 0; i < m; i++) {
            int a, b;
            ll w;
            tie(a, b, w) = edges[i];
            diff.emplace_back(abs(w - x), a, b);
        }

        sort(diff.begin(), diff.end());
        for (auto [w, a, b] : diff)
            if (uf.unite(a, b)) ans += w;

        cout << ans << endl;
    }
}

小課題 3

#include <bits/stdc++.h>
#include <cassert>

using namespace std;
using ll = long long;
const ll linf = 1001001001001001001ll;

template <typename T> bool chmin(T &x, T y) {
    if (x > y) {
        x = y;
        return true;
    }
    return false;
}

int main() {
    int n, m;
    cin >> n >> m;

    vector<vector<ll>> edge(n - 1);
    for (int j = 0; j < m; j++) {
        int a, b;
        ll w;
        cin >> a >> b >> w;
        a--, b--;
        assert(a + 1 == b);
        edge[a].push_back(w);
    }

    for (int i = 0; i < n - 1; i++) sort(edge[i].begin(), edge[i].end());

    vector<int> inds(n - 1, 0);

    int q;
    cin >> q;
    for (int qi = 0; qi < q; qi++) {
        ll x;
        cin >> x;

        ll ans = 0;
        for (int i = 0; i < n - 1; i++) {
            while (inds[i] + 1 < edge[i].size() && edge[i][inds[i] + 1] < x)
                inds[i]++;

            ll t = abs(x - edge[i][inds[i]]);
            if (inds[i] + 1 < edge[i].size())
                chmin(t, abs(edge[i][inds[i] + 1] - x));

            ans += t;
        }

        cout << ans << endl;
    }
}

JOI 2021/2022 本選参加記

JOI 2021/2022 本選参加記

追記

本選通ってた.俺の勝ち!

合計点

ID Score
A 100
B 100
C 77
D 8
E 19
Sum 304

非公式順位表によると 304 点は 9 人いるらしい.304 点を落としたら春合宿の人数多くした意味ないので,32 人くらいなら 30 人超えてもセーフってことにしましょう.

304 点以上が非公式順位表で 29 人,300 点以上が非公式順位表で 30 人,恐怖の C 満点潜伏 er が一人いるので,怖い.

反省

競プロあんまり触れてないまま本選に出たけれど,一応レート相応の点は取れたのでよかった.

ところで本選突破のハードル上がってませんか?人数増えたけど要求されているレベル高いままと感じた.

春合宿は 30 人くらいいたほうが楽しいと思うので,通せ.

難易度は,5-6-10-10-?? だと思っています.

以下問題ごとに.

A, B を割と速攻で通したのはよかった.

C も嘘からすぐ復活して,77 点とるまでの流れはよかった.が,$O(N^{2}KlogN)$ にこだわってそのあとかなりの時間を使ったのはよくない.TL 1.6s なのと, $ 500 \cdot 500 \cdot 500 \cdot log(500) \fallingdotseq 1.1 \cdot 10^{9} $ で微妙と思ったけど $log$ 解法捨てるべきだったのかも.

D, E はあんまりうまく解けず.C の $log$ 解法に粘って,残り 1 時間半位になった後,安全策で部分点一個一個死守しにいったけど,D はもうちょっと考えるべきだったのかもしれない.解ききれない問題があったときに頭をリセットできないことが分かった.

A: インターカステラー

去年に比べてめっちゃ簡単になっているように感じる.

二分探索で終わり.7 分!

f:id:RheoTommy:20220213202724p:plain

B: 自習

去年に比べてめっちゃ簡単になっていて,うれしい!

これまた二分探索すれば,どうせできるので適当にちゃちゃっと書くと WA る.

f:id:RheoTommy:20220213203008p:plain f:id:RheoTommy:20220213203104p:plain

小課題 1 で WA っているのがわかるので,とりあえず

10 1
1 1 1 1 1 1 1 1 1 1

を試すとバグる.デバッグですぐオーバーフローに気づき,修正.

f:id:RheoTommy:20220213203521p:plain

A, B 合わせて 30 分!めっちゃ順調!

C: 選挙で勝とう

最初の考察

  • $B_i$ を選ぶなら最初に全部選んでしまったほうがいい
  • $A_i$ しか選ばないなら,最後に昇順に選べばいい
  • $B_i$ も選ぶなら小さい順でいいじゃん
    • 嘘です

これに基づいて協力者数全探索で終わりじゃんっていって提出.当然 WA.

f:id:RheoTommy:20220213204533p:plain

A_i B_i
--------
1 109
100 101

みたいなときに嘘だなーってなるので,もうちょっと考察を進める.

  • $B_i$ を選ぶなら最初に全部選んでしまったほうがいい
  • $A_i$ しか選ばないなら,最後に昇順に選べばいい
  • 協力者数を $K_i$ で固定すると,$\frac{B_i}{1} + \frac{B_j}{1+1} + \cdots + \frac{\sum A_k}{K_i}$ の最小値を求める問題になって,DP で解けそう.
  • 選ぶ $B_i$ に関しては当然昇順がいい.

そこで,$B$ をソートして協力者数 $K_i$ の時の答えを

$dp[i][j][k] = 州を i 個見て,協力者を j 人集めて,票を k 票集めた時の最小時間$

で求める.これで部分点 1, 2, 3, 4, 5 が取れて 66 点.

f:id:RheoTommy:20220213205528p:plain

ここまでで 1 時間 23 分.割と順調.

ここで部分点 6 を見ると,自明 DP の k の次元を消せるので,3 分で部分点回収.

f:id:RheoTommy:20220213205737p:plain

最後に,$K_i$ の固定をなくして一つの DP で解くか,$K_i$ 固定で DP の j か k の次元を消すことを考えたけど,うまくいかず,30 分ぐらいそのまま.

ここで,$K_i$ に対して答えに単調性がありそうだと気づいたので,試してみたところ行けそうで,バグにちょっと苦戦しながら実装.$O(N^{2}KlogN)$ なので勝ったと思ったが,なんと通らない.

f:id:RheoTommy:20220213210107p:plain

ここで折り返し.

この後も C の満点を狙って定数倍高速化を 10 分ちょっとしたが,通らず,断念.

最後 10 分も粘ってみるも,JOI は $log(N)$ をちゃんと落としてくれます(無念).

f:id:RheoTommy:20220213210511p:plain

D: 鉄道旅行 2

C を断念してから 30 分くらい,部分点ごとにちゃんと考察してみたが,$O(QMN)$ 以上に計算量の良い解法浮かばなかった.考察も実装もちょっと右往左往して,残り 1 時間程度で自明部分点の小課題 1 のみとる.

なんか C より簡単とか言っている人結構多いけどちょっとわかりません.

f:id:RheoTommy:20220213210742p:plain

E: 砂の城 2

残り時間も少ないので,まずは自明部分点である小課題 1 を通す. 次に小課題 2 に手を付けたところ,

ただし,Ai, j の値はすべて相異なる

の制約をちゃーんとすっぽかして,加えてよくわからない考察とよくわからない実装で迷子になる(木の直径求めようとしてた)

最終的には,長方形全探索した後, DAG の長さが長方形の面積に等しいかっていうコードを書いて小課題 2 だけは死守する.30 分掛かった.

f:id:RheoTommy:20220213211249p:plain

AtCoder青になりました!

ありがとうABC190、ありがとう早解き回。 f:id:RheoTommy:20210130222908p:plain

f:id:RheoTommy:20210130230916p:plain

やったこと

  • 緑以下をすべて埋めて、水を途中まで埋めた。
  • ちょっとライブラリを整理した(リンクはこちら!
  • 黄コーダーとたくさん話した(難しいアルゴリズムも知識として知っておくと、たまに使える)
  • オライリーのRust本と実践Rustを再履修した(道具を知っていないと、道具に殺されますよ(笑))
  • コンテストにで続けた(出ないとレート変動しませんよ(笑))
  • PCKに出た(楽しかった、が、文化祭と被らんでくれ(泣))
  • JOIに出た(JOI本選さん、C++以外を使わせてくれぇ(泣))

一言で言うなら、競プロを継続的に楽しんでいた(とりあえず継続してやってれば、いずれレートは上がるだろうと信じて)

f:id:RheoTommy:20210130231435p:plainf:id:RheoTommy:20210130231501p:plainf:id:RheoTommy:20210130231525p:plain

次は

PCKまでに黄色めざすぞー!

【プログラミング言語速度比較】Collatz数列ベンチマークを言語別比較しよー!

Collatz ベンチマーク

目次

追記(2021/02/23)

このベンチマークでは Collatz 数列と呼ばれるものを使って適当に言語ごとの速度を測っているのですが、正直なところ良いベンチマークとは言えません。

適当に書いたら思っていた以上にアクセスされており驚いています。

時間があれば、これよりもっと正確で実用的なベンチマークをしたいと思っているのですが、以下の条件を満たすようないいベンチマークが見つかっていません。

  • 実装が軽い(いろんな言語で書くことを考えると、できるだけ軽いほうがいい)
  • 計算量解析が容易
  • ボトルネックが自明(その処理の速度を言語ごとに比較できる)

Dijkstra や DP など案は適当に上がったのですが、どうもしっくりきていません。良さそうな案があったら是非コメントください。

Collatz 数列とは

整数 n が偶数ならば 2 で割り、奇数ならば 3 倍して 1 を足す。これを n が 1 になるまで繰り返す。

証明されていないが、1020 くらいまでは必ず 1 になるらしい。

やること

コラッツ数列の長さを求める関数を書き、それを用いて言語の実行速度を測る。

計測方法

AtCoder のコードテストを用います(AtCoder 上での処理時間がわかれば、言語別の実行速度の参考になるかと思ったため。サーバーに負荷がかかるとかで怒られたらやめます)

任意の整数i(1 <= i <= 104,105,106, 107)について、iが 1 になるまで上記の操作をしたときの操作回数を求めるプログラムを書き、総和を 1000000007 で割った余りを求める(107 程度ではオーバーフローはしませんでしたが)

基本的には、collatz()関数をmain()関数内でn回回し、その総和の Mod を取る形で実装しています。

Haskell は言語仕様的にループを使わないので、再帰で実装しています。

ベンチマーク結果

結果表

単位は[ms] TLEは 105 ms 以上

言語 104 105 106 107
Rust(メモ化チート) 7 12 61 512
Rust(usize) 10 23 176 1887
Rust(isize) 10 37 275 2910
C++ (GCC) 8 34 301 3211
C++ (Clang) 13 34 302 3211
C++(GCC)編集版 7 25 203 2119
C++(GCC)編集版(unsigned) 9 31 301 3394
C++(Clang)編集版 11 38 272 2727
C++(Clang)編集版(unsigned) 13 25 179 1896
C(GCC) 4 22 194 2027
C(GCC)(unsigned) 8 30 314 3559
C(Clang) 6 27 249 2727
C(Clang)(unsigned) 7 20 171 1887
C(Clang)(unsigned)(アセンブリ参考) 6 19 164 1807
D(LDC) 10 25 175 1870
D(GDC) 30 46 258 2586
D(DMD) 10 51 400 4282
Nim 11 30 229 2424
Nim(uint) 10 35 251 2678
Fortran 12 28 252 2889
Objective-C 76 51 333 2910
Swift 9 40 305 3078
Golang 9 42 291 3103
Haskell 11 30 285 3241
Crystal 15 48 345 3746
Java 116 146 489 4404
Kotlin 119 146 522 4551
Scala 518 565 920 4681
Julia 216 229 383 2071
Pascal 9 50 413 5012
C# 89 109 487 4825
Visual Basic 95 120 511 4702
PyPy2 69 113 805 5522
PyPy3 80 116 599 5528
JavaScript 72 89 1557 TLE
Common Lisp 53 157 1657 TLE
Php 57 454 5269 TLE
Cython 74 364 3885 TLE
Cython2 54 67 280 2552
Clojure 1921 2136 5005 TLE
Ruby 110 685 7384 TLE
Scheme 70 604 7425 TLE
Python 109 1059 TLE TLE
Python(numba) 130 124 271 1968
Lua(Lua) 104 1152 TLE TLE
Lua(LuaJIT) 12 65 642 7169
Perl 124 1244 TLE TLE
Awk 121 1398 TLE TLE
Bc 747 9481 TLE TLE
Raku(Perl6) 1232 TLE TLE TLE
dc 1378 TLE TLE TLE
Vim 3689 TLE TLE TLE
Bash 7674 TLE TLE TLE

編集履歴

詳細を見る

C#を追加しました。 白緑さん、ありがとうございます!

C++の編集版を追加しました。ありがとうございます!

Python と PyPy を追加しました。arasius さん、ありがとうございます!

Crystal を追加しました。yuruhiya さん、ありがとうございます!

Golang を追加しました。

JavaScript を追加しました。

Fortran を追加しました。jj1gui さん、ありがとうございます!

Ruby を追加しました。KowerKoint さん、ありがとうございます!

Rust だけ非負整数だったので、Rust の isize 版、C++の unsigned 版も作りました。

D 言語を追加しました。lempiji さん、ありがとうございます!

Bash を追加しました。KowerKoint さん、ありがとうございます!

Swift を追加しました。

Objective-C を追加しました。matsumo さん、ありがとうございます!

C を追加しました。mikhail さん、fujitanozomu さん、ありがとうございます! https://twitter.com/Mikhail_chan/status/1284665518524256256?s=20

Bc を追加しました。fujitanozomu さん、ありがとうございます!

Perl を追加しました。fujitanozomu さん、ありがとうございます!

Vim を追加しました。KowerKoint さん、ありがとうございます!

PhpAWKLuaPascal を追加しました。fujitanozomu さん、本当にありがとうございます!!

Rust のメモ化版を追加しました。

Visual Basic を追加しました。KowerKoint さん、ありがとうございます! https://twitter.com/KowerKoint2010/status/1285871100329525254?s=20

Raku(Perl6)を追加しました。fujitanozomu さん、ありがとうございます! https://twitter.com/fujitanozomu/status/1285957649926811654?s=20

dc を追加しました。fujitanozomu さん、ありがとうございます! https://twitter.com/fujitanozomu/status/1287306706444115968?s=20

Nim を追加しました。Kitagawa さん、ありがとうございます! https://twitter.com/kitagawahr1992/status/1288004487802572800?s=20

コメントより、Nim の uint 版を追加しました。udoooon さん、ありがとうございます!

コメントより、Cython を追加しました。papico さん、ありがとうございます!

コメントより、Numba を使った Python を追加しました。sooooba さん、ありがとうございます。入力を可変にできていません

Clojure CommonLisp Scheme を追加しました。Linuxmetal さんありがとうございます! https://twitter.com/linuxmetel/status/1364182475480522756?s=20

Twitter の DM より,Cython, julia, numba を追加しました.ありがとうございます!

感想

基本的に最速クラス

2000ms を切るくらいです。

Rust は相変わらず早いです。ただし、usize から isize にすると結構速度が落ちました。

C++も、Clang においては Rust と同じ傾向です。

Rust も D も C++(Clang)も C(Clang) も、非負整数の最適化が早いみたいです。

特に Rust は、配列の添字に使えるのが usize だけだったり、非負整数を扱う機会が多いのでそれなりに最適化されていそうです。

癖がある最速クラス

2000ms 前半でした。 unsigned より signed のほうが早い、ちょっと癖のある感じです。GCC は特殊ですね。どっちにしろ最速クラスなのには変わりないです。

準最速クラス

2000ms 後半くらいで結構早いです。Java の 1.5~2 倍程度高速です。困らない速さだと思います。

Fortran は速さ的には C++に近い感じです。結構イメージ通りではあります。というか初めて Fortran のコードを見ました()

Objective-C ってこんなに C なんですね。書きにくそう()

Nimは思ったより遅かったです。Rust並かRustを超えると思ったんだけどな・・・

まぁまぁ早いクラス

3000ms 台で十分に高速だと思います。ただ、ここから最速とは言われないイメージがあります。

Haskell 思ったより早い!

GolangC++Haskell 並の速度は出ますね

Crystal は Ruby 版の Nim だと思うと早いのも不思議じゃないですね。

Swift は Kotlin よりちょっとだけ早いです。

普通クラス

4000ms 台で高速ですが、Rust などと比べると 2 倍の差がついています。

Scala は圧倒的に JVM 起動時間が重いですね。Rust なら 1ms の問題でも Scala だとジャッジがめっちゃ遅いのでちょっと不便です。

C#JVM の起動時間を無視したとき、Java とほとんど同じくらいの速度に見えます。

Visual Basic ってこんなに早いんですね。

やや遅いクラス

  • PyPy
  • Lua(LuaJIT)

5000ms を超えました。Python と比べれば圧倒的な改善ですが、流石にトップレベルにはかなわない感じです。

PyPy は想像以上に早かったです。Java と対して差がついていないので、PyPy ならどの問題も通せそうです(MLE 問題がありそうですが)。

LuaJIT コンパイルすると PyPy より遅い程度で動作しますね。

ビリ

ダントツで遅く、TLE が出ました・・・。 動的型付け言語なので仕方ないですね。

JavaScriptPython に比べれば圧倒的に早いですが、それでも静的型付け言語には全然及びませんね。想定解が通らない問題は少なそうですが、あんまり書きやすいイメージもないので、競プロにはあんまり向いていないかもしれません。

Php も思ったより早いですね。JS に近いイメージでしょうか

RubyJavaScript には及びませんが、Python と比べると結構早いですね。Ruby は関数がめちゃめちゃ豊富で、すごく短いコードを見かけたりすることがあるので、問題によっては楽に解けそうです。ただ、ゴリゴリの実装問題とかだともしかすると間に合わないことがあるかもしれないですね・・・。

PyPy に対して Python は想像以上に遅かった・・・。想定解が通らない問題があってもおかしくないかもしれないです。Cython ももうちょっと早いと思ってました。

う 笑

う 笑

言語別実装

Rust メモ化版

メモ化すると 500ms 程度になりましたが、これ以上は早くなりませんでした。(HashMap や BTreeMap ですべての要素をメモ化したり、一部の要素をメモ化したりしましたが、Vector で n 個までメモ化するのが最速でした)

コードを見る

#![allow(unused_macros)]
#![allow(dead_code)]
#![allow(unused_imports)]

use std::io::stdin;
use std::collections::{HashMap, BTreeMap};

const U_INF: usize = 1 << 60;
const I_INF: isize = 1 << 60;

fn main() {
    let mut buf = String::new();
    stdin().read_line(&mut buf).unwrap();
    let n = buf.split_whitespace().next().unwrap().parse().unwrap();
    // let mut memo = HashMap::new();
    let mut memo = vec![None; n];


    let mut acc = 0;
    for i in 1..=n {
        acc += collatz(i, &mut memo);
        acc %= 1000000007;
    }
    println!("{}", acc);
}

fn collatz(i: usize, memo: &mut [Option<usize>]) -> usize {
    match memo.get(i).and_then(|o| *o) {
        Some(t) => t,
        None => {
            let cnt =
                if i == 1 { 0 } else if i % 2 == 0 { 1 + collatz(i / 2, memo) } else { 1 + collatz(i * 3 + 1, memo) };

            if (0..memo.len()).contains(&i) {
                memo[i] = Some(cnt);
            }
            cnt
        }
    }
}

Rust

大好き!適当に書いてもコンパイルが通れば一定に高速で安全なコードになるのは最高です。

isize 版はcollatz(mut i:isize)->isizeとして測定しました。

コードを見る

#![allow(unused_macros)]
#![allow(dead_code)]
#![allow(unused_imports)]

use std::io::stdin;

const U_INF: usize = 1 << 60;
const I_INF: isize = 1 << 60;

fn main() {
    let mut buf = String::new();
    stdin().read_line(&mut buf).unwrap();
    let n = buf.split_whitespace().next().unwrap().parse().unwrap();
    let mut acc = 0;
    for i in 1..=n {
        acc += collatz(i);
        acc %= 1000000007;
    }
    println!("{}", acc);
}

fn collatz(mut i: usize) -> usize {
    let mut cnt = 0;
    while i != 1 {
        cnt += 1;
        if i % 2 == 0 {
            i /= 2;
        } else {
            i *= 3;
            i += 1;
        }
    }
    cnt
}

C++

意外と Rust のほうが早いです C++の書き心地は Rust とあんまり変わりませんね。危ない橋を渡れる言語だと思っています。

for を 0 から始めるようにすると 1.5 倍位早くなりました。不思議。

unsigned 版はusing ll = unsigned long longにしました。

編集前

コードを見る

#include <iostream>

using namespace std;

using ll = long long;

ll collatz(ll i) {
    ll cnt = 0;
    while (i != 1) {
        cnt++;
        if (i % 2 == 0) {
            i /= 2;
        } else {
            i *= 3;
            i += 1;
        }
    }
    return cnt;
}

int main() {
    ll n;
    cin >> n;
    ll acc = 0;
    for (ll i = 1; i <= n; i++) {
        acc += collatz(i);
        acc %= 1000000007;
    }
    cout << acc << endl;
}

編集後

コードを見る

#include <iostream>

using namespace std;

using ll = long long;

ll collatz(ll i) {
    ll cnt = 0;
    while (i != 1) {
        cnt++;
        if (i % 2 == 0) {
            i /= 2;
        } else {
            i *= 3;
            i += 1;
        }
    }
    return cnt;
}

int main() {
    ll n;
    cin >> n;
    ll acc = 0;
    for (ll i = 0; i < n; i++) {
        acc += collatz(i+1);
        acc %= 1000000007;
    }
    cout << acc << endl;
}

Java

C++,Rust ほどではないが早いです。

コードは C++以上に書きにくい印象があります。

コードを見る

import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long n = sc.nextLong();
        long acc = 0;
        for (long i = 1; i <= n; i++) {
            acc += collatz(i);
            acc %= 1000000007;
        }
        System.out.println(acc);
    }

    public static long collatz(long j) {
        long i = j;
        long cnt = 0;
        while (i != 1) {
            cnt++;
            if (i % 2 == 0) {
                i /= 2;
            } else {
                i *= 3;
                i += 1;
            }
        }
        return cnt;
    }
}

Kotlin

Java とほとんど変わらないです。

Kotlin がかける環境ならば Java を書く理由がない気がします。

コードを見る

import java.util.*

fun main() {
    val sc = Scanner(System.`in`)
    val n = sc.nextLong();
    var acc = 0L;
    for (i in 1..n) {
        acc += collatz(i)
        acc %= 1000000007
    }
    println(acc)
}

fun collatz(j: Long): Long {
    var cnt = 0L
    var i = j
    while (i != 1L) {
        cnt += 1
        if (i % 2L == 0L) {
            i /= 2
        } else {
            i *= 3
            i += 1
        }
    }
    return cnt
}

Scala

JVM 起動時間がすごく重いですねー。 N が大きくなると差がなくなることから、ロジック部分の処理速度は Java と対して変わらなそうです

コードを見る

import java.util.Scanner

object Main {
  def main(args: Array[String]): Unit = {
    val sc = new Scanner(System.in)
    val n = sc.nextLong()
    var acc = 0L
    for (i <- 1L to n) {
      acc += collatz(i)
      acc %= 1000000007L
    }
    println(acc)
  }

  def collatz(j: Long): Long = {
    var i = j
    var cnt = 0
    while (i != 1) {
      cnt += 1
      if (i % 2 == 0) {
        i /= 2
      } else {
        i *= 3
        i += 1
      }
    }
    cnt
  }
}

Haskell

Haskell は C と同じデータ構造とアルゴリズムならば同等の速度を出せると聞いていました。

その話通り、Int 型の指定を入れると C++(編集前)並に爆速です。

solve :: Int -> Int を指定しないと solve :: Integral a => a -> a と多相になるので遅いです

コードを見る

main = do
  n <- read <$> getLine
  print $ solve n

solve :: Int -> Int
solve 1 = collatz 1 0
solve n = (collatz n 0 + solve (n-1)) `mod` 1000000007

collatz :: Int -> Int -> Int
collatz 1 acc = acc
collatz n acc | n `mod` 2 == 0 = collatz (n `div` 2) (acc + 1)
collatz n acc = collatz (n * 3 + 1) (acc + 1)

CSharp

Java と同じくらいの速度なので、罠が少なければ結構良い言語かもしれないです(人気ありますよね)

コードを見る

using System;

class Test{
  public static void Main(string[] args){
    var n = long.Parse(Console.ReadLine());
    long acc=0;
    for(long i=1;i<=n;i++){
      acc+=collatz(i);
      acc%=1_000_000_007;
    }
    Console.WriteLine(acc);
  }

  private static long collatz(long i){
    long cnt = 0;
    while (i != 1) {
      cnt++;
      if (i % 2 == 0) {
        i /= 2;
      } else {
        i *= 3;
        i += 1;
      }
    }
    return cnt;
  }
}

Python & PyPy

コードはやっぱりシンプルですね

PyPy じゃないと速度はだいぶ遅いです

コードを見る

def collatx(i):
    cnt=0
    while i!=1:
        cnt+=1
        if i%2==0:
            i//=2
        else:
            i*=3
            i+=1
    return cnt

def main():
    n=int(input())
    acc=0
    for i in range(1,n+1):
        acc+=collatx(i)
        acc%=1000000007
    print(acc)

if __name__ == "__main__":
    main()

Crystal

コードが凄くシンプルでびっくりです。

それでいてかなり高速でした。

コードを見る

def collatz(i)
  cnt = 0i64
  while i != 1
    cnt += 1
    if i.even?
      i //= 2
    else
      i *= 3
      i += 1
    end
  end
  cnt
end

n = read_line.to_i64
puts (1i64..n).reduce(0i64) { |acc, i|
  acc += collatz(i)
  acc %= 1_000_000_007
}

Golang

結構早いです。

IDE との相性もよく、気持ちよく書けます。

コードを見る

package main

import "fmt"

func main() {
 var n int
 _, _ = fmt.Scan(&n)
 acc := 0

 for i := 1; i <= n; i++ {
  acc += collatz(i)
  acc %= 100000007
 }

 fmt.Println(acc)
}

func collatz(j int) int {
 i := j
 cnt := 0
 for ; i != 1; {
  cnt++
  if i%2 == 0 {
   i /= 2
  } else {
   i *= 3
   i += 1
  }
 }
 return cnt
}

JavaScript

動的型付け言語書けません

コードを見る

"use strict";

const main = (arg) => {
    const n = parseInt(arg);

    let acc = 0;
    for (let i = 1; i <= n; i++) {
        acc += collatz(i);
        acc %= 1000000007;
    }

    console.log(acc);
};
main(require("fs").readFileSync("/dev/stdin", "utf8"));

function collatz(i) {
    let cnt = 0;
    while (i !== 1) {
        cnt++;
        if (i % 2 === 0) {
            i /= 2;
        } else {
            i *= 3;
            i += 1;
        }
    }
    return cnt;
}

Fortran

コードを雰囲気で読むことができませんでした。

bash とかとも近い気がしました(素人感)

コードを見る

program main
    implicit none
    integer(8) n,i,acc
    read*,n
    acc=0
    do i=1,n
        acc=acc+collatz(i)
        acc=mod(acc,1000000007)
    end do
    print'(i0)',acc
contains
integer(8) function collatz(n_in)
    integer(8) n,n_in
    collatz=0
    n=n_in

    do while(n/=1)
        collatz=collatz+1
        if(mod(n,2)==0)then
            n=n/2
        else
            n=3*n+1
        end if
    end do
end function
end program main

Ruby

Ruby のコードは左によっているイメージがあります。

コードを見る

def collatz(i)
  cnt = 0
  while i != 1 do
    cnt += 1
    if i % 2 == 0
      i /= 2
    else
      i *= 3
      i += 1
    end
  end
  cnt
end

n = gets.to_i
acc = 0
n.times do |i|
  acc += collatz(i + 1)
  acc %= 1000000007
end
puts(acc)

D 言語

びっくりするくらい早かったです。

コードを見る

import std.stdio;
import std.conv;

size_t collatz()(size_t i)
{
    size_t cnt = 0;
    while (i != 1)
    {
        cnt++;
        if (i % 2 == 0)
        {
            i /= 2;
        }
        else
        {
            i *= 3;
            i += 1;
        }
    }
    return cnt;
}

void main()
{
    size_t n = readln().to!size_t;

    size_t acc = 0;
    for (size_t i = 1; i <= n; i++)
    {
        acc += collatz(i + 1);
        acc %= 1000000007;
    }
    writeln(acc);
}

Bash

Bash に処理速度を期待してはいけません

コードを見る

collatz() {
    cnt=0
    j=$1
    while [ $j -ne 1 ] ; do
        cnt=$((cnt+1))
        if [ $((j % 2)) -eq 0 ]; then
            j=$((j / 2))
        else
            j=$((j * 3))
            j=$((j + 1))
        fi
    done
    return_var=$cnt
}

read n
acc=0
for i in `seq 1 $n`; do
    collatz $i
    acc=$((acc + return_var))
    acc=$((acc % 1000000007))
done
echo $acc

Swift

Swift のコードも見やすいと思います。

速度も申し分ないですね。

コードを見る

func collatz(_ j: Int64) -> Int64 {
    var cnt: Int64 = 0
    var i = j
    while (i != 1) {
        cnt += 1
        if (i % 2 == 0) {
            i /= 2
        } else {
            i *= 3
            i += 1
        }
    }
    return cnt
}

let n = Int64(readLine()!)!
var acc: Int64 = 0;
for i in 1...n {
    acc += collatz(i)
    acc %= 1000000007
}
print(acc)

Objective-C

見た目は C++と同じ感じですね

思ったより遅かったです。

コードを見る

#import <stdio.h>

typedef long long ll;

ll collatz(ll i) {
    ll cnt = 0;
    while (i != 1) {
        cnt++;
        if (i % 2 == 0) {
            i /= 2;
        } else {
            i *= 3;
            i += 1;
        }
    }
    return cnt;
}

int main() {
    ll n;
    scanf("%lld", &n);
    ll acc = 0;
    for (ll i = 1; i <= n; i++) {
        acc += collatz(i);
        acc %= 1000000007;
    }
    printf("%lld\n", acc);
}

C

longAtCoder 上では 64bit らしいです https://twitter.com/Mikhail_chan/status/1284743491872907265?s=20

通常版

unsignedlong longの前に追加するとunsigned版になります

コードを見る

#include <stdio.h>

long long collatz(long long n) {
   long long cnt = 0;
    while (n != 1) {
        cnt++;
        if (n % 2 == 0) {
            n /= 2;
        } else {
            n *= 3;
            n++;
        }
    }
    return cnt;
}

int main(void) {
    long long n, ans = 0;
    scanf("%ld", &n);
    for (unsigned long long i = 0; i < n; i++) {
        ans += collatz(i + 1);
        ans %= 1000000007;
    }
    printf("%ld\n", ans);
}

アセンブリ修正版

コードを見る

    #include <stdio.h>

    #if 1
    unsigned long collatz(unsigned long n);
    __asm(
    "        .text                        \n"
    "        .globl  collatz              \n"
    "        .p2align 4                   \n"
    "collatz:                             \n"
    "        xor     %eax, %eax           \n"
    "        cmp     $1, %rdi             \n"
    "        je      1f                   \n"
    "        .p2align 4, 0x90             \n"
    "0:      inc     %rax                 \n"
    "        mov     %rdi, %rcx           \n"
    "        shr     $1, %rcx             \n"
    "        lea     1(%rdi,%rdi,2), %rdi \n"
    "        cmovnc  %rcx, %rdi           \n"
    "        cmp     $1, %rdi             \n"
    "        jne     0b                   \n"
    "1:      ret                          \n"
    );
    #else
    unsigned long collatz(unsigned long n) {
        unsigned long cnt = 0;
        while (n != 1) {
            cnt++;
            if (n % 2 == 0) {
                n /= 2;
            } else {
                n *= 3;
                n++;
            }
        }
        return cnt;
    }
    #endif

    int main(void) {
        unsigned long n, ans = 0;
        scanf("%lu", &n);
        for (unsigned long i = 0; i < n; i++) {
            ans += collatz(i + 1);
            ans %= 1000000007;
        }
        printf("%lu\n", ans);
    }

Bc

Bash よりは書きやすそうです

コードを見る

define collatz(i) {
    cnt = 0
    while (i != 1) {
        cnt += 1
        if (i % 2 == 0) {
            i /= 2
        } else {
            i *= 3
            i += 1
        }
    }
    return cnt;
}

scale = 0
n = read()
ans = 0
for (i = 1; i <= n; i++) {
    ans += collatz(i);
    ans %= 1000000007;
}
print ans, "\n";

Perl

最近はあんまり聞かない言語ですね~。

速度は Python に近いようです。

コードを見る

#!/usr/bin/perl

sub collatz {
    my ($i) = @_;

    my $cnt = 0;
    while ($i != 1) {
        $cnt += 1;
        if ($i % 2 == 0) {
            $i /= 2
        } else {
            $i *= 3;
            $i += 1
        }
    }
    return $cnt
}

my $scale = 0;
my $n = <STDIN>;
my $ans = 0;
for (my $i = 1; $i <= $n; $i++) {
    $ans += collatz($i);
    $ans %= 1000000007
}
print "$ans\n"

Php

Web で使うからか、Python などより JS に近い速度でした(JS よりは遅いですが)

コードを見る

<?php

function collatz($i) {
    $cnt = 0;
    while ($i != 1) {
        $cnt += 1;
        if ($i % 2 == 0) {
            $i /= 2;
        } else {
            $i *= 3;
            $i += 1;
        }
    }
    return $cnt;
}

$n = fgets(STDIN);
$ans = 0;
for ($i = 1; $i <= $n; $i++) {
    $ans += collatz($i);
    $ans %= 1000000007;
}
echo "$ans\n";

?>

Vim

Vim が提出できるとは思っていませんでした

コードを見る

:function! Collatz(j) abort
    :let cnt = 0
    :let x = a:j
    :while x != 1
        :let cnt += 1
        :if x % 2 == 0
            :let x /= 2
        :else
            :let x *= 3
            :let x += 1
        :endif
    :endwhile
    :return cnt
:endfunction

:delete a
:let n = @a
:let acc = 0
:for i in range(n)
    :let acc += Collatz(i + 1)
    :let acc %= 1000000007
:endfor
:put! = acc
:write
:quit

Awk

Bc に近いイメージでしょうか。Bash や Bc よりは早いです。

コードを見る

function collatz(i) {
    cnt = 0
    while (i != 1) {
        cnt += 1
        if (i % 2 == 0) {
            i /= 2
        } else {
            i *= 3
            i += 1
        }
    }
    return cnt
}

{
    acc = 0
    for (i = 1; i <= $0; i++) {
        acc += collatz(i)
        acc %= 1000000007
    }
    print acc
}

Lua

コードはなんとなく Bash みたいな雰囲気がありますが Bash より圧倒的に読みやすいです(Pascal に近い構文らしいです)

コードを見る

function collatz(i)
  cnt = 0
  while i ~= 1 do
    cnt = cnt + 1
    if i % 2 == 0 then
      i = i / 2
    else
      i = 3 * i
      i = i + 1
    end
  end
  return cnt
end

n = io.stdin:read()
acc = 0
for i = 1, n do
  acc = acc + collatz(i)
  acc = acc % 1000000007
end
print(acc)

Pascal

C#より早いくらいの速度でした。コードは RubyBash に近い印象です(あんまり詳しくないのでこの辺的外れだと思います・・・)

コードを見る

program collatz;

function collatz(i: UInt64): UInt64;
var cnt: UInt64;
begin
  cnt := 0;
  while i <> 1 do
    begin
      cnt := cnt + 1;
      if i mod 2 = 0 then
        begin
          i := i div 2
        end
      else
        begin
          i := 3 * i;
          i := i + 1
        end
    end;
  collatz := cnt
end;

var n:   UInt64;
var acc: UInt64;
var i:   UInt64;

begin
  readln(n);
  acc := 0;
  for i := 1 to n do
    begin
      acc := acc + collatz(i);
      acc := acc mod 1000000007
    end;
  writeln(acc)
end.

Visual Basic

思ったより早い言語

コードを見る

Module Collatz
    Function Collatz(ByVal i As Long)
        Dim cnt As Long
        cnt = 0
        While i <> 1
            cnt += 1
            If i Mod 2 = 0 Then
                i \= 2
            Else
                i *= 3
                i += 1
            End If
        End While
        Return cnt
    End Function

    Sub Main(ByVal args() as String)
        Dim n, i, acc As Long
        n = Console.ReadLine()
        acc = 0
        For i=1 To n
            acc += Collatz(i)
            acc = acc Mod 1000000007
        Next
        Console.WriteLine(acc)
    End Sub
End Module

Raku(Perl6)

Perl と比較しても断然遅かったです。(Perl6 とは)

コードを見る

#!/usr/bin/perl6

sub collatz(Int $i is copy --> Int) {
    my Int $cnt = 0;
    while ($i != 1) {
        $cnt += 1;
        if ($i % 2 == 0) {
            $i div= 2;
        } else {
            $i *= 3;
            $i += 1;
        }
    }
    return $cnt;
}

my Int $n = +get();
my Int $ans = 0;
loop (my Int $i = 1; $i <= $n; $i++) {
    $ans += collatz($i);
    $ans %= 1000000007;
}
say "$ans\n"

dc

電卓のあれらしいですが正直全く読めません。このへんの言語ってどんなきっかけで学ぶんでしょうか

コードを見る

[cq]sa[3*1+q]sb[d2%1=b2/]sc[d1=alcxlx1+sxldx]sd[0sxldxlx]se?sn0si0sy[li1+dsilexly+syliln>l]dslxlyn

Nim

Haskellと同じ感じの再帰実装のようです。コードは比較的読みやすいですね。

コードを見る

import strutils, sequtils

const MOD = 1000000007'u

func collatz(i: uint): uint =
  if i == 1: 0'u
  elif i mod 2 == 0: collatz(i div 2) + 1'u
  else: collatz(i * 3 + 1) + 1'u

proc main() =
  let n = stdin.readLine().parseUInt()
  echo (1'u..n).foldl((a + collatz(b)) mod MOD, first = 0'u)

if isMainModule:
  main()

Nim(uint)

uint を使った非再帰実装のようです。

前者とあまり変わらないですね。

コードを見る

import strutils

const MOD = 1000000007'u

proc collatz(i: uint): uint = 
    var 
        i = i 
        res = 0'u 
    while i != 1:
        if (i and 1) == 0:
            i = i shr 1 
        else:
            i *= 3 
            inc i 
            inc res 
    return res

proc main() = 
    let n = stdin.readLine().parseUInt()
    var 
        i = 1'u 
        ret = 0'u

    while i <= n:
        ret += collatz(i)
        ret = ret mod MOD 
        inc i 
    echo ret

if isMainModule:
    main()

Cython

うーん、PyPy くらい早くなると思っていたのですが。

コードを見る

cpdef int collatx(i):
    cpdef int cnt=0
    while i!=1:
        cnt+=1
        if i%2==0:
            i//=2
        else:
            i*=3
            i+=1
    return cnt

cpdef main():
    cpdef int n=int(input())
    cpdef int acc=0
    for i in range(1,n+1):
        acc+=collatx(i)
        acc%=1000000007
    print(acc)

if __name__ == "__main__":
    main()

Cython2

詳しくないので違いがあんまり判らないんですが,以下のコードだと上のコードに比べかなり高速に実行できるようです

コードを見る

# cython: language_level = 3, nonecheck = False, cdivision = True
ctypedef unsigned long long ull
cdef ull collatz(ull i) nogil:
    cdef ull cnt = 0
    while i != 1:
        cnt += 1
        if i % 2 == 0:
            i //= 2
        else:
            i *= 3
            i += 1
    return cnt

def main():
    cdef:
        ull n = int(input())
        ull acc = 0
    for 1 <= i < n + 1:
        acc += collatz(i)
        acc %= 1000000007
    print(acc)

if __name__ == '__main__':
    main()

Python(Numba)

これ実は動かせていません。

入力を可変にしようと思って色々やったのですが、Python の知識が欠けており…

コードを見る

from numba import jit

@jit('i8(i8)', nopython=True)
def collatx(i):
    cnt = 0
    while i != 1:
        cnt += 1
        if i % 2 == 0:
            i //= 2
        else:
            i *= 3
            i += 1
    return cnt


@jit(nopython=True)
def main():
    n = 10 ** 7
    acc = 0
    for i in range(1, n + 1):
        acc += collatx(i)
        acc %= 1000000007
    print(acc)


if __name__ == "__main__":
    main()

Clojure

コードを見る

(defn collatz [x]
  (loop [i x cnt 0]
    (cond (= i 1) cnt
          (even? i) (recur (quot i 2) (inc cnt))
          :else (recur (inc (* i 3)) (inc cnt)))))

(def n (read))

(println (loop [i 1 acc 0]
           (if (<= i n)
             (recur (inc i) (rem (+ acc (collatz i)) 1000000007))
             acc)))

Common Lisp

コードを見る

(defun collatz (i)
  (loop for cnt from 0
        until (= i 1)
        do (setq i (if (evenp i)
                       (floor i 2)
                       (1+ (* 3 i))))
        finally (return cnt)))

(let ((n (read))
      (acc 0))
  (loop for i from 1 to n
        do (setq acc (rem (+ acc (collatz i)) 1000000007)))
  (princ acc)
  (fresh-line))

Scheme

コードを見る

(define (collatz x)
  (let loop ((i x)
             (cnt 0))
    (if (= i 1)
      cnt
      (if (even? i)
        (loop (quotient i 2) (+ cnt 1))
        (loop (+ 1 (* 3 i)) (+ cnt 1))))))

(define n (read))

(print (let loop ((i 1)
                  (acc 0))
         (if (<= i n)
           (loop (+ i 1) (remainder (+ acc (collatz i)) 1000000007))
           acc)))

Julia

科学計算用途の動的型付け言語らしいんですがかなり早いですね.JIT コンパイルでもしてるんでしょうか

コードを見る

function collatz(i)
    cnt = 0
    while i != 1
        cnt += 1
        if i & 1 == 0
            i >>= 1
        else
            i *= 3
            i += 1
        end
    end
    cnt
end

function main()
    n = parse(Int, readline())
    acc = 0
    for i = 1:n
        acc += collatz(i)
        acc %= 1000000007
    end
    println(acc)
end

if abspath(PROGRAM_FILE) == @__FILE__
    main()
end

numba

i8 を使っているようで結構早いですね.

コードを見る

import sys
if sys.argv[-1] == 'ONLINE_JUDGE':
    from numba import *
    from numba.pycc import CC

    @njit('i8(i8)', parallel = True, cache = True)
    def collatz(i):
        cnt = 0
        while i != 1:
            cnt += 1
            if i & 1 == 0:
                i >>= 1
            else:
                i *= 3
                i += 1
        return cnt

    cc = CC('my_module')
    @cc.export('main', 'void(i8)')
    def main(n):
        acc = 0
        for i in prange(1, n + 1):
            acc += collatz(i)
            acc %= 1000000007
        print(acc)
    cc.compile()
    exit()

if __name__ == '__main__':
    from my_module import main
    main(int(input()))