comp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub luzhiled1333/comp-library

:heavy_check_mark: Number Theoretic Transform
(src/math/convolution/modint-convolution.hpp)

modint_convolution

std::vector< modint > modint_convolution(std::vector< modint > f, std::vector< modint > g)

fg の畳み込みを返す。

計算量

Depends on

Required by

Verified with

Code

#pragma once

#include "src/cpp-template/header/int-alias.hpp"
#include "src/cpp-template/header/rep.hpp"
#include "src/cpp-template/header/size-alias.hpp"
#include "src/utility/bit/bit-width.hpp"

#include <cassert>
#include <vector>

namespace luz::internal {

  template < typename modint >
  class ButterflyInfo {
   public:
    ButterflyInfo() {
      if (not roots.empty()) {
        return;
      }

      u32 mod = modint::get_mod();
      assert(mod >= 3 && mod % 2 == 1);
      u32 tmp        = mod - 1;
      usize max_base = 0;
      while (tmp % 2 == 0) {
        tmp >>= 1;
        max_base++;
      }
      modint root = 2;
      while (root.pow((mod - 1) >> 1) == modint(1)) {
        root += 1;
      }
      assert(root.pow(mod - 1) == modint(1));
      assert(max_base >= 2);
      roots.resize(max_base + 1);
      iroots.resize(max_base + 1);
      rate3.resize(max_base + 1);
      irate3.resize(max_base + 1);

      roots[max_base]  = root.pow((mod - 1) >> max_base);
      iroots[max_base] = modint(1) / roots[max_base];
      for (usize i: rrep(0, max_base)) {
        roots[i]  = roots[i + 1] * roots[i + 1];
        iroots[i] = iroots[i + 1] * iroots[i + 1];
      }
      modint prod = 1, iprod = 1;
      for (usize i: rep(0, max_base - 2)) {
        rate3[i]  = roots[i + 3] * prod;
        irate3[i] = iroots[i + 3] * iprod;
        prod *= iroots[i + 3];
        iprod *= roots[i + 3];
      }
    }
    static std::vector< modint > roots, iroots, rate3, irate3;
  };
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::roots =
      std::vector< modint >();
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::iroots =
      std::vector< modint >();
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::rate3 =
      std::vector< modint >();
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::irate3 =
      std::vector< modint >();

  template < typename modint >
  void butterfly(std::vector< modint > &vs) {
    usize n = vs.size();
    assert((n & (n - 1)) == 0);
    usize h = bit_width(n) - 1;
    ButterflyInfo< modint > info;
    assert(h < info.iroots.size());
    usize len   = 0;
    modint imag = info.roots[2];
    if (h & 1) {
      usize p = 1 << (h - 1);
      for (usize i: rep(0, p)) {
        modint r  = vs[i + p];
        vs[i + p] = vs[i] - r;
        vs[i] += r;
      }
      len++;
    }
    while (len + 1 < h) {
      usize p = 1 << (h - len - 2);
      for (usize i: rep(0, p)) {
        modint a0        = vs[i + 0 * p];
        modint a1        = vs[i + 1 * p];
        modint a2        = vs[i + 2 * p];
        modint a3        = vs[i + 3 * p];
        modint a1na3imag = (a1 - a3) * imag;
        modint a0a2      = a0 + a2;
        modint a1a3      = a1 + a3;
        modint a0na2     = a0 - a2;
        vs[i + 0 * p]    = a0a2 + a1a3;
        vs[i + 1 * p]    = a0a2 - a1a3;
        vs[i + 2 * p]    = a0na2 + a1na3imag;
        vs[i + 3 * p]    = a0na2 - a1na3imag;
      }
      modint rot = info.rate3[0];
      for (usize s: rep(1, 1 << len)) {
        usize offset = s << (h - len);
        modint rot2  = rot * rot;
        modint rot3  = rot2 * rot;
        for (usize i: rep(0, p)) {
          modint a0              = vs[i + offset + 0 * p];
          modint a1              = vs[i + offset + 1 * p] * rot;
          modint a2              = vs[i + offset + 2 * p] * rot2;
          modint a3              = vs[i + offset + 3 * p] * rot3;
          modint a1na3imag       = (a1 - a3) * imag;
          modint a0a2            = a0 + a2;
          modint a1a3            = a1 + a3;
          modint a0na2           = a0 - a2;
          vs[i + offset + 0 * p] = a0a2 + a1a3;
          vs[i + offset + 1 * p] = a0a2 - a1a3;
          vs[i + offset + 2 * p] = a0na2 + a1na3imag;
          vs[i + offset + 3 * p] = a0na2 - a1na3imag;
        }
        rot *= info.rate3[__builtin_ctz(~s)];
      }
      len += 2;
    }
  }

