Source code for reptrix.utils

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.decomposition import PCA  # type: ignore


[docs] def get_eigenspectrum( activations_np: np.ndarray, max_eigenvals: int = 2048 ) -> np.ndarray: """Get eigenspectrum of activation covariance matrix. Args: activations_np (np.ndarray): Numpy arr of activations, shape (bsz,d1,d2...dn) max_eigenvals (int, optional): Maximum #eigenvalues to compute. Defaults to 2048. Returns: np.ndarray: Returns the eigenspectrum of the activation covariance matrix """ feats = activations_np.reshape(activations_np.shape[0], -1) feats_center = feats - feats.mean(axis=0) pca = PCA( n_components=min(max_eigenvals, feats_center.shape[0], feats_center.shape[1]), svd_solver='full', ) pca.fit(feats_center) eigenspectrum = pca.explained_variance_ratio_ return eigenspectrum
[docs] def plot_eigenspectrum(eigenspectrum: np.ndarray) -> None: """Plot eigenspectrum in log-log scale Args: eigenspectrum (np.ndarray): Eigenspectrum of activation covariance matrix """ xrange = np.arange(1, 1 + len(eigenspectrum)) plt.loglog(xrange, eigenspectrum, c='blue', lw=2.0, label='Eigenspectrum') xlim_max = 1024 if len(eigenspectrum) > 512 else len(eigenspectrum) - 20 plt.xlim(right=xlim_max) plt.ylim(bottom=1e-4) plt.legend() plt.grid(True, color='gray', lw=1.0, alpha=0.3) plt.xlabel('i') plt.ylabel(r'$\lambda_i$')
[docs] def mat_sqrt_inv(matrix: torch.Tensor) -> torch.Tensor: """Compute the square root of inverse of a square (positive definite) matrix. Args: matrix (torch.Tensor): Matrix to be be sqrt inverted Returns: torch.Tensor: Square root of inverse of matrix """ # Ensure the input is a square matrix assert matrix.size(0) == matrix.size(1), "Input must be a square matrix." # Compute the inverse of the matrix matrix_inv = torch.inverse(matrix) # Compute the eigenvalues and eigenvectors of the inverse matrix eigenvals, eigenvecs = torch.linalg.eig(matrix_inv) sqrt_eigenvals = torch.sqrt(eigenvals.real).type_as(matrix) sqrt_diag = torch.diag(sqrt_eigenvals) eigenvecs_real = eigenvecs.real.type_as(matrix) # Reconstruct the matrix result = torch.mm( torch.mm(eigenvecs_real, sqrt_diag), torch.linalg.inv(eigenvecs_real) ) return result