In numpy, how to efficiently list all fixed-size submatrices?

I have an arbitrary NxM matrix, for example:

1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 

I want to get a list of all 3x3 submatrices in this matrix:

 1 2 3 2 3 4 0 1 2 7 8 9 ; 8 9 0 ; ... ; 6 7 8 3 4 5 4 5 6 2 3 4 

I can do this with two nested loops:

 rows, cols = input_matrix.shape patches = [] for row in np.arange(0, rows - 3): for col in np.arange(0, cols - 3): patches.append(input_matrix[row:row+3, col:col+3]) 

But for a large input matrix, this is slow. Is there a way to make this faster with numpy?

I looked at np.split , but that gives me non-overlapping np.split , whereas I want all possible np.split , regardless of overlap.

+3
source share
1 answer

You need a window view:

 from numpy.lib.stride_tricks import as_strided arr = np.arange(1, 25).reshape(4, 6) % 10 sub_shape = (3, 3) view_shape = tuple(np.subtract(arr.shape, sub_shape) + 1) + sub_shape arr_view = as_strided(arr, view_shape, arr.strides * 2 arr_view = arr_view.reshape((-1,) + sub_shape) >>> arr_view array([[[[1, 2, 3], [7, 8, 9], [3, 4, 5]], [[2, 3, 4], [8, 9, 0], [4, 5, 6]], ... [[9, 0, 1], [5, 6, 7], [1, 2, 3]], [[0, 1, 2], [6, 7, 8], [2, 3, 4]]]]) 

A good part of this is that you are not copying any data, you are simply accessing the data in your original array differently. For large arrays, this can lead to huge memory savings.

+7
source

Source: https://habr.com/ru/post/1274902/


All Articles