How to do arithmetic modulo a different number, without overflow?

I am trying to run a quick primality test for Rust types u32 and u64 . As part of this, I need to calculate (n*n)%d , where n and d are u32 (or u64 , respectively).

While the result can easily fit into the data type, I don’t understand how to calculate it. As far as I know, there is no processor primitive for this.

For u32 we can fake it - drop it to u64 so that the product does not overflow, then we take the module, and then return it to u32 , knowing that it will not overflow. However, since I don't have the u128 data u128 (as far as I know), this trick will not work for u64 .

So, for u64 most obvious way I can come up with for this is to somehow calculate x*y to get a pair of (carry, product) of u64 , so we fix the number of overflows instead of just losing it (or panic or something else).

Is there any way to do this? Or another standard way to solve the problem?

+5
source share
3 answers

Richard Rast pointed out that the Wikipedia version only works with 63-bit integers. I have extended the code provided by Boiethios to work with a full set of 64-bit unsigned integers.

 fn mul_mod64(mut x: u64, mut y: u64, m: u64) -> u64 { let msb = 0x8000_0000_0000_0000; let mut d = 0; let mp2 = m >> 1; x %= m; y %= m; if m & msb == 0 { for _ in 0..64 { d = if d > mp2 { (d << 1) - m } else { d << 1 }; if x & msb != 0 { d += y; } if d >= m { d -= m; } x <<= 1; } d } else { for _ in 0..64 { d = if d > mp2 { d.wrapping_shl(1).wrapping_sub(m) } else { // the case d == m && x == 0 is taken care of // after the end of the loop d << 1 }; if x & msb != 0 { let (mut d1, overflow) = d.overflowing_add(y); if overflow { d1 = d1.wrapping_sub(m); } d = if d1 >= m { d1 - m } else { d1 }; } x <<= 1; } if d >= m { d - m } else { d } } } #[test] fn test_mul_mod64() { let half = 1 << 16; let max = std::u64::MAX; assert_eq!(mul_mod64(0, 0, 2), 0); assert_eq!(mul_mod64(1, 0, 2), 0); assert_eq!(mul_mod64(0, 1, 2), 0); assert_eq!(mul_mod64(1, 1, 2), 1); assert_eq!(mul_mod64(42, 1, 2), 0); assert_eq!(mul_mod64(1, 42, 2), 0); assert_eq!(mul_mod64(42, 42, 2), 0); assert_eq!(mul_mod64(42, 42, 42), 0); assert_eq!(mul_mod64(42, 42, 41), 1); assert_eq!(mul_mod64(1239876, 2948635, 234897), 163320); assert_eq!(mul_mod64(1239876, 2948635, half), 18476); assert_eq!(mul_mod64(half, half, half), 0); assert_eq!(mul_mod64(half+1, half+1, half), 1); assert_eq!(mul_mod64(max, max, max), 0); assert_eq!(mul_mod64(1239876, 2948635, max), 3655941769260); assert_eq!(mul_mod64(1239876, max, max), 0); assert_eq!(mul_mod64(1239876, max-1, max), max-1239876); assert_eq!(mul_mod64(max, 2948635, max), 0); assert_eq!(mul_mod64(max-1, 2948635, max), max-2948635); assert_eq!(mul_mod64(max-1, max-1, max), 1); assert_eq!(mul_mod64(2, max/2, max-1), 0); } 
+4
source

Use simple math:

 (n*n)%d = (n%d)*(n%d)%d 

To make sure this is true, set n = k*d+r :

 n*n%d = k**2*d**2+2*k*d*r+r**2 %d = r**2%d = (n%d)*(n%d)%d 
0
source

red75prime added a helpful comment . Here is the rust code for calculating modulo two multiplied numbers taken from Wikipedia:

 fn mul_mod(mut x: u64, mut y: u64, m: u64) -> u64 { let mut d = 0_u64; let mp2 = m >> 1; x %= m; y %= m; for _ in 0..64 { d = if d > mp2 { (d << 1) - m } else { d << 1 }; if x & 0x8000_0000_0000_0000_u64 != 0 { d += y; } if d > m { d -= m; } x <<= 1; } d } 
0
source

Source: https://habr.com/ru/post/1271292/


All Articles