In [1]:
import numpy as np
In [2]:
def view_diagonals(arr: np.ndarray) -> np.ndarray:
"""Returns a view of all the diagonals in the array `arr`
If `arr` has shape (n1, n2, *rest),
then `view_diagonals(arr)` has shape (ndiags, nelmts, *rest), where
- `ndiags = max(1, n1 - n2 + 1)`
- `nelmts = min(n1, n2)`
"""
ndiags = max(1, arr.shape[0] - (arr.shape[1]-1))
nelmts = min(arr.shape[0], arr.shape[1])
return np.lib.stride_tricks.as_strided(
arr,
shape=(ndiags, nelmts, *arr.shape[2:]),
strides=(arr.strides[0], arr.strides[0] + arr.strides[1], *arr.strides[2:])
)
In [3]:
for L in np.arange(6*3).reshape(6,3), np.arange(6*3).reshape(3,6), np.arange(6*3*2).reshape(6,3,2):
Ldiags = view_diagonals(L)
Lantis = view_diagonals(np.fliplr(L))
Ltdiags = view_diagonals(L.transpose())
print('L = ', L, sep='\n')
print('diagonals of L =', Ldiags, sep='\n')
print('anti-diagonals of L =', Lantis, sep='\n')
print('horizontal diagonals (diagonals along axes 1,0 instead of 0,1) of L =', Ltdiags, sep='\n')
print()
assert(np.shares_memory(L, Ldiags))
assert(np.shares_memory(L, Lantis))
assert(np.shares_memory(L, Ltdiags))
L = [[ 0 1 2] [ 3 4 5] [ 6 7 8] [ 9 10 11] [12 13 14] [15 16 17]] diagonals of L = [[ 0 4 8] [ 3 7 11] [ 6 10 14] [ 9 13 17]] anti-diagonals of L = [[ 2 4 6] [ 5 7 9] [ 8 10 12] [11 13 15]] horizontal diagonals (diagonals along axes 1,0 instead of 0,1) of L = [[0 4 8]] L = [[ 0 1 2 3 4 5] [ 6 7 8 9 10 11] [12 13 14 15 16 17]] diagonals of L = [[ 0 7 14]] anti-diagonals of L = [[ 5 10 15]] horizontal diagonals (diagonals along axes 1,0 instead of 0,1) of L = [[ 0 7 14] [ 1 8 15] [ 2 9 16] [ 3 10 17]] L = [[[ 0 1] [ 2 3] [ 4 5]] [[ 6 7] [ 8 9] [10 11]] [[12 13] [14 15] [16 17]] [[18 19] [20 21] [22 23]] [[24 25] [26 27] [28 29]] [[30 31] [32 33] [34 35]]] diagonals of L = [[[ 0 1] [ 8 9] [16 17]] [[ 6 7] [14 15] [22 23]] [[12 13] [20 21] [28 29]] [[18 19] [26 27] [34 35]]] anti-diagonals of L = [[[ 4 5] [ 8 9] [12 13]] [[10 11] [14 15] [18 19]] [[16 17] [20 21] [24 25]] [[22 23] [26 27] [30 31]]] horizontal diagonals (diagonals along axes 1,0 instead of 0,1) of L = [[[ 0 6 12 18 24 30] [ 3 9 15 21 27 33]]]