Skip to content

Commit 1103897

Browse files
committed
fix: add raw_normalize and add_self_loop
1 parent 62e7aee commit 1103897

File tree

6 files changed

+58
-12
lines changed

6 files changed

+58
-12
lines changed

.ci/install-dev.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
rm -rf .env
22

3-
# # install python 3.8.16 using pyenv:
4-
# pyenv install 3.8.16
5-
# pyenv local 3.8.16
3+
bash .ci/py.sh
64

75
# create and activate virtual environment
86
python3 -m venv .env

.ci/install.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@ source .env/bin/activate
1616
# python -m pip install -U pip
1717

1818
# # torch cuda 11.3
19-
python -m pip install torch==1.12 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu113
19+
python -m pip install "torch>=1.12,<2.2" torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
2020

2121
# dgl cuda 11.3
2222
# add a source if prefered: -i https://pypi.tuna.tsinghua.edu.cn/simple/
23-
python -m pip install dgl==1.1.0 -f https://data.dgl.ai/wheels/cu113/repo.html
24-
python -m pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html
23+
python3 -m pip install dgl==2.1 -f https://data.dgl.ai/wheels/cu121/repo.html
2524

2625
# install requirements
2726
# add a source if prefered: -i https://pypi.tuna.tsinghua.edu.cn/simple/

.ci/py.sh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Check if python3 command exists
2+
flag=0
3+
if command -v python3 &>/dev/null; then
4+
# Extracting the version number
5+
python_version=$(python3 --version 2>&1 | awk '{print $2}')
6+
7+
# Splitting the version number into major and minor parts
8+
IFS='.' read -r -a version_parts <<< "$python_version"
9+
10+
major_version="${version_parts[0]}"
11+
minor_version="${version_parts[1]}"
12+
13+
if [ "$major_version" -eq 3 ] && [ "$minor_version" -ge 8 ]; then
14+
echo "Python version $python_version is greater than or equal to 3.8."
15+
else
16+
echo "Python version $python_version is less than 3.8."
17+
flag=1
18+
fi
19+
else
20+
echo "Python>=3.8 is not installed."
21+
flag=1
22+
fi
23+
# # install python 3.8, i.e., using pyenv:
24+
# pyenv install 3.8
25+
if [ "$flag" -eq 1 ]; then
26+
command -v pyenv >/dev/null 2>&1 && pyenv local 3.8 && { echo >&2 "Install python 3.8 using pyenv successfully.";} || { echo >&2 "Pyenv is not installed. Did not install python 3.8 using pyenv. Please install mannually and retry"; exit 1;}
27+
fi

graph_datasets/load_data.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Load Graph Datasets
22
"""
3+
34
# pylint:disable=protected-access
45
import ssl
56
from typing import Tuple
7+
from typing import Union
68

79
import dgl
810
import torch
11+
import torch.nn.functional as F
912
from torch_geometric.data import Data
1013
from torch_geometric.utils import from_dgl
1114

@@ -33,9 +36,11 @@ def load_data(
3336
verbosity: int = 0,
3437
source: str = "pyg",
3538
return_type: str = "dgl",
39+
raw_normalize: bool = True,
3640
rm_self_loop: bool = True,
41+
add_self_loop: bool = False,
3742
to_simple: bool = True,
38-
) -> Tuple[dgl.DGLGraph, torch.Tensor, int] or Data:
43+
) -> Union[Tuple[dgl.DGLGraph, torch.Tensor, int], Data]:
3944
"""Load graphs.
4045
4146
Args:
@@ -47,14 +52,18 @@ def load_data(
4752
source (str, optional): Source for data loading. Defaults to "pyg".
4853
return_type (str, optional): Return type of the graphs within ["dgl", "pyg"]. \
4954
Defaults to "dgl".
55+
raw_normalize (str, optional): Row normalize the feature matrix. Defaults to True.
5056
rm_self_loop (str, optional): Remove self loops. Defaults to True.
57+
add_self_loop (str, optional): Add self loops no matter what rm_self_loop is. \
58+
Defaults to True.
5159
to_simple (str, optional): Convert to a simple graph with no duplicate undirected edges.
5260
5361
Raises:
5462
NotImplementedError: Dataset unknown.
5563
5664
Returns:
57-
Tuple[dgl.DGLGraph, torch.Tensor, int]: [graph, label, n_clusters]
65+
Tuple[dgl.DGLGraph, torch.Tensor, int]: [graph, label, n_clusters] or \
66+
torch_geometric.data.Data
5867
5968
Example:
6069
.. code-block:: python
@@ -143,9 +152,12 @@ def load_data(
143152
f"https://galogm.github.io/graph_datasets_docs/rst/table.html"
144153
)
145154

146-
# remove self loop and turn graphs into undirected ones
155+
if raw_normalize:
156+
graph.ndata["feat"] = F.normalize(graph.ndata["feat"], dim=1)
147157
if rm_self_loop:
148-
graph = dgl.remove_self_loop(graph)
158+
graph = graph.remove_self_loop()
159+
if add_self_loop:
160+
graph = graph.remove_self_loop().add_self_loop()
149161
if to_simple:
150162
graph = dgl.to_bidirected(graph, copy_ndata=True)
151163

@@ -155,9 +167,13 @@ def load_data(
155167
new_label = torch.tensor(list(map(lambda x: old2new[x.item()], label)))
156168
graph.ndata["label"] = new_label
157169

170+
name = f"{dataset_name}_{source}"
171+
graph.name = name
172+
158173
if verbosity:
159174
print_dataset_info(
160-
dataset_name=f"{source.upper()} undirected {dataset_name}\nwithout self-loops",
175+
dataset_name=
176+
f"{source.upper()} undirected {dataset_name}\n add_self_loop={add_self_loop} rm_self_loop={rm_self_loop}",
161177
n_nodes=graph.num_nodes(),
162178
n_edges=graph.num_edges(),
163179
n_feats=graph.ndata["feat"].shape[1],
@@ -168,7 +184,7 @@ def load_data(
168184
return graph, new_label, n_clusters
169185

170186
data = from_dgl(graph)
171-
data.name = dataset_name
187+
data.name = name
172188
data.num_classes = n_clusters
173189
data.x = data.feat
174190
data.y = data.label

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
ogb
12
gdown
23
wget
34
dgl>=1.1

tests/test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test
22
"""
3+
34
# pylint:disable=duplicate-code
45
from graph_datasets import load_data
56
from graph_datasets.data_info import DATASETS
@@ -11,6 +12,10 @@ def main(_source, _dataset):
1112
directory="./data",
1213
source=_source,
1314
verbosity=3,
15+
raw_normalize=True,
16+
rm_self_loop=True,
17+
add_self_loop=True,
18+
to_simple=True,
1419
)
1520

1621
# import argparse

0 commit comments

Comments
 (0)