The Way

이항 계수 빠르게 구하기 본문

공부/컴퓨터 알고리즘

이항 계수 빠르게 구하기

Jeonggyun 2018. 6. 5. 22:01

초보자들에게 이항 계수는 보통 DP를 이용하여 계산하는 것이 정석이라 여겨진다.

하지만 어느 정도 문제를 풀다보면 DP의 메모리 범위 이상의 숫자들을 더 빠르게 계산해야 한다.


이번 AtCoder Grand Contest 025 B번 문제에서도 이를 이용하는 문제가 있었는데, 버벅이다가 결국 못풀었다.

따라서 이항계수를 구하는 방법들을 간단하게 공부해보았다. 이 분야의 한줄기 빛과도 같은 구사과님의 블로그를 참고하였다.



1. DP

생략. 시간, 공간 복잡도가 $O(n^2)$이라 쓸모가 없다.

대신 여기서 사용하는 점화식 $_n C_k = _{n-1}C_{k-1} + _{n-1}C_k$은 많은 아이디어의 기반이 되니 기억하자.



* 아래 2~4번의 방법은 MOD가 N보다 크다고 가정하자. 예기치 못한 오류가 자꾸 발생한다.

발생 원인은 뭐.. N이 MOD보다 크다면 (MOD + 1)!으로 나누는 상황이 발생할 수 있는데, 0으로 나누는 것이기 때문에 썩 좋지가 못하다.

그리고 페르마의 소정리도 a와 p가 서로소일 때만 성립한다.


* 180606 약간 수정. 전역 변수를 initialize list로 초기화하면 컴파일 시간과 결과물 용량이 무식하게 증가한다. main문으로 옮겼다.


2. 페르마의 소정리 이용

보통 자비가 있다면 mod 소수꼴로 구하라고 하는데, 이러면 페르마의 소정리를 이용할 수 있다.

$a^{p - 1} = 1 (mod p)$이므로 $a^{-1} = a^{p - 2} (mod p)$이다.

$a^{p - 1}$은 분할 정복을 이용해 $O(\log{p})$에 계산 가능하므로 총 시간은 $O(n\log{p})$가 된다.


이렇게 전처리를 해놓으면 $_n C_k = n! * k!^{-1} * (n - k)!^{-1} $을 이용해서 구할 수 있다.


mod하는 소수 범위를 잘 보고, 5만 이상이면 저장은 상관 없지만 계산할 때는 long long을 써야 하는 것 잊지 말자.


사실 구하려는 n이 100만 이하라면 long long을 써도 배열 크기 다 합쳐봤자 최악의 경우 (4번 방법) 24MB정도로 넉넉하므로,

코딩의 편의상 long long을 사용하자.

#include <bits/stdc++.h>
#define MOD 1000000007
#define N 1000000

long long fac[N + 1], facinv[N + 1];

long long power(long long base, int index) {
	long long r = 1;
	while (index) {
		if (index & 1) r = (r * base) % MOD;
		base = (base * base) % MOD;
		index >>= 1;
	}
	return r;
}

long long comb(int n, int k) {
	long long r = (fac[n] * facinv[n - k]) % MOD;
	r = (r * facinv[k]) % MOD;
	return r;
}
 
int main() {
	fac[0] = fac[1] = facinv[0] = facinv[1] = 1;
	for (int i = 2; i <= N; ++i) {
		fac[i] = (fac[i - 1] * i) % MOD;
		facinv[i] = power(fac[i], MOD - 2);
	}
	std::cout << comb(987654, 456789) << '\n';
}



3. facinv 계산을 조금 더 빠르게

$(n - 1)!^{-1} = n * (n!)^{-1}$ 라는 놀라운 식이 또 있었다. 바로 적용하자.

#include <bits/stdc++.h>
#define MOD 1000000007
#define N 1000000

long long fac[N + 1], facinv[N + 1];

long long power(long long base, int index) {
	long long r = 1;
	while (index) {
		if (index & 1) r = (r * base) % MOD;
		base = (base * base) % MOD;
		index >>= 1;
	}
	return r;
}

long long comb(int n, int k) {
	long long r = (fac[n] * facinv[n - k]) % MOD;
	r = (r * facinv[k]) % MOD;
	return r;
}
 