  template < typename modint >
  void butterfly_inv(std::vector< modint > &vs) {
    usize n = vs.size();
    assert((n & (n - 1)) == 0);
    usize h = bit_width(n) - 1;
    ButterflyInfo< modint > info;
    assert(h < info.iroots.size());
    usize len    = h;
    modint iimag = info.iroots[2];
    while (len > 1) {
      usize p = 1 << (h - len);
      for (usize i: rep(0, p)) {
        modint a0         = vs[i + 0 * p];
        modint a1         = vs[i + 1 * p];
        modint a2         = vs[i + 2 * p];
        modint a3         = vs[i + 3 * p];
        modint a2na3iimag = (a2 - a3) * iimag;
        modint a0na1      = a0 - a1;
        modint a0a1       = a0 + a1;
        modint a2a3       = a2 + a3;
        vs[i + 0 * p]     = a0a1 + a2a3;
        vs[i + 1 * p]     = (a0na1 + a2na3iimag);
        vs[i + 2 * p]     = (a0a1 - a2a3);
        vs[i + 3 * p]     = (a0na1 - a2na3iimag);
      }
      modint irot = info.irate3[0];
      for (usize s: rep(1, 1 << (len - 2))) {
        usize offset = s << (h - len + 2);
        modint irot2 = irot * irot;
        modint irot3 = irot2 * irot;
        for (usize i: rep(0, p)) {
          modint a0              = vs[i + offset + 0 * p];
          modint a1              = vs[i + offset + 1 * p];
          modint a2              = vs[i + offset + 2 * p];
          modint a3              = vs[i + offset + 3 * p];
          modint a2na3iimag      = (a2 - a3) * iimag;
          modint a0na1           = a0 - a1;
          modint a0a1            = a0 + a1;
          modint a2a3            = a2 + a3;
          vs[i + offset + 0 * p] = a0a1 + a2a3;
          vs[i + offset + 1 * p] = (a0na1 + a2na3iimag) * irot;
          vs[i + offset + 2 * p] = (a0a1 - a2a3) * irot2;
          vs[i + offset + 3 * p] = (a0na1 - a2na3iimag) * irot3;
        }
        irot *= info.irate3[__builtin_ctz(~s)];
      }
      len -= 2;
    }
    if (len > 0) {
      usize p = 1 << (h - 1);
      for (usize i: rep(0, p)) {
        modint ajp = vs[i] - vs[i + p];
        vs[i] += vs[i + p];
        vs[i + p] = ajp;
      }
    }
  }

} // namespace luz::internal

namespace luz {

  template < typename modint >
  std::vector< modint > modint_convolution(std::vector< modint > f,
                                           std::vector< modint > g) {
    assert(not f.empty() and not g.empty());
    usize n = f.size(), m = g.size();
    usize s = 1 << bit_width(n + m - 2);
    f.resize(s);
    g.resize(s);
    internal::butterfly(f);
    internal::butterfly(g);
    modint s_inv = modint(1) / s;
    for (usize i: rep(0, s)) {
      f[i] *= g[i] * s_inv;
    }
    internal::butterfly_inv(f);
    f.resize(n + m - 1);
    return f;
  }

} // namespace luz
#line 2 "src/math/convolution/modint-convolution.hpp"

#line 2 "src/cpp-template/header/int-alias.hpp"

#include <cstdint>

namespace luz {

  using i32  = std::int32_t;
  using i64  = std::int64_t;
  using i128 = __int128_t;

  using u32  = std::uint32_t;
  using u64  = std::uint64_t;
  using u128 = __uint128_t;

} // namespace luz
#line 2 "src/cpp-template/header/rep.hpp"

#line 2 "src/cpp-template/header/size-alias.hpp"

#include <cstddef>

namespace luz {

  using isize = std::ptrdiff_t;
  using usize = std::size_t;

} // namespace luz
#line 4 "src/cpp-template/header/rep.hpp"

#include <algorithm>

namespace luz {

  struct rep {
    struct itr {
      usize i;
      constexpr itr(const usize i) noexcept: i(i) {}
      void operator++() noexcept {
        ++i;
      }
      constexpr usize operator*() const noexcept {
        return i;
      }
      constexpr bool operator!=(const itr x) const noexcept {
        return i != x.i;
      }
    };
    const itr f, l;
    constexpr rep(const usize f, const usize l) noexcept
        : f(std::min(f, l)),
          l(l) {}
    constexpr auto begin() const noexcept {
      return f;
    }
    constexpr auto end() const noexcept {
      return l;
    }
  };

