謎バタフライ演算(WIP)
Twitter見てたらこういうのが流れてきた。
キャッシュ意識して書いたら高速ゼータ変換が倍速くらいになってヤバい
— 熨斗袋 (@noshi91) 2023年4月11日
— 熨斗袋 (@noshi91) 2023年4月11日
n ≤ 21 くらいだとほとんど同じ速度になる (多分 L2 に全部乗るから)
— 熨斗袋 (@noshi91) 2023年4月11日
どうなっているのか追うの結構大変だけど、上の桁から"非同期的"に徐々に処理されていく感じになっている。
(TODO:もっと丁寧に説明)
よく眺めてみると、別にこれは高速ゼータ変換だけでなく一般のバタフライ演算(クロネッカー積やFFT)の処理にも適用できる(はず)。
ということでとりあえずこれをFFTのバタフライ演算に使ってみた。(以下はACLのbutterflyの1 bit版を変形したもの)
template <class mint, internal::is_static_modint_t<mint>* = nullptr> void butterfly2(std::vector<mint>& a) { using namespace atcoder::internal; int n = int(a.size()); static const fft_info<mint> info; std::array<mint, info.rank2> temp; mint rot = 1; for (int i = 0; i < n; i += 2) { const int j_end = __builtin_ctz(i | n); mint r = rot; for (int j = 1; j < j_end; j++) { temp[j] = r; r *= r; } for (int j = j_end - 1; j >= 0; r = temp[j--]) { const int w = 1 << j; for (int k = 0; k < w; k++) { auto vl = a[i + k]; auto vr = a[i + k + w] * r; a[i + k] = vl + vr; a[i + k + w] = vl - vr; } } if (i + 2 != n) { rot *= info.rate2[bsf(~(unsigned int)(i >> 1))]; } } }
ACLの1 bit版と速度を比較したけど、ほとんど差が出なくて速くなったのか遅くなったのか良く分かりませんでした。いかがでしたか?
(TODO:もうちょっとちゃんとやる)
(TODO: 4-baseに相当するものを書いて比較) とりあえず書いた
template <class mint, internal::is_static_modint_t<mint>* = nullptr> void butterly2_4base(std::vector<mint>& a) { using namespace atcoder::internal; int n = int(a.size()); int h = ceil_pow2(n); static const fft_info<mint> info; if (h % 2 == 1) { int p = 1 << (h - 1); for (int i = 0; i < p; i++) { auto l = a[i]; auto r = a[i + p]; a[i] = l + r; a[i + p] = l - r; } } static std::array<mint, info.rank2> rots; for (int j = 0; j < h; j += 2) rots[j] = 1; const mint imag = info.root[2]; for (int i = 0, j_end = h; i < n; i += 4) { for (int j = (j_end & ~1) - 2; j >= 0; j -= 2) { mint rot = rots[j]; mint rot2 = rot * rot; mint rot3 = rot2 * rot; const int w = 1 << j; for (int k = 0; k < w; k++) { auto mod2 = 1ULL * mint::mod() * mint::mod(); auto a0 = 1ULL * a[k + i].val(); auto a1 = 1ULL * a[k + i + w].val() * rot.val(); auto a2 = 1ULL * a[k + i + 2 * w].val() * rot2.val(); auto a3 = 1ULL * a[k + i + 3 * w].val() * rot3.val(); auto a1na3imag = 1ULL * mint(a1 + mod2 - a3).val() * imag.val(); auto na2 = mod2 - a2; a[k + i] = a0 + a2 + a1 + a3; a[k + i + 1 * w] = a0 + a2 + (2 * mod2 - (a1 + a3)); a[k + i + 2 * w] = a0 + na2 + a1na3imag; a[k + i + 3 * w] = a0 + na2 + (mod2 - a1na3imag); } } if (i + 4 != n) { j_end = bsf(~(unsigned int)(i | 3)); for (int j = (j_end & ~1) - 2; j >= 0; j -= 2) { rots[j] *= info.rate3[j_end - j - 2]; } } } }
ついでの話。逆向きに回せば下の桁から徐々に処理されていくようになるはず。つまり下みたいにi,jの増減の方向を逆にするとそうなる。kはどうでもいい。
void FZT_pre_2(std::vector<u64> &a) { const int n = a.size(); for (int i = n - 2; i >= 0; i -= 2) { for (int j = 0, j_end = __builtin_ctz(i | n); j < j_end; j++) { const int w = 1 << j; for (int k = 0; k < w; k++) { a[i + k] += a[i + k + w]; } } } }
終わりに
こういう風にもバタフライ演算を書くことができるんですね。すごい。