# Dirichlet¶

This notebook illustrates the classification of the nodes of a graph by the Dirichlet problem, based on the labels of a few nodes.

[1]:

from IPython.display import SVG

[2]:

import numpy as np

[3]:

from sknetwork.data import karate_club, painters, movie_actor
from sknetwork.classification import DirichletClassifier, BiDirichletClassifier
from sknetwork.visualization import svg_graph, svg_digraph, svg_bigraph


## Graphs¶

[4]:

graph = karate_club(metadata=True)
position = graph.position
labels_true = graph.labels


Classification

[5]:

seeds = {i: labels_true[i] for i in [0, 33]}

[6]:

diffusion = DirichletClassifier()

[7]:

precision = np.round(np.mean(labels_pred == labels_true), 2)
precision

[7]:

0.97

[8]:

image = svg_graph(adjacency, position, labels=labels_pred, seeds=seeds)

[9]:

SVG(image)

[9]:


Soft classification

[10]:

membership = diffusion.membership_

[11]:

scores = membership[:,1].toarray().ravel()

[12]:

image = svg_graph(adjacency, position, scores=scores, seeds=seeds)

[13]:

SVG(image)

[13]:


## Digraphs¶

[14]:

graph = painters(metadata=True)
position = graph.position
names = graph.names


Classification

[15]:

rembrandt = 5
klimt = 6
cezanne = 11
seeds = {cezanne: 0, rembrandt: 1, klimt: 2}

[16]:

diffusion = DirichletClassifier()

[17]:

image = svg_digraph(adjacency, position, names, labels, seeds=seeds)

[18]:

SVG(image)

[18]:


Soft classification

[19]:

membership = diffusion.membership_

[20]:

scores = membership[:,0].toarray().ravel()

[21]:

image = svg_digraph(adjacency, position, names=names, scores=scores, seeds=[cezanne])

[22]:

SVG(image)

[22]:


## Bigraphs¶

[23]:

graph = movie_actor(metadata=True)
names_row = graph.names_row
names_col = graph.names_col


Classification

[24]:

inception = 0
drive = 3
budapest = 8

[25]:

seeds_row = {inception: 0, drive: 1, budapest: 2}

[26]:

bidiffusion = BiDirichletClassifier()
labels_row = bidiffusion.labels_row_
labels_col = bidiffusion.labels_col_

[27]:

image = svg_bigraph(biadjacency, names_row, names_col, labels_row, labels_col, seeds_row=seeds_row)

[28]:

SVG(image)

[28]:


Soft classification

[29]:

membership_row = bidiffusion.membership_row_
membership_col = bidiffusion.membership_col_

[30]:

scores_row = membership_row[:,1].toarray().ravel()
scores_col = membership_col[:,1].toarray().ravel()

[31]:

image = svg_bigraph(biadjacency, names_row, names_col, scores_row=scores_row, scores_col=scores_col,
seeds_row=seeds_row)

[32]:

SVG(image)

[32]: