Each bit level has a pattern consisting of 2^power
0s followed by 2^power
1s.
So there are three cases:
When M
and N
are such that M = 0 mod 2^(power+1)
and N = 2^(power+1)-1 mod 2^(power+1)
. In this case, the answer is simply (N-M+1) / 2
When M
and N
are such that both M and N = the same number when the integer is divisible by 2^(power+1)
. In this case, there are several subcases:
- Both
M
and N
are such that both M
and N
= the same number when the integer is divisible by 2^(power)
. In this case, if N < 2^(power) mod 2^(power+1)
, then the answer will be 0
, otherwise the answer will be N-M+1
- Otherwise they are different, in this case the answer is
N - (N/2^(power+1))*2^(power+1) + 2**(power)
(integer division) if N > 2^(power) mod 2^(power+1)
, otherwise the answer is (M/2^(power+1))*2^(power+1) - 1 - M
The last case is where M and N = different numbers when the integer is divisible by 2^(power+1)
. In this case, you can combine methods 1 and 2. Find the number of numbers between M
and (M/(2^(power+1)) + 1)*(2^(power+1)) - 1
. Then between (M/(2^(power+1)) + 1)*(2^(power+1))
and (N/(2^(power+1)))*2^(power+1)-1
. And finally, between (N/(2^(power+1)))*2^(power+1)
and N
If this answer has logical errors, let me know, it's complicated, and I may have messed up a bit.
UPDATE:
python implementation
def case1(M, N): return (N - M + 1) // 2 def case2(M, N, power): if (M > N): return 0 if (M // 2**(power) == N // 2**(power)): if (N % 2**(power+1) < 2**(power)): return 0 else: return N - M + 1 else: if (N % 2**(power+1) >= 2**(power)): return N - (getNextLower(N,power+1) + 2**(power)) + 1 else: return getNextHigher(M, power+1) - M def case3(M, N, power): return case2(M, getNextHigher(M, power+1) - 1, power) + case1(getNextHigher(M, power+1), getNextLower(N, power+1)-1) + case2(getNextLower(N, power+1), N, power) def getNextLower(M, power): return (M // 2**(power))*2**(power) def getNextHigher(M, power): return (M // 2**(power) + 1)*2**(power) def numSetBits(M, N, power): if (M % 2**(power+1) == 0 and N % 2**(power+1) == 2**(power+1)-1): return case1(M,N) if (M // 2**(power+1) == N // 2**(power+1)): return case2(M,N,power) else: return case3(M,N,power) if (__name__ == "__main__"): print numSetBits(0,10,0) print numSetBits(0,10,1) print numSetBits(0,10,2) print numSetBits(0,10,3) print numSetBits(0,10,4) print numSetBits(5,18,0) print numSetBits(5,18,1) print numSetBits(5,18,2) print numSetBits(5,18,3) print numSetBits(5,18,4)