Source code for reptrix.alpha

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import r2_score  # type: ignore

import reptrix.utils as utils


[docs] def get_powerlaw(eigen: np.ndarray, trange: np.ndarray) -> tuple: """Fit powerlaw and return decay, powerlaw fit and the goodness of fit Args: eigen (np.ndarray): Eigenspectrum of activation covariance matrix trange (np.ndarray): Range to fit the powerlaw. Tip: Ignore the first couple of eigenvalues because we want to fit the tail of the spectrum Returns: tuple: Result of size 4 alpha: powerlaw decay coefficient ypred: Powerlaw fit fit_R2: goodness of powerlaw fit fit_R2_100: goodness of powerlaw fit (computed till the 100 eigvals) """ # Inspired by Stringer+Pachitariu 2018b github repo! # (https://github.com/MouseLand/stringer-pachitariu-et-al-2018b/blob/master/python/utils.py) logss = np.log(np.abs(eigen)) y = logss[trange][:, np.newaxis] trange += 1 nt = trange.size x = np.concatenate((-np.log(trange)[:, np.newaxis], np.ones((nt, 1))), axis=1) w = 1.0 / trange.astype(np.float32)[:, np.newaxis] b = np.linalg.solve(x.T @ (x * w), (w * x).T @ y).flatten() allrange = np.arange(0, eigen.size).astype(int) + 1 x = np.concatenate( (-np.log(allrange)[:, np.newaxis], np.ones((eigen.size, 1))), axis=1 ) ypred = np.exp((x * b).sum(axis=1)) alpha = b[0] max_range = ( 500 if len(eigen) >= 512 else len(eigen) - 10 ) # subtracting 10 here arbitrarily because we want to avoid the last tail! fit_R2 = r2_score( y_true=logss[trange[0] : max_range], y_pred=np.log(np.abs(ypred))[trange[0] : max_range], ) try: fit_R2_100 = r2_score( y_true=logss[trange[0] : 100], y_pred=np.log(np.abs(ypred))[trange[0] : 100] ) except: # noqa: E722 fit_R2_100 = None return (alpha, ypred, fit_R2, fit_R2_100)
[docs] def get_alpha( activations: torch.Tensor, max_eigenvals: int = 2048, fit_range: np.ndarray = None, ) -> tuple: """Get alpha and powerlaw fit (https://proceedings.neurips.cc/paper_files/paper/2022/hash/70596d70542c51c8d9b4e423f4bf2736-Abstract-Conference.html) Args: activations (np.ndarray): Activation tensor of shape (bsz,d1,d2...dn) max_eigenvals (int, optional): Maximum #eigenvalues to compute. Defaults to 2048. fit_range (np.ndarray, optional): Range to fit the powerlaw. Defaults to np.arange(10,100). Returns: tuple: Result of size 4 alpha: powerlaw decay coefficient ypred: Powerlaw fit fit_R2: goodness of powerlaw fit fit_R2_100: goodness of powerlaw fit (computed till the 100 eigvals) """ if fit_range is None: fit_range = np.arange(10, 100) if activations.requires_grad: activations_arr = activations.detach().cpu() else: activations_arr = activations.cpu() activations_arr_np = activations_arr.numpy() eigen = utils.get_eigenspectrum( activations_np=activations_arr_np, max_eigenvals=max_eigenvals ) alpha_res = get_powerlaw(eigen=eigen, trange=fit_range) return alpha_res
[docs] def plot_powerlaw(eigenspectrum: np.ndarray, alpha_res: tuple) -> None: """Plot eigenspectrum and powerlaw fit Args: eigenspectrum (np.ndarray): Eigenspectrum of activation covariance matrix alpha_res (tuple): Tuple containing alpha, powerlaw fit, goodness of powerlaw fit, goodness of powerlaw fit for first 100 eigenvalues """ utils.plot_eigenspectrum(eigenspectrum=eigenspectrum) alpha, ypred, r2, r2_100 = alpha_res xrange = np.arange(1, 1 + len(ypred)) plt.loglog(xrange, ypred, c='red', ls='--', lw=2.0, label='Powerlaw fit') plt.title( rf'Eigenspectrum and powerlaw fit, $\alpha$ = {alpha:.3f} ' + rf'(r2_100 = {r2_100:.3f})' )