comp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub luzhiled1333/comp-library

:heavy_check_mark: ワイルドカードパターンマッチング
(src/sequence/wildcard-pattern-matching.hpp)

wildcard_pattern_matching

template< modint, class T, class Iter >
std::vector< i32 >
wildcard_pattern_matching(Iter f1, Iter l1, Iter f2, Iter l2, T wildcard)

イテレータ $[f1, l1)$, $[f2, l2)$ のワイルドカードありパターンマッチングの結果を返す。

result の配列の $i$ 番目には $[f1 + i, f1 + i + (l2 - f2))$ と $[f2, l2)$ がマッチしている場合には 1 が、マッチしていない場合は 0 が格納されている。

wildcard にはワイルドカードと扱うべき値を渡す。

TIter についてはおそらく推論が効くが、modint については効かないと思われるため明示するとよい。

計算量

note: 衝突について

内部的に $mod M$ で計算しているため、「取りうる値の2乗程度 * 長さ」が $M$ を超える場合は値が十分ランダムであっても各判定につき $expected O(1/M)$ で誤判定をしてしまう。

本来はマッチしない文字列がマッチしたと検出されることはあっても、本来はマッチする文字列がマッチしないことはないため、modint として複数のものを渡し、その結果の and を取るほうが実用上安全である。

note: 値の取りうる範囲について

列の値が取りうる値として 0 が含まれている場合、何らかの手段を用いて $[1, k)$ への単射で写しておく必要がある。

Depends on

Verified with

Code

#pragma once

#include "src/cpp-template/header/int-alias.hpp"
#include "src/cpp-template/header/size-alias.hpp"
#include "src/math/convolution/modint-convolution.hpp"

#include <cassert>
#include <vector>

namespace luz {

  // [warning] false positive occur expect O(1/M)
  //           when values are randomized
  // [note] try to use multiple mods if necessary
  template < class modint, class T, class Iter >
  std::vector< i32 > wildcard_pattern_matching(Iter f1, Iter l1,
                                               Iter f2, Iter l2,
                                               const T wildcard) {
    usize n = l1 - f1, m = l2 - f2;
    assert(m <= n);

    std::vector< modint > as(n), bs(n), cs(n), ss(m), ts(m), us(m);

    for (Iter iter = f1; iter != l1; ++iter) {
      modint x(*iter == wildcard ? 0 : *iter);
      modint y(*iter == wildcard ? 0 : 1);
      usize i = iter - f1;
      as[i]   = y * x * x;
      bs[i]   = y * x * -2;
      cs[i]   = y;
    }

    for (Iter iter = f2; iter != l2; ++iter) {
      modint x(*iter == wildcard ? 0 : *iter);
      modint y(*iter == wildcard ? 0 : 1);
      usize i = l2 - iter - 1;
      ss[i]   = y;
      ts[i]   = y * x;
      us[i]   = y * x * x;
    }

    auto f = modint_convolution(as, ss);
    auto g = modint_convolution(bs, ts);
    auto h = modint_convolution(cs, us);

    std::vector< i32 > result(n - m + 1);
    for (usize i: rep(0, result.size())) {
      usize j = i + m - 1;
      modint x(f[j] + g[j] + h[j]);
      if (x.val() == 0) result[i] = 1;
    }

    return result;
  }

} // namespace luz
#line 2 "src/sequence/wildcard-pattern-matching.hpp"

#line 2 "src/cpp-template/header/int-alias.hpp"

#include <cstdint>

namespace luz {

  using i32  = std::int32_t;
  using i64  = std::int64_t;
  using i128 = __int128_t;

  using u32  = std::uint32_t;
  using u64  = std::uint64_t;
  using u128 = __uint128_t;

} // namespace luz
#line 2 "src/cpp-template/header/size-alias.hpp"

#include <cstddef>

namespace luz {

  using isize = std::ptrdiff_t;
  using usize = std::size_t;

} // namespace luz
#line 2 "src/math/convolution/modint-convolution.hpp"

