ttamx's library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: polynomial/ntt.hpp

Depends on

Required by

Verified with

Code

#pragma once
#include "modular-arithmetic/binpow.hpp"

/**
 * Author: Teetat T.
 * Description: Number Theoretic Transform
 * Time: $O(N \log N)$
 */

template<class mint>
struct NTT{
	using vm = vector<mint>;
	
	static constexpr mint root=mint::get_root();
    static_assert(root!=0, "root must be nonzero");

	static void ntt(vm &a){
		int n=a.size(),L=31-__builtin_clz(n);
		vm rt(n);
		rt[1]=1;
		for(int k=2,s=2;k<n;k*=2,s++){
			mint z[]={1,binpow(root,MOD>>s)};
			for(int i=k;i<2*k;i++)rt[i]=rt[i/2]*z[i&1];
		}
		vector<int> rev(n);
		for(int i=1;i<n;i++)rev[i]=(rev[i/2]|(i&1)<<L)/2;
		for(int i=1;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
		for(int k=1;k<n;k*=2)for(int i=0;i<n;i+=2*k)for(int j=0;j<k;j++){
			mint z=rt[j+k]*a[i+j+k];
			a[i+j+k]=a[i+j]-z;
			a[i+j]+=z;
		}
	}
	static vm conv(const vm &a,const vm &b){
		if(a.empty()||b.empty())return {};
		int s=a.size()+b.size()-1,n=1<<(32-__builtin_clz(s));
		mint inv=mint(n).inv();
		vm in1(a),in2(b),out(n);
		in1.resize(n),in2.resize(n);
		ntt(in1),ntt(in2);
		for(int i=0;i<n;i++)out[-i&(n-1)]=in1[i]*in2[i]*inv;
		ntt(out);
		return vm(out.begin(),out.begin()+s);
	}
	vm operator()(const vm &a,const vm &b){
		return conv(a,b);
	}
};
#line 2 "modular-arithmetic/binpow.hpp"

/**
 * Author: Teetat T.
 * Date: 2024-01-15
 * Description: n-th power using divide and conquer
 * Time: $O(\log b)$
 */

template<class T>
constexpr T binpow(T a,ll b){
    T res=1;
    for(;b>0;b>>=1,a*=a)if(b&1)res*=a;
    return res;
}

#line 3 "polynomial/ntt.hpp"

/**
 * Author: Teetat T.
 * Description: Number Theoretic Transform
 * Time: $O(N \log N)$
 */

template<class mint>
struct NTT{
	using vm = vector<mint>;
	
	static constexpr mint root=mint::get_root();
    static_assert(root!=0, "root must be nonzero");

	static void ntt(vm &a){
		int n=a.size(),L=31-__builtin_clz(n);
		vm rt(n);
		rt[1]=1;
		for(int k=2,s=2;k<n;k*=2,s++){
			mint z[]={1,binpow(root,MOD>>s)};
			for(int i=k;i<2*k;i++)rt[i]=rt[i/2]*z[i&1];
		}
		vector<int> rev(n);
		for(int i=1;i<n;i++)rev[i]=(rev[i/2]|(i&1)<<L)/2;
		for(int i=1;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
		for(int k=1;k<n;k*=2)for(int i=0;i<n;i+=2*k)for(int j=0;j<k;j++){
			mint z=rt[j+k]*a[i+j+k];
			a[i+j+k]=a[i+j]-z;
			a[i+j]+=z;
		}
	}
	static vm conv(const vm &a,const vm &b){
		if(a.empty()||b.empty())return {};
		int s=a.size()+b.size()-1,n=1<<(32-__builtin_clz(s));
		mint inv=mint(n).inv();
		vm in1(a),in2(b),out(n);
		in1.resize(n),in2.resize(n);
		ntt(in1),ntt(in2);
		for(int i=0;i<n;i++)out[-i&(n-1)]=in1[i]*in2[i]*inv;
		ntt(out);
		return vm(out.begin(),out.begin()+s);
	}
	vm operator()(const vm &a,const vm &b){
		return conv(a,b);
	}
};
Back to top page