Source code for sknetwork.linalg.normalization

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created in November 2019
@author: Nathan de Lara <nathan.delara@polytechnique.org>
"""
from typing import Union

import numpy as np
from scipy import sparse
from scipy.sparse.linalg import LinearOperator


[docs]def diagonal_pseudo_inverse(weights: np.ndarray) -> sparse.csr_matrix: """Compute :math:`\\text{diag}(w)^+`, the pseudo-inverse of the diagonal matrix with diagonal elements given by the weights :math:`w`. Parameters ---------- weights: The weights to invert. Returns ------- sparse.csr_matrix """ diag: sparse.csr_matrix = sparse.diags(weights, format='csr') diag.data = 1 / diag.data return diag
def get_norms(matrix: Union[sparse.csr_matrix, np.ndarray, LinearOperator], p=1): """Get the norms of rows of a matrix. Parameters ---------- matrix : numpy array, sparse CSR matrix or linear operator, shape (n_rows, n_cols) Input matrix. p : Order of the norm. Returns ------- norms : np.array, shape (n_rows,) Vector norms """ if p == 1: norms = matrix.dot(np.ones(matrix.shape[1])) elif p == 2: if isinstance(matrix, np.ndarray): norms = np.linalg.norm(matrix, axis=1) elif isinstance(matrix, sparse.csr_matrix): data = matrix.data.copy() matrix.data = data ** 2 norms = np.sqrt(matrix.dot(np.ones(matrix.shape[1]))) matrix.data = data else: raise NotImplementedError('Norm 2 is not available for a LinearOperator.') else: raise NotImplementedError('Only norms 1 and 2 are available at the moment.') return norms
[docs]def normalize(matrix: Union[sparse.csr_matrix, np.ndarray, LinearOperator], p=1): """Normalize the rows of a matrix so that all have norm 1 (or 0; null rows remain null). Parameters ---------- matrix : Input matrix. p : Order of the norm. Returns ------- normalized matrix : Normalized matrix (same format as input matrix). """ norms = get_norms(matrix, p) diag = diagonal_pseudo_inverse(norms) if hasattr(matrix, 'left_sparse_dot') and callable(matrix.left_sparse_dot): return matrix.left_sparse_dot(diag) return diag.dot(matrix)