(Warning: I jerked a bit when the range of factors was not [0, n), so I adjusted it. It is easy to compensate.)
I am going to sketch out with tested Python code an implementation that runs on time O (log max {a, b}) . Firstly, there are some utility functions and a naive implementation.
from fractions import gcd from random import randrange def coprime(a, b): return gcd(a, b) == 1 def floordiv(a, b): return a // b def ceildiv(a, b): return floordiv(a + b - 1, b) def count1(a, b, n, m): assert 1 <= a < b assert coprime(a, b) assert 0 <= n < b + 1 assert 0 <= m < b + 1 return sum(k * a % b < m for k in range(n))
Now, how can we speed it up? The first improvement is to split the multipliers into disjoint ranges so that within the range the corresponding multiples of a are between two multiples of b . Knowing the lowest and highest values, we can calculate with the help of ceiling division the number of multiples less than m .
def count2(a, b, n, m): assert 1 <= a < b assert coprime(a, b) assert 0 <= n < b + 1 assert 0 <= m < b + 1 count = 0 first = 0 while 0 < n: count += min(ceildiv(m - first, a), n) k = ceildiv(b - first, a) n -= k first = first + k * a - b return count
This is not fast enough. The second improvement is to replace most of the while loop with a recursive call. The code j below shows the number of iterations that are “completed” in the sense that there is a workaround. term3 takes into account the remaining iteration using logic similar to count2 .
Each of the complete iterations contributes floor(m / a) or floor(m / a) + 1 under the threshold m . Whether we get + 1 depends on what first for this iteration. first starts at 0 and changes to a - (b % a) modulo a at each iteration through a while loop. We get + 1 whenever it is under some threshold, and this score is calculated through a recursive call.
def count3(a, b, n, m): assert 1 <= a < b assert coprime(a, b) assert 0 <= n < b + 1 assert 0 <= m < b + 1 if 1 == a: return min(n, m) j = floordiv(n * a, b) term1 = j * floordiv(m, a) term2 = count3(a - b % a, a, j, m % a) last = n * a % b first = last % a term3 = min(ceildiv(m - first, a), (last - first)
Runtime can be analyzed similarly to the Euclidean GCD algorithm.
Here is some test code that confirms my statements about correctness. Remember to remove claims before performance testing.
def test(p, f1, f2): assert 3 <= p for t in range(100): while True: b = randrange(2, p) a = randrange(1, b) if coprime(a, b): break for n in range(b + 1): for m in range(b + 1): args = (a, b, n, m) print(args) assert f1(*args) == f2(*args) if __name__ == '__main__': test(25, count1, count2) test(25, count1, count3)