#line 2 "src/cpp-template/header/rep.hpp"

#line 4 "src/cpp-template/header/rep.hpp"

#include <algorithm>

namespace luz {

  struct rep {
    struct itr {
      usize i;
      constexpr itr(const usize i) noexcept: i(i) {}
      void operator++() noexcept {
        ++i;
      }
      constexpr usize operator*() const noexcept {
        return i;
      }
      constexpr bool operator!=(const itr x) const noexcept {
        return i != x.i;
      }
    };
    const itr f, l;
    constexpr rep(const usize f, const usize l) noexcept
        : f(std::min(f, l)),
          l(l) {}
    constexpr auto begin() const noexcept {
      return f;
    }
    constexpr auto end() const noexcept {
      return l;
    }
  };

  struct rrep {
    struct itr {
      usize i;
      constexpr itr(const usize i) noexcept: i(i) {}
      void operator++() noexcept {
        --i;
      }
      constexpr usize operator*() const noexcept {
        return i;
      }
      constexpr bool operator!=(const itr x) const noexcept {
        return i != x.i;
      }
    };
    const itr f, l;
    constexpr rrep(const usize f, const usize l) noexcept
        : f(l - 1),
          l(std::min(f, l) - 1) {}
    constexpr auto begin() const noexcept {
      return f;
    }
    constexpr auto end() const noexcept {
      return l;
    }
  };

} // namespace luz
#line 2 "src/utility/bit/bit-width.hpp"

#line 2 "src/utility/bit/popcount.hpp"

#line 5 "src/utility/bit/popcount.hpp"

#include <cassert>

namespace luz {

  usize popcount(u64 x) {
    assert(__cplusplus <= 201703L);

#ifdef __GNUC__
    return __builtin_popcountll(x);
#endif

    x -= (x >> 1) & 0x5555555555555555;
    x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333);
    x += (x >> 4) & 0x0f0f0f0f0f0f0f0f;
    return x * 0x0101010101010101 >> 56 & 0x7f;
  }

} // namespace luz
#line 6 "src/utility/bit/bit-width.hpp"

#line 8 "src/utility/bit/bit-width.hpp"

namespace luz {

  usize bit_width(u64 x) {
    assert(__cplusplus <= 201703L);

    if (x == 0) {
      return 0;
    }

#ifdef __GNUC__
    return 64 - __builtin_clzll(x);
#endif

    x |= x >> 1;
    x |= x >> 2;
    x |= x >> 4;
    x |= x >> 8;
    x |= x >> 16;
    x |= x >> 32;
    return popcount(x);
  }

} // namespace luz
#line 7 "src/math/convolution/modint-convolution.hpp"

#line 9 "src/math/convolution/modint-convolution.hpp"
#include <vector>

namespace luz::internal {

  template < typename modint >
  class ButterflyInfo {
   public:
    ButterflyInfo() {
      if (not roots.empty()) {
        return;
      }

      u32 mod = modint::get_mod();
      assert(mod >= 3 && mod % 2 == 1);
      u32 tmp        = mod - 1;
      usize max_base = 0;
      while (tmp % 2 == 0) {
        tmp >>= 1;
        max_base++;
      }
      modint root = 2;
      while (root.pow((mod - 1) >> 1) == modint(1)) {
        root += 1;
      }
      assert(root.pow(mod - 1) == modint(1));
      assert(max_base >= 2);
      roots.resize(max_base + 1);
      iroots.resize(max_base + 1);
      rate3.resize(max_base + 1);
      irate3.resize(max_base + 1);

      roots[max_base]  = root.pow((mod - 1) >> max_base);
      iroots[max_base] = modint(1) / roots[max_base];
      for (usize i: rrep(0, max_base)) {
        roots[i]  = roots[i + 1] * roots[i + 1];
        iroots[i] = iroots[i + 1] * iroots[i + 1];
      }
      modint prod = 1, iprod = 1;
      for (usize i: rep(0, max_base - 2)) {
        rate3[i]  = roots[i + 3] * prod;
        irate3[i] = iroots[i + 3] * iprod;
        prod *= iroots[i + 3];
        iprod *= roots[i + 3];
      }
    }
    static std::vector< modint > roots, iroots, rate3, irate3;
  };
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::roots =
      std::vector< modint >();
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::iroots =
      std::vector< modint >();
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::rate3 =
      std::vector< modint >();
  template < typename modint >
  std::vector< modint > ButterflyInfo< modint >::irate3 =
      std::vector< modint >();

