# Dirichlet¶

This notebook illustrates the classification of the nodes of a graph by the Dirichlet problem, based on the labels of a few nodes.

:

from IPython.display import SVG

:

import numpy as np

:

from sknetwork.data import karate_club, painters, movie_actor
from sknetwork.classification import DirichletClassifier, BiDirichletClassifier
from sknetwork.visualization import svg_graph, svg_digraph, svg_bigraph


## Graphs¶

:

graph = karate_club(metadata=True)
adjacency = graph.adjacency
position = graph.position
labels_true = graph.labels


Classification

:

seeds = {i: labels_true[i] for i in [0, 33]}

:

diffusion = DirichletClassifier()
labels_pred = diffusion.fit_transform(adjacency, seeds)

:

precision = np.round(np.mean(labels_pred == labels_true), 2)
precision

:

0.97

:

image = svg_graph(adjacency, position, labels=labels_pred, seeds=seeds)

:

SVG(image)

: Soft classification

:

membership = diffusion.membership_

:

scores = membership[:,1].toarray().ravel()

:

image = svg_graph(adjacency, position, scores=scores, seeds=seeds)

:

SVG(image)

: ## Digraphs¶

:

graph = painters(metadata=True)
adjacency = graph.adjacency
position = graph.position
names = graph.names


Classification

:

rembrandt = 5
klimt = 6
cezanne = 11
seeds = {cezanne: 0, rembrandt: 1, klimt: 2}

:

diffusion = DirichletClassifier()
labels = diffusion.fit_transform(adjacency, seeds)

:

image = svg_digraph(adjacency, position, names, labels, seeds=seeds)

:

SVG(image)

: Soft classification

:

membership = diffusion.membership_

:

scores = membership[:,0].toarray().ravel()

:

image = svg_digraph(adjacency, position, names=names, scores=scores, seeds=[cezanne])

:

SVG(image)

: ## Bigraphs¶

:

graph = movie_actor(metadata=True)
biadjacency = graph.biadjacency
names_row = graph.names_row
names_col = graph.names_col


Classification

:

inception = 0
drive = 3
budapest = 8

:

seeds_row = {inception: 0, drive: 1, budapest: 2}

:

bidiffusion = BiDirichletClassifier()
bidiffusion.fit(biadjacency, seeds_row)
labels_row = bidiffusion.labels_row_
labels_col = bidiffusion.labels_col_

:

image = svg_bigraph(biadjacency, names_row, names_col, labels_row, labels_col, seeds_row=seeds_row)

:

SVG(image)

: Soft classification

:

membership_row = bidiffusion.membership_row_
membership_col = bidiffusion.membership_col_

:

scores_row = membership_row[:,1].toarray().ravel()
scores_col = membership_col[:,1].toarray().ravel()

:

image = svg_bigraph(biadjacency, names_row, names_col, scores_row=scores_row, scores_col=scores_col,
seeds_row=seeds_row)

:

SVG(image)

: 