int main() {
	fac[0] = fac[1] = facinv[0] = facinv[1] = 1;
	for (int i = 2; i <= N; ++i) {
		fac[i] = (fac[i - 1] * i) % MOD;
	}
	facinv[N] = power(fac[N], MOD - 2);
	for (int i = N; i > 2; --i) {
		facinv[i - 1] = (facinv[i] * i) % MOD;
	}
	
	std::cout << comb(987654, 456789) << '\n';
}

시간 복잡도는 $O(\log{p} + n)$이다. 사실상 이정도면 최선의 코드라고 할 수 있겠다.



4. 정수의 역원을 이용한 계산

정수의 역원을 이용한 흥미로운 방법이다. 시간 복잡도는 놀랍게 $O(n)$인데, 계산 과정상 3번 방법이 실행 속도는 약간 빠를지도 모르겠다.

하지만 정수의 역원을 쉽게 구할 수 있는 방법이니 알아두자. 또, power 코드를 짤 필요도 없어 코드도 더 간단해진다.


$2^{-1}, 3^{-1}$ 등의 정수의 역원을 알고있다면

facinv를 factorial 구하듯 빠르게 계산할 수 있다.


정수의 역원은 다음과 같은 관계를 통해 구할 수 있다.


일반적인 식(여기서 /는 정수 나눗셈이다)에서 출발.

P = (P / x) * x + P % x

P - (P / x) * x = P % x

- (P / x) * x = P % x (mod P)


이제 양변을 x * (P % x)로 나누면

- (P / x) * (P % x)^-1 = x^-1 (mod P)


P % x는 항상 x보다 작으므로 아래서부터 차례로 구해나갈 수 있다.

참고로 중간에 P % x로 나누어주기 때문에 P가 x의 배수가 아닐 때에만,

다시 말해 MOD가 N보다 클 경우에만 + 소수일 경우에만 사용 가능할 것으로.. 보인다.

#include <bits/stdc++.h>
#define MOD 1000000007
#define N 1000000
 
long long fac[N + 1], facinv[N + 1], inv[N + 1];

long long findInv(int n) {
	if (inv[n] > 0) return inv[n];
	return inv[n] = ((MOD - MOD / n) * findInv(MOD % n)) % MOD;
}

long long comb(int n, int k) {
	long long r = (fac[n] * facinv[n - k]) % MOD;
	r = (r * facinv[k]) % MOD;
	return r;
}
  
int main() {
	fac[0] = fac[1] = facinv[0] = facinv[1] = inv[1] = 1;
	for (int i = 2; i <= N; ++i) {
		fac[i] = (fac[i - 1] * i) % MOD;
		findInv(i);
		facinv[i] = (facinv[i - 1] * inv[i]) % MOD;
	}
	std::cout << comb(987654, 456789) << '\n';
}



5. n이 무식하게 크다면 - lucas의 정리

참고하자.


P가 상대적으로 작고, n이 무식하게 클 때 사용함직한 방법이다.

예를 들어 아래와 같이, n이 12345678987654321정도 된다면(...) array를 만들 수가 없다.

이럴 때 사용하기 딱 좋다.

참고로 여기서는 mod 이하의 조합을 구할 때 DP를 썼는데, 꼭 그럴 필요는 없다.


참고로 계산 중간에 n과 k의 mod값이 역전되는 상황이 발생하는데, 그냥 0을 곱해주면 된다.

#include <bits/stdc++.h>
#define MOD 1009

int c[MOD][MOD];

int lucas_comb(long long n, long long k) {
	int r = 1;
	while (n || k) {
		r = (r * c[n % MOD][k % MOD]) % MOD;
		n /= MOD; k /= MOD;
	}
	return r;
}

int main() {
	for (int i = 0; i < MOD; ++i) for (int j = 0; j <= i; ++j) {
		if (i == j || j == 0) c[i][j] = 1;
		else c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % MOD;
	}
	std::cout << lucas_comb(12345678987654321LL, 987654323456789LL) << '\n';
}



6. 만약 p가 소수가 아니면?

난감하다. 일단 마땅히 생각나는 방법은 없다. 일단 차차 생각해보자.

Comments