快速幂取模算法

yuyu888 于 2021-01-19 发布

在做RSA算法实例推导的时候,遇到一个计算问题:求 23123 % 2091 的值
我需要先算出23123 的值

普通求幂算法

根据表达式的原始含义,求 ab

package main

import (
	"fmt"
)

func main() {
	fmt.Println(pow(3, 11))
}

func pow(a int, b int) int {
	ans := 1
	for {
		if b == 0 {
			break
		}
		ans *= a
		b--
	}
	return ans
}

这里 b 等于几就要循环几次,效率太低

快速幂算法

先做一个公式变换:

ab = ab/2 * ab/2 = (a * a)b/2

假如 b 是偶数
那么 我们只用先算出 a*a 然后 再循环 b/2 次就行了 是不是节约了很多?

假如 b 是奇数
公式稍微变换下:

ab = a * a(b-1)/2 * a(b-1)/2 = a * ( (a * a)(b-1)/2 )


为了方便, 假如 b 是 4 的倍数
把 a * a 的结果当成一个数 m ,把 b/2 当成另一个数n
a * a = m
b/2 = n
ab = (a * a)b/2 = mn 其实就变成 mn 也可以套用上述的公式变换
那么 ab = mn = (m * m)n/2 = (a * a * a * a)(b/2)/2 = (a4)(b/4)

我们只需要 先算出 a4 再循环 b/4 次就行了 又节约了 近一半运算

按照这个思路, 我们就可以不断的去拆解 b 进行递归运算;

判断奇偶 b % 2 == 1 可以用 b&1 替换 ; b/2 可以用 b»1 替换

递归求解代码:

func quick_pow(a int, b int) int {
	if b == 0 {
        return 1
    }
    // 	fmt.Println(a)
    if b&1 == 1 {
        return a * quick_pow(a*a, (b-1)>>1)
    } else {
        return quick_pow(a*a, b>>1)
    }
}

通过上述函数计算 311 = 177147 把注释取消打印出a
3         ——>31
9         ——>32
81       ——>34
6561   ——>38
观察规律
311 = 177147 = 3 * 7 * 6561 = 31 * 32 * 38
311 = 3 * (3 * 3) * (3 * 3 * 3 * 3 * 3 * 3 * 3 * 3)
其中 81 没有参加运算
11 = 1 + 2 + 8 是一个11 转换 2进制的 分解因式; 11 的二进制表示是 1011;

由此我们可以很容易把对指数b的分解,变成一个把b分解成对其二进制的表达式的分解,位数值为0时不参与最终值的计算
快速幂也可以使用一种非递归方式

 func QuickPow(a int, b int) int {
	ans := 1
	for {
		if b == 0 {
			break
		}
		if b&1 == 1 {
			ans *= a
		}
		a *= a
		b = b >> 1
	}
	return ans
}

快速幂求模

通过快速幂取模算法,我们确实可以极大的减少计算复杂度,回归到 我们初始目标:

求 23123 % 2091 的值

我们真的要先求 23123 再对 2091 取模吗? 23123的结果会溢出; 我们实际上需要得到的数小于 2091

数论里有取模运算的几个法则:
(a + b) % p = (a % p + b % p) % p
(a - b) % p = (a % p - b % p) % p
(a * b) % p = (a % p )*(b % p) % p

重点关注下 (a * b) % p = (a % p )*(b % p) % p 这个
证明

假设 a = p * m + e 则 a%p = e   
假设 b = p * n + d 则 b%p = d
(a * b) % p =(p*m+e) * (p*n+d) % p = (p*m*p*n+p*m*d+e*p*n+e*d)%p = (e * d) % p = (a % p )*(b % p) % p 

通过这个可以推导出 a b % p = ((a % p)b) % p

根据快速幂求解 可以稍微改造下

func QuickPowMod(a int, b int, m int) int {
	ans := 1
	mod := a % m
	for {
		if b == 0 {
			break
		}
		if b&1 == 1 {
			ans = (ans * mod) % m
		}
		mod = (mod * mod) % m
		b = b >> 1
	}
	return ans
}