@@ -101,25 +101,29 @@ void butterfly_inv(std::vector<mint>& a) {
101101 }
102102}
103103
104- } // namespace internal
105-
106104template <class mint , internal::is_static_modint_t <mint>* = nullptr >
107- std::vector<mint> convolution ( std::vector<mint> a, std::vector<mint> b) {
105+ std::vector<mint> convolution_naive ( const std::vector<mint>& a, const std::vector<mint>& b) {
108106 int n = int (a.size ()), m = int (b.size ());
109- if (!n || !m) return {};
110- if (std::min (n, m) <= 60 ) {
111- if (n < m) {
112- std::swap (n, m);
113- std::swap (a, b);
107+ std::vector<mint> ans (n + m - 1 );
108+ if (n < m) {
109+ for (int j = 0 ; j < m; j++) {
110+ for (int i = 0 ; i < n; i++) {
111+ ans[i + j] += a[i] * b[j];
112+ }
114113 }
115- std::vector<mint> ans (n + m - 1 );
114+ } else {
116115 for (int i = 0 ; i < n; i++) {
117116 for (int j = 0 ; j < m; j++) {
118117 ans[i + j] += a[i] * b[j];
119118 }
120119 }
121- return ans;
122120 }
121+ return ans;
122+ }
123+
124+ template <class mint , internal::is_static_modint_t <mint>* = nullptr >
125+ std::vector<mint> convolution_fft (std::vector<mint> a, std::vector<mint> b) {
126+ int n = int (a.size ()), m = int (b.size ());
123127 int z = 1 << internal::ceil_pow2 (n + m - 1 );
124128 a.resize (z);
125129 internal::butterfly (a);
@@ -132,7 +136,25 @@ std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
132136 a.resize (n + m - 1 );
133137 mint iz = mint (z).inv ();
134138 for (int i = 0 ; i < n + m - 1 ; i++) a[i] *= iz;
135- return a;
139+ return std::move (a);
140+ }
141+
142+ } // namespace internal
143+
144+ template <class mint , internal::is_static_modint_t <mint>* = nullptr >
145+ std::vector<mint> convolution (std::vector<mint>&& a, std::vector<mint>&& b) {
146+ int n = int (a.size ()), m = int (b.size ());
147+ if (!n || !m) return {};
148+ if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
149+ return internal::convolution_fft (a, b);
150+ }
151+
152+ template <class mint , internal::is_static_modint_t <mint>* = nullptr >
153+ std::vector<mint> convolution (const std::vector<mint>& a, const std::vector<mint>& b) {
154+ int n = int (a.size ()), m = int (b.size ());
155+ if (!n || !m) return {};
156+ if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
157+ return internal::convolution_fft (a, b);
136158}
137159
138160template <unsigned int mod = 998244353 ,
0 commit comments