diff --git a/matrix_functions.py b/matrix_functions.py index 973a44a7..de244293 100644 --- a/matrix_functions.py +++ b/matrix_functions.py @@ -955,7 +955,9 @@ def compute_matrix_root_inverse_residuals( root_inv_config=root_inv_config, epsilon=epsilon, ) - relative_error = torch.dist(X, X_hat, p=torch.inf) / torch.norm(X, p=torch.inf) + relative_error = torch.dist(X, X_hat, p=torch.inf) / torch.linalg.vector_norm( + X, ord=torch.inf + ) # compute residual X_invr, _, _ = _matrix_inverse_root_eigen( @@ -970,8 +972,8 @@ def compute_matrix_root_inverse_residuals( A_reg = A.double() + epsilon * torch.eye( A.shape[0], dtype=torch.float64, device=A.device ) - relative_residual = torch.dist(X_invr, A_reg, p=torch.inf) / torch.norm( - A_reg, p=torch.inf - ) + relative_residual = torch.dist( + X_invr, A_reg, p=torch.inf + ) / torch.linalg.vector_norm(A_reg, ord=torch.inf) return relative_error, relative_residual