  struct rrep {
    struct itr {
      usize i;
      constexpr itr(const usize i) noexcept: i(i) {}
      void operator++() noexcept {
        --i;
      }
      constexpr usize operator*() const noexcept {
        return i;
      }
      constexpr bool operator!=(const itr x) const noexcept {
        return i != x.i;
      }
    };
    const itr f, l;
    constexpr rrep(const usize f, const usize l) noexcept
        : f(l - 1),
          l(std::min(f, l) - 1) {}
    constexpr auto begin() const noexcept {
      return f;
    }
    constexpr auto end() const noexcept {
      return l;
    }
  };

} // namespace luz
#line 2 "src/utility/bit/bit-width.hpp"

#line 2 "src/utility/bit/popcount.hpp"

#line 5 "src/utility/bit/popcount.hpp"

#include <cassert>

namespace luz {

  usize popcount(u64 x) {
    assert(__cplusplus <= 201703L);

#ifdef __GNUC__
    return __builtin_popcountll(x);
#endif

    x -= (x >> 1) & 0x5555555555555555;
    x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333);
    x += (x >> 4) & 0x0f0f0f0f0f0f0f0f;
    return x * 0x0101010101010101 >> 56 & 0x7f;
  }

} // namespace luz
#line 6 "src/utility/bit/bit-width.hpp"

#line 8 "src/utility/bit/bit-width.hpp"

namespace luz {

  usize bit_width(u64 x) {
    assert(__cplusplus <= 201703L);

    if (x == 0) {
      return 0;
    }

#ifdef __GNUC__
    return 64 - __builtin_clzll(x);
#endif

    x |= x >> 1;
    x |= x >> 2;
    x |= x >> 4;
    x |= x >> 8;
    x |= x >> 16;
    x |= x >> 32;
    return popcount(x);
  }

} // namespace luz
#line 7 "src/math/convolution/modint-convolution.hpp"

#line 9 "src/math/convolution/modint-convolution.hpp"
#include <vector>

namespace luz::internal {

  template < typename modint >
  class ButterflyInfo {
   public:
    ButterflyInfo() {
      if (not roots.empty()) {
        return;
      }

      u32 mod = modint::get_mod();
      assert(mod >= 3 && mod % 2 == 1);
      u32 tmp        = mod - 1;
      usize max_base = 0;
      while (tmp % 2 == 0) {
        tmp >>= 1;
        max_base++;
      }
      modint root = 2;
      while (root.pow((mod - 1) >> 1) == modint(1)) {
        root += 1;
      }
      assert(root.pow(mod - 1) == modint(1));
      assert(max_base >= 2);
      roots.resize(max_base + 1);
      iroots.resize(max_base + 1);
      rate3.resize(max_base + 1);
      irate3.resize(max_base + 1);

      roots[max_base]  = root.pow((mod - 1) >> max_base);
      iroots[max_base] = modint(1) / roots[max_base];
      for (usize i: rrep(0, max_base)) {
        roots[i]  = roots[i + 1] * roots[i + 1];
        iroots[i] = iroots[i + 1] * iroots[i + 1];
      }
      modint prod = 1, iprod = 1;
      for (usize i: rep(0, max_base - 2)) {
        rate3[i]  = roots[i + 3] * prod;
        irate3[i] = iroots[i + 3] * iprod;
        prod *= iroots[i + 3];
        iprod *= roots[i + 3];
      }
    }
    static std::vector< modint > roots, iroots, rate3, irate3;
  };
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::roots =
      std::vector< modint >();
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::iroots =
      std::vector< modint >();
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::rate3 =
      std::vector< modint >();
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::irate3 =
      std::vector< modint >();