  template < typename modint >
  void butterfly(std::vector< modint > &vs) {
    usize n = vs.size();
    assert((n & (n - 1)) == 0);
    usize h = bit_width(n) - 1;
    ButterflyInfo< modint > info;
    assert(h < info.iroots.size());
    usize len   = 0;
    modint imag = info.roots[2];
    if (h & 1) {
      usize p = 1 << (h - 1);
      for (usize i: rep(0, p)) {
        modint r  = vs[i + p];
        vs[i + p] = vs[i] - r;
        vs[i] += r;
      }
      len++;
    }
    while (len + 1 < h) {
      usize p = 1 << (h - len - 2);
      for (usize i: rep(0, p)) {
        modint a0        = vs[i + 0 * p];
        modint a1        = vs[i + 1 * p];
        modint a2        = vs[i + 2 * p];
        modint a3        = vs[i + 3 * p];
        modint a1na3imag = (a1 - a3) * imag;
        modint a0a2      = a0 + a2;
        modint a1a3      = a1 + a3;
        modint a0na2     = a0 - a2;
        vs[i + 0 * p]    = a0a2 + a1a3;
        vs[i + 1 * p]    = a0a2 - a1a3;
        vs[i + 2 * p]    = a0na2 + a1na3imag;
        vs[i + 3 * p]    = a0na2 - a1na3imag;
      }
      modint rot = info.rate3[0];
      for (usize s: rep(1, 1 << len)) {
        usize offset = s << (h - len);
        modint rot2  = rot * rot;
        modint rot3  = rot2 * rot;
        for (usize i: rep(0, p)) {
          modint a0              = vs[i + offset + 0 * p];
          modint a1              = vs[i + offset + 1 * p] * rot;
          modint a2              = vs[i + offset + 2 * p] * rot2;
          modint a3              = vs[i + offset + 3 * p] * rot3;
          modint a1na3imag       = (a1 - a3) * imag;
          modint a0a2            = a0 + a2;
          modint a1a3            = a1 + a3;
          modint a0na2           = a0 - a2;
          vs[i + offset + 0 * p] = a0a2 + a1a3;
          vs[i + offset + 1 * p] = a0a2 - a1a3;
          vs[i + offset + 2 * p] = a0na2 + a1na3imag;
          vs[i + offset + 3 * p] = a0na2 - a1na3imag;
        }
        rot *= info.rate3[__builtin_ctz(~s)];
      }
      len += 2;
    }
  }

