diff --git a/graph_datasets/utils/evaluation/eval_tools.py b/graph_datasets/utils/evaluation/eval_tools.py index 8c2f01e..a7a4041 100644 --- a/graph_datasets/utils/evaluation/eval_tools.py +++ b/graph_datasets/utils/evaluation/eval_tools.py @@ -3,6 +3,7 @@ # pylint: disable=invalid-name,invalid-name,too-many-locals import os import random +from datetime import datetime import numpy as np import torch @@ -303,3 +304,27 @@ def evaluate_results_nc( f1_mean, f1_std, ) + + +def save_embedding( + node_embeddings: torch.tensor, + dataset_name: str, + model_name: str, + params: dict, + save_dir: str = "./save", + verbose: bool or int = True, +): + dataset_name = dataset_name.replace("_", "-") + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + file_name = f"{dataset_name.lower()}_{model_name.lower()}_embeds_{timestamp}.pth" + file_path = os.path.join(save_dir, file_name) + + result = { + "node_embeddings": node_embeddings.cpu().detach(), + "hyperparameters": params, + } + + torch.save(result, file_path) + + if verbose: + print(f"Embeddings and hyperparameters saved to {file_path}")