  template < typename modint >
  void butterfly(std::vector< modint > &vs) {
    usize n = vs.size();
    assert((n & (n - 1)) == 0);
    usize h = bit_width(n) - 1;
    ButterflyInfo< modint > info;
    assert(h < info.iroots.size());
    usize len   = 0;
    modint imag = info.roots[2];
    if (h & 1) {
      usize p = 1 << (h - 1);
      for (usize i: rep(0, p)) {
        modint r  = vs[i + p];
        vs[i + p] = vs[i] - r;
        vs[i] += r;
      }
      len++;
    }
    while (len + 1 < h) {
      usize p = 1 << (h - len - 2);
      for (usize i: rep(0, p)) {
        modint a0        = vs[i + 0 * p];
        modint a1        = vs[i + 1 * p];
        modint a2        = vs[i + 2 * p];
        modint a3        = vs[i + 3 * p];
        modint a1na3imag = (a1 - a3) * imag;
        modint a0a2      = a0 + a2;
        modint a1a3      = a1 + a3;
        modint a0na2     = a0 - a2;
        vs[i + 0 * p]    = a0a2 + a1a3;
        vs[i + 1 * p]    = a0a2 - a1a3;
        vs[i + 2 * p]    = a0na2 + a1na3imag;
        vs[i + 3 * p]    = a0na2 - a1na3imag;
      }
      modint rot = info.rate3[0];
      for (usize s: rep(1, 1 << len)) {
        usize offset = s << (h - len);
        modint rot2  = rot * rot;
        modint rot3  = rot2 * rot;
        for (usize i: rep(0, p)) {
          modint a0              = vs[i + offset + 0 * p];
          modint a1              = vs[i + offset + 1 * p] * rot;
          modint a2              = vs[i + offset + 2 * p] * rot2;
          modint a3              = vs[i + offset + 3 * p] * rot3;
          modint a1na3imag       = (a1 - a3) * imag;
          modint a0a2            = a0 + a2;
          modint a1a3            = a1 + a3;
          modint a0na2           = a0 - a2;
          vs[i + offset + 0 * p] = a0a2 + a1a3;
          vs[i + offset + 1 * p] = a0a2 - a1a3;
          vs[i + offset + 2 * p] = a0na2 + a1na3imag;
          vs[i + offset + 3 * p] = a0na2 - a1na3imag;
        }
        rot *= info.rate3[__builtin_ctz(~s)];
      }
      len += 2;
    }
  }

  template < typename modint >
  void butterfly_inv(std::vector< modint > &vs) {
    usize n = vs.size();
    assert((n & (n - 1)) == 0);
    usize h = bit_width(n) - 1;
    ButterflyInfo< modint > info;
    assert(h < info.iroots.size());
    usize len    = h;
    modint iimag = info.iroots[2];
    while (len > 1) {
      usize p = 1 << (h - len);
      for (usize i: rep(0, p)) {
        modint a0         = vs[i + 0 * p];
        modint a1         = vs[i + 1 * p];
        modint a2         = vs[i + 2 * p];
        modint a3         = vs[i + 3 * p];
        modint a2na3iimag = (a2 - a3) * iimag;
        modint a0na1      = a0 - a1;
        modint a0a1       = a0 + a1;
        modint a2a3       = a2 + a3;
        vs[i + 0 * p]     = a0a1 + a2a3;
        vs[i + 1 * p]     = (a0na1 + a2na3iimag);
        vs[i + 2 * p]     = (a0a1 - a2a3);
        vs[i + 3 * p]     = (a0na1 - a2na3iimag);
      }
      modint irot = info.irate3[0];
      for (usize s: rep(1, 1 << (len - 2))) {
        usize offset = s << (h - len + 2);
        modint irot2 = irot * irot;
        modint irot3 = irot2 * irot;
        for (usize i: rep(0, p)) {
          modint a0              = vs[i + offset + 0 * p];
          modint a1              = vs[i + offset + 1 * p];
          modint a2              = vs[i + offset + 2 * p];
          modint a3              = vs[i + offset + 3 * p];
          modint a2na3iimag      = (a2 - a3) * iimag;
          modint a0na1           = a0 - a1;
          modint a0a1            = a0 + a1;
          modint a2a3            = a2 + a3;
          vs[i + offset + 0 * p] = a0a1 + a2a3;
          vs[i + offset + 1 * p] = (a0na1 + a2na3iimag) * irot;
          vs[i + offset + 2 * p] = (a0a1 - a2a3) * irot2;
          vs[i + offset + 3 * p] = (a0na1 - a2na3iimag) * irot3;
        }
        irot *= info.irate3[__builtin_ctz(~s)];
      }
      len -= 2;
    }
    if (len > 0) {
      usize p = 1 << (h - 1);
      for (usize i: rep(0, p)) {
        modint ajp = vs[i] - vs[i + p];
        vs[i] += vs[i + p];
        vs[i + p] = ajp;
      }
    }
  }

} // namespace luz::internal

namespace luz {

  template < typename modint >
  std::vector< modint > modint_convolution(std::vector< modint > f,
                                           std::vector< modint > g) {
    assert(not f.empty() and not g.empty());
    usize n = f.size(), m = g.size();
    usize s = 1 << bit_width(n + m - 2);
    f.resize(s);
    g.resize(s);
    internal::butterfly(f);
    internal::butterfly(g);
    modint s_inv = modint(1) / s;
    for (usize i: rep(0, s)) {
      f[i] *= g[i] * s_inv;
    }
    internal::butterfly_inv(f);
    f.resize(n + m - 1);
    return f;
  }

} // namespace luz
Back to top page