We can extend this solution to your 3D case, using np.lib.stride_tricks.as_strided based on sliding-windowed views to efficiently retrieve patches, for example:
from skimage.util.shape import view_as_windows def get_patches(data, locations, size): # Get 2D sliding windows for each element off data w = view_as_windows(data, (1,1,size,size)) # Use fancy/advanced indexing to select the required ones return w[np.arange(len(locations)), :, locations[:,0], locations[:,1]][:,:,0,0]
We need those 1,1 as the window options for view_as_windows , because it expects the window to have the same number of elements as the amount of dim input. We glide along the last two axes of data , so we save the first two as 1s , basically without sliding along the first two axes of data .
Launch examples for single channel and more channel data -
In [78]: n, c, h, w = 3, 1, 4, 4 # number of channels = 1 ...: data = np.arange(n * c * h * w).reshape(n, c, h, w) ...: ...: size = 2 ...: locations = np.array([ ...: [0, 1], ...: [1, 1], ...: [0, 2] ...: ]) ...: ...: crops = np.stack([d[:, y:y+size, x:x+size] ...: for d, (y,x) in zip(data, locations)]) In [79]: print np.allclose(get_patches(data, locations, size), crops) True In [80]: n, c, h, w = 3, 5, 4, 4 # number of channels = 5 ...: data = np.arange(n * c * h * w).reshape(n, c, h, w) ...: ...: size = 2 ...: locations = np.array([ ...: [0, 1], ...: [1, 1], ...: [0, 2] ...: ]) ...: ...: crops = np.stack([d[:, y:y+size, x:x+size] ...: for d, (y,x) in zip(data, locations)]) In [81]: print np.allclose(get_patches(data, locations, size), crops) True
Benchmarking
Other approaches -
# Original soln def stack(data, locations, size): crops = np.stack([d[:, y:y+size, x:x+size] for d, (y,x) in zip(data, locations)]) return crops
From the comments, it seems that OP is interested in the case with form data (512,1,60,60) and size as 12,24,48 . So let me customize the data for them using the function -
# Setup data def create_inputs(size): np.random.seed(0) n, c, h, w = 512, 1, 60, 60 data = np.arange(n * c * h * w).reshape(n, c, h, w) locations = np.random.randint(0,3,(n,2)) return data, locations, size
Dates -
In [186]: data, locations, size = create_inputs(size=12) In [187]: %timeit stack(data, locations, size) ...: %timeit allocate_assign(data, locations, size) ...: %timeit get_patches(data, locations, size) 1000 loops, best of 3: 1.26 ms per loop 1000 loops, best of 3: 1.06 ms per loop 10000 loops, best of 3: 124 µs per loop In [188]: data, locations, size = create_inputs(size=24) In [189]: %timeit stack(data, locations, size) ...: %timeit allocate_assign(data, locations, size) ...: %timeit get_patches(data, locations, size) 1000 loops, best of 3: 1.66 ms per loop 1000 loops, best of 3: 1.55 ms per loop 1000 loops, best of 3: 470 µs per loop In [190]: data, locations, size = create_inputs(size=48) In [191]: %timeit stack(data, locations, size) ...: %timeit allocate_assign(data, locations, size) ...: %timeit get_patches(data, locations, size) 100 loops, best of 3: 2.8 ms per loop 100 loops, best of 3: 3.33 ms per loop 1000 loops, best of 3: 1.45 ms per loop