from collections.abc import Sequence
from functools import reduce
import numpy as np
import gudhi
SimplexTrees = Sequence[ Sequence[Sequence[gudhi.SimplexTree]] ]
[docs]
def point_cloud_to_simplex_tree(point_cloud: np.ndarray) -> gudhi.SimplexTree:
simplex_tree = gudhi.AlphaComplex( points = point_cloud
).create_simplex_tree()
simplex_tree.compute_persistence()
return simplex_tree
[docs]
def point_clouds_to_simplex_trees(point_clouds:np.ndarray
) -> SimplexTrees:
nb_classes, nb_point_clouds, nb_coarsened, _, _ = point_clouds.shape
return [ [ [ point_cloud_to_simplex_tree( point_clouds[i, j, k, :, :] ) for
k in range( nb_coarsened )
] for
j in range( nb_point_clouds )
] for
i in range( nb_classes )
]
[docs]
def simplex_trees_to_persistence_intervals_in_dim( *,
simplex_trees: SimplexTrees,
dim: int
) -> np.ndarray:
persistence_intervals = [
[ [ simplex_tree.persistence_intervals_in_dimension( dim ) for
simplex_tree in coarseneds
] for
coarseneds in point_clouds
] for
point_clouds in simplex_trees
]
max_len = reduce(
max,
( reduce(
max,
( reduce(
max,
( intervals.shape[0] for
intervals in coarseneds
)
) for
coarseneds in point_clouds
)
) for
point_clouds in persistence_intervals
)
)
return np.stack( [
np.stack( [
np.stack( [
np.pad( intervals,
( (0, max_len - intervals.shape[0]), (0, 0) )
) for
intervals in coarseneds
] ) for
coarseneds in point_clouds
] ) for
point_clouds in persistence_intervals
] )