newton . .
numba.njit
numba jit
import numpy as np
from numba import njit
@njit
def find_k_numba(a, t):
a = np.sort(a)
m = len(a)
s = a.sum()
to_remove = s - t
if to_remove <= 0:
k = 0
else:
for i, x in enumerate(a):
k = to_remove / (m - i)
if k < x:
break
else:
to_remove -= x
return k
numpy
Paul
, .
import numpy as np
def find_k_numpy(a, t):
a = np.sort(a)
m = len(a)
s = a.sum()
to_remove = s - t
if to_remove <= 0:
k = 0
else:
c = a.cumsum()
n = np.arange(m)[::-1]
b = n * a + c
i = np.searchsorted(b, to_remove)
k = a[i] + (to_remove - b[i]) / (m - i)
return k
scipy.optimize.newton
Newton
import numpy as np
from scipy.optimize import newton
def find_k_newton(a, t):
s = a.sum()
if s <= t:
k = 0
else:
def f(k_):
return np.clip(a - k_, 0, a.max()).sum() - t
k = newton(f, (s - t) / len(a))
return k
import pandas as pd
from timeit import timeit
res = pd.DataFrame(
np.nan, [10, 30, 100, 300, 1000, 3000, 10000, 30000],
'find_k_newton find_k_numpy find_k_numba'.split()
)
for i in res.index:
a = np.random.rand(i)
t = a.sum() * .6
for j in res.columns:
stmt = f'{j}(a, t)'
setp = f'from __main__ import {j}, a, t'
res.at[i, j] = timeit(stmt, setp, number=200)
res.plot(loglog=True)

res.div(res.min(1), 0)
find_k_newton find_k_numpy find_k_numba
10 57.265421 17.552150 1.000000
30 29.221947 9.420263 1.000000
100 16.920463 5.294890 1.000000
300 10.761341 3.037060 1.000000
1000 1.455159 1.033066 1.000000
3000 1.000000 2.076484 2.550152
10000 1.000000 3.798906 4.421955
30000 1.000000 5.551422 6.784594