This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub luzhiled1333/comp-library
#include "src/math/convolution/modint-convolution.hpp"
std::vector< modint > modint_convolution(std::vector< modint > f, std::vector< modint > g)
f と g の畳み込みを返す。
f
g
n
m
#pragma once #include "src/cpp-template/header/int-alias.hpp" #include "src/cpp-template/header/rep.hpp" #include "src/cpp-template/header/size-alias.hpp" #include "src/utility/bit/bit-width.hpp" #include <cassert> #include <vector> namespace luz::internal { template < typename modint > class ButterflyInfo { public: ButterflyInfo() { if (not roots.empty()) { return; } u32 mod = modint::get_mod(); assert(mod >= 3 && mod % 2 == 1); u32 tmp = mod - 1; usize max_base = 0; while (tmp % 2 == 0) { tmp >>= 1; max_base++; } modint root = 2; while (root.pow((mod - 1) >> 1) == modint(1)) { root += 1; } assert(root.pow(mod - 1) == modint(1)); assert(max_base >= 2); roots.resize(max_base + 1); iroots.resize(max_base + 1); rate3.resize(max_base + 1); irate3.resize(max_base + 1); roots[max_base] = root.pow((mod - 1) >> max_base); iroots[max_base] = modint(1) / roots[max_base]; for (usize i: rrep(0, max_base)) { roots[i] = roots[i + 1] * roots[i + 1]; iroots[i] = iroots[i + 1] * iroots[i + 1]; } modint prod = 1, iprod = 1; for (usize i: rep(0, max_base - 2)) { rate3[i] = roots[i + 3] * prod; irate3[i] = iroots[i + 3] * iprod; prod *= iroots[i + 3]; iprod *= roots[i + 3]; } } static std::vector< modint > roots, iroots, rate3, irate3; }; template < typename modint > std::vector< modint > ButterflyInfo< modint >::roots = std::vector< modint >(); template < typename modint > std::vector< modint > ButterflyInfo< modint >::iroots = std::vector< modint >(); template < typename modint > std::vector< modint > ButterflyInfo< modint >::rate3 = std::vector< modint >(); template < typename modint > std::vector< modint > ButterflyInfo< modint >::irate3 = std::vector< modint >(); template < typename modint > void butterfly(std::vector< modint > &vs) { usize n = vs.size(); assert((n & (n - 1)) == 0); usize h = bit_width(n) - 1; ButterflyInfo< modint > info; assert(h < info.iroots.size()); usize len = 0; modint imag = info.roots[2]; if (h & 1) { usize p = 1 << (h - 1); for (usize i: rep(0, p)) { modint r = vs[i + p]; vs[i + p] = vs[i] - r; vs[i] += r; } len++; } while (len + 1 < h) { usize p = 1 << (h - len - 2); for (usize i: rep(0, p)) { modint a0 = vs[i + 0 * p]; modint a1 = vs[i + 1 * p]; modint a2 = vs[i + 2 * p]; modint a3 = vs[i + 3 * p]; modint a1na3imag = (a1 - a3) * imag; modint a0a2 = a0 + a2; modint a1a3 = a1 + a3; modint a0na2 = a0 - a2; vs[i + 0 * p] = a0a2 + a1a3; vs[i + 1 * p] = a0a2 - a1a3; vs[i + 2 * p] = a0na2 + a1na3imag; vs[i + 3 * p] = a0na2 - a1na3imag; } modint rot = info.rate3[0]; for (usize s: rep(1, 1 << len)) { usize offset = s << (h - len); modint rot2 = rot * rot; modint rot3 = rot2 * rot; for (usize i: rep(0, p)) { modint a0 = vs[i + offset + 0 * p]; modint a1 = vs[i + offset + 1 * p] * rot; modint a2 = vs[i + offset + 2 * p] * rot2; modint a3 = vs[i + offset + 3 * p] * rot3; modint a1na3imag = (a1 - a3) * imag; modint a0a2 = a0 + a2; modint a1a3 = a1 + a3; modint a0na2 = a0 - a2; vs[i + offset + 0 * p] = a0a2 + a1a3; vs[i + offset + 1 * p] = a0a2 - a1a3; vs[i + offset + 2 * p] = a0na2 + a1na3imag; vs[i + offset + 3 * p] = a0na2 - a1na3imag; } rot *= info.rate3[__builtin_ctz(~s)]; } len += 2; } } template < typename modint > void butterfly_inv(std::vector< modint > &vs) { usize n = vs.size(); assert((n & (n - 1)) == 0); usize h = bit_width(n) - 1; ButterflyInfo< modint > info; assert(h < info.iroots.size()); usize len = h; modint iimag = info.iroots[2]; while (len > 1) { usize p = 1 << (h - len); for (usize i: rep(0, p)) { modint a0 = vs[i + 0 * p]; modint a1 = vs[i + 1 * p]; modint a2 = vs[i + 2 * p]; modint a3 = vs[i + 3 * p]; modint a2na3iimag = (a2 - a3) * iimag; modint a0na1 = a0 - a1; modint a0a1 = a0 + a1; modint a2a3 = a2 + a3; vs[i + 0 * p] = a0a1 + a2a3; vs[i + 1 * p] = (a0na1 + a2na3iimag); vs[i + 2 * p] = (a0a1 - a2a3); vs[i + 3 * p] = (a0na1 - a2na3iimag); } modint irot = info.irate3[0]; for (usize s: rep(1, 1 << (len - 2))) { usize offset = s << (h - len + 2); modint irot2 = irot * irot; modint irot3 = irot2 * irot; for (usize i: rep(0, p)) { modint a0 = vs[i + offset + 0 * p]; modint a1 = vs[i + offset + 1 * p]; modint a2 = vs[i + offset + 2 * p]; modint a3 = vs[i + offset + 3 * p]; modint a2na3iimag = (a2 - a3) * iimag; modint a0na1 = a0 - a1; modint a0a1 = a0 + a1; modint a2a3 = a2 + a3; vs[i + offset + 0 * p] = a0a1 + a2a3; vs[i + offset + 1 * p] = (a0na1 + a2na3iimag) * irot; vs[i + offset + 2 * p] = (a0a1 - a2a3) * irot2; vs[i + offset + 3 * p] = (a0na1 - a2na3iimag) * irot3; } irot *= info.irate3[__builtin_ctz(~s)]; } len -= 2; } if (len > 0) { usize p = 1 << (h - 1); for (usize i: rep(0, p)) { modint ajp = vs[i] - vs[i + p]; vs[i] += vs[i + p]; vs[i + p] = ajp; } } } } // namespace luz::internal namespace luz { template < typename modint > std::vector< modint > modint_convolution(std::vector< modint > f, std::vector< modint > g) { assert(not f.empty() and not g.empty()); usize n = f.size(), m = g.size(); usize s = 1 << bit_width(n + m - 2); f.resize(s); g.resize(s); internal::butterfly(f); internal::butterfly(g); modint s_inv = modint(1) / s; for (usize i: rep(0, s)) { f[i] *= g[i] * s_inv; } internal::butterfly_inv(f); f.resize(n + m - 1); return f; } } // namespace luz
#line 2 "src/math/convolution/modint-convolution.hpp" #line 2 "src/cpp-template/header/int-alias.hpp" #include <cstdint> namespace luz { using i32 = std::int32_t; using i64 = std::int64_t; using i128 = __int128_t; using u32 = std::uint32_t; using u64 = std::uint64_t; using u128 = __uint128_t; } // namespace luz #line 2 "src/cpp-template/header/rep.hpp" #line 2 "src/cpp-template/header/size-alias.hpp" #include <cstddef> namespace luz { using isize = std::ptrdiff_t; using usize = std::size_t; } // namespace luz #line 4 "src/cpp-template/header/rep.hpp" #include <algorithm> namespace luz { struct rep { struct itr { usize i; constexpr itr(const usize i) noexcept: i(i) {} void operator++() noexcept { ++i; } constexpr usize operator*() const noexcept { return i; } constexpr bool operator!=(const itr x) const noexcept { return i != x.i; } }; const itr f, l; constexpr rep(const usize f, const usize l) noexcept : f(std::min(f, l)), l(l) {} constexpr auto begin() const noexcept { return f; } constexpr auto end() const noexcept { return l; } }; struct rrep { struct itr { usize i; constexpr itr(const usize i) noexcept: i(i) {} void operator++() noexcept { --i; } constexpr usize operator*() const noexcept { return i; } constexpr bool operator!=(const itr x) const noexcept { return i != x.i; } }; const itr f, l; constexpr rrep(const usize f, const usize l) noexcept : f(l - 1), l(std::min(f, l) - 1) {} constexpr auto begin() const noexcept { return f; } constexpr auto end() const noexcept { return l; } }; } // namespace luz #line 2 "src/utility/bit/bit-width.hpp" #line 2 "src/utility/bit/popcount.hpp" #line 5 "src/utility/bit/popcount.hpp" #include <cassert> namespace luz { usize popcount(u64 x) { assert(__cplusplus <= 201703L); #ifdef __GNUC__ return __builtin_popcountll(x); #endif x -= (x >> 1) & 0x5555555555555555; x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333); x += (x >> 4) & 0x0f0f0f0f0f0f0f0f; return x * 0x0101010101010101 >> 56 & 0x7f; } } // namespace luz #line 6 "src/utility/bit/bit-width.hpp" #line 8 "src/utility/bit/bit-width.hpp" namespace luz { usize bit_width(u64 x) { assert(__cplusplus <= 201703L); if (x == 0) { return 0; } #ifdef __GNUC__ return 64 - __builtin_clzll(x); #endif x |= x >> 1; x |= x >> 2; x |= x >> 4; x |= x >> 8; x |= x >> 16; x |= x >> 32; return popcount(x); } } // namespace luz #line 7 "src/math/convolution/modint-convolution.hpp" #line 9 "src/math/convolution/modint-convolution.hpp" #include <vector> namespace luz::internal { template < typename modint > class ButterflyInfo { public: ButterflyInfo() { if (not roots.empty()) { return; } u32 mod = modint::get_mod(); assert(mod >= 3 && mod % 2 == 1); u32 tmp = mod - 1; usize max_base = 0; while (tmp % 2 == 0) { tmp >>= 1; max_base++; } modint root = 2; while (root.pow((mod - 1) >> 1) == modint(1)) { root += 1; } assert(root.pow(mod - 1) == modint(1)); assert(max_base >= 2); roots.resize(max_base + 1); iroots.resize(max_base + 1); rate3.resize(max_base + 1); irate3.resize(max_base + 1); roots[max_base] = root.pow((mod - 1) >> max_base); iroots[max_base] = modint(1) / roots[max_base]; for (usize i: rrep(0, max_base)) { roots[i] = roots[i + 1] * roots[i + 1]; iroots[i] = iroots[i + 1] * iroots[i + 1]; } modint prod = 1, iprod = 1; for (usize i: rep(0, max_base - 2)) { rate3[i] = roots[i + 3] * prod; irate3[i] = iroots[i + 3] * iprod; prod *= iroots[i + 3]; iprod *= roots[i + 3]; } } static std::vector< modint > roots, iroots, rate3, irate3; }; template < typename modint > std::vector< modint > ButterflyInfo< modint >::roots = std::vector< modint >(); template < typename modint > std::vector< modint > ButterflyInfo< modint >::iroots = std::vector< modint >(); template < typename modint > std::vector< modint > ButterflyInfo< modint >::rate3 = std::vector< modint >(); template < typename modint > std::vector< modint > ButterflyInfo< modint >::irate3 = std::vector< modint >(); template < typename modint > void butterfly(std::vector< modint > &vs) { usize n = vs.size(); assert((n & (n - 1)) == 0); usize h = bit_width(n) - 1; ButterflyInfo< modint > info; assert(h < info.iroots.size()); usize len = 0; modint imag = info.roots[2]; if (h & 1) { usize p = 1 << (h - 1); for (usize i: rep(0, p)) { modint r = vs[i + p]; vs[i + p] = vs[i] - r; vs[i] += r; } len++; } while (len + 1 < h) { usize p = 1 << (h - len - 2); for (usize i: rep(0, p)) { modint a0 = vs[i + 0 * p]; modint a1 = vs[i + 1 * p]; modint a2 = vs[i + 2 * p]; modint a3 = vs[i + 3 * p]; modint a1na3imag = (a1 - a3) * imag; modint a0a2 = a0 + a2; modint a1a3 = a1 + a3; modint a0na2 = a0 - a2; vs[i + 0 * p] = a0a2 + a1a3; vs[i + 1 * p] = a0a2 - a1a3; vs[i + 2 * p] = a0na2 + a1na3imag; vs[i + 3 * p] = a0na2 - a1na3imag; } modint rot = info.rate3[0]; for (usize s: rep(1, 1 << len)) { usize offset = s << (h - len); modint rot2 = rot * rot; modint rot3 = rot2 * rot; for (usize i: rep(0, p)) { modint a0 = vs[i + offset + 0 * p]; modint a1 = vs[i + offset + 1 * p] * rot; modint a2 = vs[i + offset + 2 * p] * rot2; modint a3 = vs[i + offset + 3 * p] * rot3; modint a1na3imag = (a1 - a3) * imag; modint a0a2 = a0 + a2; modint a1a3 = a1 + a3; modint a0na2 = a0 - a2; vs[i + offset + 0 * p] = a0a2 + a1a3; vs[i + offset + 1 * p] = a0a2 - a1a3; vs[i + offset + 2 * p] = a0na2 + a1na3imag; vs[i + offset + 3 * p] = a0na2 - a1na3imag; } rot *= info.rate3[__builtin_ctz(~s)]; } len += 2; } } template < typename modint > void butterfly_inv(std::vector< modint > &vs) { usize n = vs.size(); assert((n & (n - 1)) == 0); usize h = bit_width(n) - 1; ButterflyInfo< modint > info; assert(h < info.iroots.size()); usize len = h; modint iimag = info.iroots[2]; while (len > 1) { usize p = 1 << (h - len); for (usize i: rep(0, p)) { modint a0 = vs[i + 0 * p]; modint a1 = vs[i + 1 * p]; modint a2 = vs[i + 2 * p]; modint a3 = vs[i + 3 * p]; modint a2na3iimag = (a2 - a3) * iimag; modint a0na1 = a0 - a1; modint a0a1 = a0 + a1; modint a2a3 = a2 + a3; vs[i + 0 * p] = a0a1 + a2a3; vs[i + 1 * p] = (a0na1 + a2na3iimag); vs[i + 2 * p] = (a0a1 - a2a3); vs[i + 3 * p] = (a0na1 - a2na3iimag); } modint irot = info.irate3[0]; for (usize s: rep(1, 1 << (len - 2))) { usize offset = s << (h - len + 2); modint irot2 = irot * irot; modint irot3 = irot2 * irot; for (usize i: rep(0, p)) { modint a0 = vs[i + offset + 0 * p]; modint a1 = vs[i + offset + 1 * p]; modint a2 = vs[i + offset + 2 * p]; modint a3 = vs[i + offset + 3 * p]; modint a2na3iimag = (a2 - a3) * iimag; modint a0na1 = a0 - a1; modint a0a1 = a0 + a1; modint a2a3 = a2 + a3; vs[i + offset + 0 * p] = a0a1 + a2a3; vs[i + offset + 1 * p] = (a0na1 + a2na3iimag) * irot; vs[i + offset + 2 * p] = (a0a1 - a2a3) * irot2; vs[i + offset + 3 * p] = (a0na1 - a2na3iimag) * irot3; } irot *= info.irate3[__builtin_ctz(~s)]; } len -= 2; } if (len > 0) { usize p = 1 << (h - 1); for (usize i: rep(0, p)) { modint ajp = vs[i] - vs[i + p]; vs[i] += vs[i + p]; vs[i + p] = ajp; } } } } // namespace luz::internal namespace luz { template < typename modint > std::vector< modint > modint_convolution(std::vector< modint > f, std::vector< modint > g) { assert(not f.empty() and not g.empty()); usize n = f.size(), m = g.size(); usize s = 1 << bit_width(n + m - 2); f.resize(s); g.resize(s); internal::butterfly(f); internal::butterfly(g); modint s_inv = modint(1) / s; for (usize i: rep(0, s)) { f[i] *= g[i] * s_inv; } internal::butterfly_inv(f); f.resize(n + m - 1); return f; } } // namespace luz