from typing import ClassVar, Optional, Tuple
from collections.abc import Callable, Iterable, Sequence
from importlib import import_module
from math import pi
import io
import torch
import numpy as np
from . import css, svg
from .seizing import Length, Seizing, SeizingByHeight, Unit
from .strip_orientation import StripOrientation
from .viewbox import ViewboxCreator, ViewboxWithPadding
from ..nn import Unravel
__all__ = [ 'Draw', 'SVG' ]
[docs]
class SVG(str):
"""SVG as a string that is automatically displayed as such in Jupyter."""
def _repr_svg_(self):
return self
def _repr_html_(self):
return self
[docs]
class Draw:
"""Callable drawing unravelled persistence diagrams as :class:`SVG`.
Parameters
----------
range_intervals : :class:`Tuple[float, float]`, optional
See :attr:`range_intervals` for details,
defaults to :obj:`(0.0, pi/2.0)`.
seizing : :class:`Seizing`, optional
See :attr:`seizing` for details.
The default is an instance of :class:`SeizingByHeight`,
specifically :obj:`SeizingByHeight( length = Length(650.0, Unit.PX) )`.
viewbox_creator : :class:`ViewboxCreator`, optional
See :attr:`viewbox_creator` for details,
defaults to :obj:`ViewboxWithPadding( padding=0.06 )`.
strip_orientation : :class:`StripOrientation`, optional
The orientation of the strip to be drawn,
see also :attr:`strip_orientation`.
The default is :obj:`StripOrientation.ORDINARY`,
:obj:`StripOrientation.CROSS` can be a sensible choice
for functional or extended persistence.
background_creator : :class:`svg.BackgroundCreator`, optional
Draws the background based on the input,
see also :attr:`background_creator`.
The default is a default instance of :class:`svg.StripWithAxesPattern`.
unravel : :class:`Optional[ Callable ]`, optional
If this is not :obj:`None`,
then it is used to compute unravelled persistence diagrams
instead of a newly created instance of :class:`Unravel`,
see also :attr:`unravel`.
The default is :obj:`None`.
fmt : :class:`str`, optional
See :attr:`fmt` for details,
defaults to '%.5f'.
"""
stylesheet: ClassVar[str] = css.POLYLINE + css.LIGHT_SCHEME
""":class:`str`:
The stylesheet to be used by the SVG.
The initial value is the concatenation of
:obj:`css.POLYLINE` and :obj:`css.LIGHT_SCHEME`.
"""
defs: ClassVar[str] = svg.MARKERS
""":class:`str`:
SVG code to be wrapped by a <defs> element.
The initial value is :obj:`svg.MARKERS`.
"""
standalone: ClassVar[bool] = True
""":class:`bool`: If drawn SVGs should be standalone (by default)."""
__slots__ = ('range_intervals',
'seizing',
'viewbox_creator',
'strip_orientation',
'background_creator',
'fmt',
'unravel'
)
def __init__(
self,
*,
range_intervals: Tuple[float, float] = (0.0, pi/2.0),
seizing: Seizing = SeizingByHeight(
length = Length(650.0, Unit.PX) ),
viewbox_creator: ViewboxCreator = ViewboxWithPadding(
padding=0.06 ),
strip_orientation: StripOrientation = StripOrientation.ORDINARY,
background_creator: svg.BackgroundCreator = (
svg.StripWithAxesPattern()
),
unravel: Optional[ Callable ] = None,
fmt: str = '%.5f',
device = None,
dtype = None):
self.range_intervals = range_intervals
""":class:`Tuple[float, float]`:
The finite range containing all persistence intervals.
"""
self.seizing = seizing
""":class:`Seizing`:
Computes 'width' and 'height' attributes based on the input.
"""
self.viewbox_creator = viewbox_creator
""":class:`ViewboxCreator`: Creates a viewbox based on the input."""
self.strip_orientation = strip_orientation
""":class:`StripOrientation`:
The orientation of the strip to be drawn.
"""
self.background_creator = background_creator
""":class:`svg.BackgroundCreator`:
Draws the background based on the input.
"""
self.fmt = fmt
""":class:`str`:
The format string used to render floats into strings.
"""
self.unravel = Unravel(
range_intervals = range_intervals,
neg_x = (
False if
strip_orientation == StripOrientation.ORDINARY else
True
),
neg_y = True,
device = device,
dtype = dtype
) if unravel is None else unravel
""":class:`Optional[ Callable ]`:
Used to compute unravelled persistence diagrams."""
[docs]
@classmethod
def complement_notebook(cls):
"""Complements the notebook with CSS and SVG containing 'defs'.
This classmethod is supposed to be called from Jupyter Notebook.
It adds a <style> element wrapping :attr:`stylesheet`
and an invisible <svg> containing a <defs> element wrapping
:attr:`defs` to the notebook.
Moreover,
it has the side effect of setting :attr:`standalone` to :obj:`False`.
"""
try:
ipython_display_mod = import_module('IPython.display')
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The method 'complement_notebook' is intended to be called "
"from Jupyter Notebook."
)
ipython_display_mod.display(
ipython_display_mod.HTML(
f'<style>{cls.stylesheet}</style>'
'<svg width="0" '
'height="0" '
'style="position: absolute; z-index: -1">'
f'<defs>{cls.defs}</defs>'
'</svg>'
"CSS and an SVG with 'defs' were added to the DOM."
"<br>"
"In order for this to take any effect "
"you may have to press this button: "
"<button>Activate CSS and SVG 'defs'</button>"
)
)
cls.standalone = False
def _unravel(self,
intervals: Sequence[ torch.Tensor | Iterable[torch.Tensor] ]
) -> Tuple[ Sequence[torch.Tensor | Sequence[torch.Tensor]],
Sequence[Sequence[torch.Tensor]] ]:
if isinstance(intervals[0], torch.Tensor):
unravelled_intervals_degreewise = self.unravel(intervals)
return (
unravelled_intervals_degreewise,
list( zip(*unravelled_intervals_degreewise) )
)
else:
unravelled_intervals_batchwise = [
self.unravel(intervals_b) for intervals_b in intervals
]
return (
list( zip(*unravelled_intervals_batchwise) ),
unravelled_intervals_batchwise
)
def _style_n_defs(self,
standalone: bool
) -> str:
if standalone:
return (
f'<style>{self.stylesheet}</style>'
f'<defs>{self.defs}</defs>'
)
else:
return ''
def _unr_ints_d_to_str(self,
unravelled_intervals_d: torch.Tensor
) -> str:
with io.StringIO() as txt_stream:
np.savetxt(txt_stream,
unravelled_intervals_d.numpy(),
fmt = self.fmt,
delimiter = ',',
newline = ' '
)
return txt_stream.getvalue()
def _create_polyline(self,
unravelled_intervals: Iterable[torch.Tensor]
) -> str:
points = ' '.join(
self._unr_ints_d_to_str( unravelled_intervals_d ) for
unravelled_intervals_d in
unravelled_intervals
)
return f'<polyline points="{points}" />'
[docs]
def __call__(self,
intervals: Sequence[ torch.Tensor | Iterable[torch.Tensor] ],
*,
seizing: Optional[Seizing] = None,
viewbox_creator: Optional[ViewboxCreator] = None,
standalone: Optional[bool] = None,
complete_dimensions: Optional[bool] = None,
) -> SVG:
"""Draws unravelled persistence diagram as :class:`SVG`.
Parameters
----------
intervals : :class:`Sequence[ torch.Tensor | Iterable[torch.Tensor] ]`
Persistence intervals.
If a :class:`Sequence[torch.Tensor]` is passed,
the indices of the outer sequence correspond to degrees,
whereas if a
:class:`Sequence[Iterable[torch.Tensor]]`
is passed,
then the inner :class:`Iterable` is assumend to correspond
to the degrees
and the outer :class:`Sequence` corresponds to different
persistence diagrams.
seizing : :class:`Optional[Seizing]`, optional
If this is not :obj:`None`,
then it is used in place of :attr:`seizing`,
defaults to :obj:`None`.
viewbox_creator : :class:`Optional[ViewboxCreator]`, optional
If this is not :obj:`None`,
then it is used in place of :attr:`viewbox_creator`,
defaults to :obj:`None`.
standalone : :class:`Optional[bool]`, optional
If this is not :obj:`None`,
then it is used in place of :attr:`standalone`,
defaults to :obj:`None`.
complete_dimensions : :class:`Optional[bool]`
If this is not :obj:`None`,
then it is used in place of :attr:`complete_dimensions`,
defaults to :obj:`None`.
Returns
-------
:class:`SVG`
The SVG picturing unravelled persistence diagrams.
"""
seizing = self.seizing if seizing is None else seizing
viewbox_creator = (
self.viewbox_creator if
viewbox_creator is None else
viewbox_creator
)
standalone = self.standalone if standalone is None else standalone
( unravelled_intervals_degreewise,
unravelled_intervals_batchwise ) = self._unravel( intervals )
viewbox = viewbox_creator(
range_intervals = self.range_intervals,
strip_orientation = self.strip_orientation,
unravelled_intervals = unravelled_intervals_degreewise
)
sizes = seizing(
aspect_ratio = ( viewbox.right - viewbox.left ) / abs(
viewbox.top - viewbox.bottom
),
viewbox = viewbox,
complete_dimensions = (
standalone if
complete_dimensions is None else
complete_dimensions
)
)
background = self.background_creator(
range_intervals = self.range_intervals,
strip_orientation = self.strip_orientation
)
polylines = ''.join(
self._create_polyline( unravelled_intervals_b ) for
unravelled_intervals_b in
unravelled_intervals_batchwise
)
return SVG(
f'<svg {sizes} '
f'viewBox="{viewbox}" '
'xmlns="http://www.w3.org/2000/svg">'
f'{self._style_n_defs(standalone)}'
f'{background}'
f'<g class="unravelled-diagrams">{polylines}</g>'
'</svg>'
)