Source code for sknetwork.visualization.dendrograms

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created in April 2020
@author: Thomas Bonald <bonald@enst.fr>
"""
from typing import Iterable, Optional

import numpy as np

from sknetwork.hierarchy.postprocess import cut_straight
from sknetwork.visualization.colors import STANDARD_COLORS


def get_index(dendrogram, reorder=True):
    """Index nodes for pretty dendrogram."""
    n = dendrogram.shape[0] + 1
    tree = {i: [i] for i in range(n)}
    for t in range(n - 1):
        i = int(dendrogram[t, 0])
        j = int(dendrogram[t, 1])
        left: list = tree.pop(i)
        right: list = tree.pop(j)
        if reorder and len(left) < len(right):
            tree[n + t] = right + left
        else:
            tree[n + t] = left + right
    return list(tree.values())[0]


def svg_dendrogram_top(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
                       color, colors, font_size, reorder, rotate_names):
    """Dendrogram as SVG image with root on top."""

    # scaling
    height *= scale
    width *= scale

    # positioning
    labels = cut_straight(dendrogram, n_clusters, return_dendrogram=False)
    index = get_index(dendrogram, reorder)
    n = len(index)
    unit_height = height / dendrogram[-1, 2]
    unit_width = width / n
    height_basis = margin + height
    position = {index[i]: (margin + i * unit_width, height_basis) for i in range(n)}
    label = {i: l for i, l in enumerate(labels)}
    width += 2 * margin
    height += 2 * margin
    if names is not None:
        text_length = np.max(np.array([len(str(name)) for name in names]))
        height += text_length * font_size * .5 + margin_text

    svg = """<svg width="{}" height="{}"  xmlns="http://www.w3.org/2000/svg">""".format(width, height)

    # text
    if names is not None:
        for i in range(n):
            x, y = position[i]
            x -= margin_text
            y += margin_text
            text = str(names[i]).replace('&', ' ')
            if rotate_names:
                svg += """<text x="{}" y="{}"  transform="rotate(60, {}, {})" font-size="{}">{}</text>""" \
                    .format(x, y, x, y, font_size, text)
            else:
                y += margin_text
                svg += """<text x="{}" y="{}"  font-size="{}">{}</text>""" \
                    .format(x, y, font_size, text)

    # tree
    for t in range(n - 1):
        i = int(dendrogram[t, 0])
        j = int(dendrogram[t, 1])
        x1, y1 = position.pop(i)
        x2, y2 = position.pop(j)
        l1 = label.pop(i)
        l2 = label.pop(j)
        if l1 == l2:
            line_color = colors[l1 % len(colors)]
        else:
            line_color = color
        x = .5 * (x1 + x2)
        y = height_basis - dendrogram[t, 2] * unit_height
        svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
            .format(line_width, line_color, x1, y1, x1, y)
        svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
            .format(line_width, line_color, x2, y2, x2, y)
        svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
            .format(line_width, line_color, x1, y, x2, y)
        position[n + t] = (x, y)
        label[n + t] = l1

    svg += '</svg>'
    return svg


def svg_dendrogram_left(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
                        color, colors, font_size, reorder):
    """Dendrogram as SVG image with root on left side."""

    # scaling
    height *= scale
    width *= scale

    # positioning
    labels = cut_straight(dendrogram, n_clusters, return_dendrogram=False)
    index = get_index(dendrogram, reorder)
    n = len(index)
    unit_height = height / n
    unit_width = width / dendrogram[-1, 2]
    width_basis = width + margin
    position = {index[i]: (width_basis, margin + i * unit_height) for i in range(n)}
    label = {i: l for i, l in enumerate(labels)}
    width += 2 * margin
    height += 2 * margin
    if names is not None:
        text_length = np.max(np.array([len(str(name)) for name in names]))
        width += text_length * font_size * .5 + margin_text

    svg = """<svg width="{}" height="{}"  xmlns="http://www.w3.org/2000/svg">""".format(width, height)

    # text
    if names is not None:
        for i in range(n):
            x, y = position[i]
            x += margin_text
            y += unit_height / 3
            text = str(names[i]).replace('&', ' ')
            svg += """<text x="{}" y="{}" font-size="{}">{}</text>""" \
                .format(x, y, font_size, text)

    # tree
    for t in range(n - 1):
        i = int(dendrogram[t, 0])
        j = int(dendrogram[t, 1])
        x1, y1 = position.pop(i)
        x2, y2 = position.pop(j)
        l1 = label.pop(i)
        l2 = label.pop(j)
        if l1 == l2:
            line_color = colors[l1 % len(colors)]
        else:
            line_color = color
        y = .5 * (y1 + y2)
        x = width_basis - dendrogram[t, 2] * unit_width
        svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
            .format(line_width, line_color, x1, y1, x, y1)
        svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
            .format(line_width, line_color, x2, y2, x, y2)
        svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
            .format(line_width, line_color, x, y1, x, y2)
        position[n + t] = (x, y)
        label[n + t] = l1

    svg += '</svg>'

    return svg


[docs]def visualize_dendrogram(dendrogram: np.ndarray, names: Optional[np.ndarray] = None, rotate: bool = False, width: float = 400, height: float = 300, margin: float = 10, margin_text: float = 5, scale: float = 1, line_width: float = 2, n_clusters: int = 2, color: str = 'black', colors: Optional[Iterable] = None, font_size: int = 12, reorder: bool = False, rotate_names: bool = True, filename: Optional[str] = None): """Return the image of a dendrogram in SVG format. Parameters ---------- dendrogram : Dendrogram to display. names : Names of leaves. rotate : If ``True``, rotate the tree so that the root is on the left. width : Width of the image (margins excluded). height : Height of the image (margins excluded). margin : Margin. margin_text : Margin between leaves and their names, if any. scale : Scaling factor. line_width : Line width. n_clusters : Number of coloured clusters to display. color : Default SVG color for the dendrogram. colors : SVG colors of the clusters of the dendrogram (optional). font_size : Font size. reorder : If ``True``, reorder leaves so that left subtree has more leaves than right subtree. rotate_names : If ``True``, rotate names of leaves (only valid if **rotate** is ``False``). filename : Filename for saving image (optional). Example ------- >>> dendrogram = np.array([[0, 1, 1, 2], [2, 3, 2, 3]]) >>> from sknetwork.visualization import svg_dendrogram >>> image = svg_dendrogram(dendrogram) >>> image[1:4] 'svg' """ if colors is None: colors = STANDARD_COLORS elif isinstance(colors, dict): colors = np.array(list(colors.values())) elif isinstance(colors, list): colors = np.array(colors) if rotate: svg = svg_dendrogram_left(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters, color, colors, font_size, reorder) else: svg = svg_dendrogram_top(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters, color, colors, font_size, reorder, rotate_names) if filename is not None: with open(filename + '.svg', 'w') as f: f.write(svg) return svg
def svg_dendrogram(dendrogram: np.ndarray, names: Optional[np.ndarray] = None, rotate: bool = False, width: float = 400, height: float = 300, margin: float = 10, margin_text: float = 5, scale: float = 1, line_width: float = 2, n_clusters: int = 2, color: str = 'black', colors: Optional[Iterable] = None, font_size: int = 12, reorder: bool = False, rotate_names: bool = True, filename: Optional[str] = None): """Return the image of a dendrogram in SVG format. Alias for visualize_dendrogram. Parameters ---------- dendrogram : Dendrogram to display. names : Names of leaves. rotate : If ``True``, rotate the tree so that the root is on the left. width : Width of the image (margins excluded). height : Height of the image (margins excluded). margin : Margin. margin_text : Margin between leaves and their names, if any. scale : Scaling factor. line_width : Line width. n_clusters : Number of coloured clusters to display. color : Default SVG color for the dendrogram. colors : SVG colors of the clusters of the dendrogram (optional). font_size : Font size. reorder : If ``True``, reorder leaves so that left subtree has more leaves than right subtree. rotate_names : If ``True``, rotate names of leaves (only valid if **rotate** is ``False``). filename : Filename for saving image (optional). """ return visualize_dendrogram(dendrogram, names, rotate, width, height, margin, margin_text, scale, line_width, n_clusters, color, colors, font_size, reorder, rotate_names, filename)