  template < typename modint >
  void butterfly_inv(std::vector< modint > &vs) {
    usize n = vs.size();
    assert((n & (n - 1)) == 0);
    usize h = bit_width(n) - 1;
    ButterflyInfo< modint > info;
    assert(h < info.iroots.size());
    usize len    = h;
    modint iimag = info.iroots[2];
    while (len > 1) {
      usize p = 1 << (h - len);
      for (usize i: rep(0, p)) {
        modint a0         = vs[i + 0 * p];
        modint a1         = vs[i + 1 * p];
        modint a2         = vs[i + 2 * p];
        modint a3         = vs[i + 3 * p];
        modint a2na3iimag = (a2 - a3) * iimag;
        modint a0na1      = a0 - a1;
        modint a0a1       = a0 + a1;
        modint a2a3       = a2 + a3;
        vs[i + 0 * p]     = a0a1 + a2a3;
        vs[i + 1 * p]     = (a0na1 + a2na3iimag);
        vs[i + 2 * p]     = (a0a1 - a2a3);
        vs[i + 3 * p]     = (a0na1 - a2na3iimag);
      }
      modint irot = info.irate3[0];
      for (usize s: rep(1, 1 << (len - 2))) {
        usize offset = s << (h - len + 2);
        modint irot2 = irot * irot;
        modint irot3 = irot2 * irot;
        for (usize i: rep(0, p)) {
          modint a0              = vs[i + offset + 0 * p];
          modint a1              = vs[i + offset + 1 * p];
          modint a2              = vs[i + offset + 2 * p];
          modint a3              = vs[i + offset + 3 * p];
          modint a2na3iimag      = (a2 - a3) * iimag;
          modint a0na1           = a0 - a1;
          modint a0a1            = a0 + a1;
          modint a2a3            = a2 + a3;
          vs[i + offset + 0 * p] = a0a1 + a2a3;
          vs[i + offset + 1 * p] = (a0na1 + a2na3iimag) * irot;
          vs[i + offset + 2 * p] = (a0a1 - a2a3) * irot2;
          vs[i + offset + 3 * p] = (a0na1 - a2na3iimag) * irot3;
        }
        irot *= info.irate3[__builtin_ctz(~s)];
      }
      len -= 2;
    }
    if (len > 0) {
      usize p = 1 << (h - 1);
      for (usize i: rep(0, p)) {
        modint ajp = vs[i] - vs[i + p];
        vs[i] += vs[i + p];
        vs[i + p] = ajp;
      }
    }
  }

} // namespace luz::internal

namespace luz {

  template < typename modint >
  std::vector< modint > modint_convolution(std::vector< modint > f,
                                           std::vector< modint > g) {
    assert(not f.empty() and not g.empty());
    usize n = f.size(), m = g.size();
    usize s = 1 << bit_width(n + m - 2);
    f.resize(s);
    g.resize(s);
    internal::butterfly(f);
    internal::butterfly(g);
    modint s_inv = modint(1) / s;
    for (usize i: rep(0, s)) {
      f[i] *= g[i] * s_inv;
    }
    internal::butterfly_inv(f);
    f.resize(n + m - 1);
    return f;
  }

} // namespace luz
#line 6 "src/sequence/wildcard-pattern-matching.hpp"

#line 9 "src/sequence/wildcard-pattern-matching.hpp"

namespace luz {

  // [warning] false positive occur expect O(1/M)
  //           when values are randomized
  // [note] try to use multiple mods if necessary
  template < class modint, class T, class Iter >
  std::vector< i32 > wildcard_pattern_matching(Iter f1, Iter l1,
                                               Iter f2, Iter l2,
                                               const T wildcard) {
    usize n = l1 - f1, m = l2 - f2;
    assert(m <= n);

    std::vector< modint > as(n), bs(n), cs(n), ss(m), ts(m), us(m);

    for (Iter iter = f1; iter != l1; ++iter) {
      modint x(*iter == wildcard ? 0 : *iter);
      modint y(*iter == wildcard ? 0 : 1);
      usize i = iter - f1;
      as[i]   = y * x * x;
      bs[i]   = y * x * -2;
      cs[i]   = y;
    }

    for (Iter iter = f2; iter != l2; ++iter) {
      modint x(*iter == wildcard ? 0 : *iter);
      modint y(*iter == wildcard ? 0 : 1);
      usize i = l2 - iter - 1;
      ss[i]   = y;
      ts[i]   = y * x;
      us[i]   = y * x * x;
    }

    auto f = modint_convolution(as, ss);
    auto g = modint_convolution(bs, ts);
    auto h = modint_convolution(cs, us);

    std::vector< i32 > result(n - m + 1);
    for (usize i: rep(0, result.size())) {
      usize j = i + m - 1;
      modint x(f[j] + g[j] + h[j]);
      if (x.val() == 0) result[i] = 1;
    }

    return result;
  }

} // namespace luz
Back to top page