@@ -101,17 +101,24 @@ def test_backward(self):
101
101
mesh = ico_sphere ()
102
102
verts = mesh .verts_packed ()
103
103
edges = mesh .edges_packed ()
104
+ verts_cpu = verts .clone ()
105
+ edges_cpu = edges .clone ()
104
106
verts_cuda = verts .clone ().to (device )
105
107
edges_cuda = edges .clone ().to (device )
106
108
verts .requires_grad = True
109
+ verts_cpu .requires_grad = True
107
110
verts_cuda .requires_grad = True
108
111
109
112
neighbor_sums_cuda = gather_scatter (verts_cuda , edges_cuda , False )
113
+ neighbor_sums_cpu = gather_scatter (verts_cpu , edges_cpu , False )
110
114
neighbor_sums = gather_scatter_python (verts , edges , False )
111
- neighbor_sums_cuda .sum ().backward ()
112
- neighbor_sums .sum ().backward ()
115
+ randoms = torch .rand_like (neighbor_sums )
116
+ (neighbor_sums_cuda * randoms .cuda ()).sum ().backward ()
117
+ (neighbor_sums_cpu * randoms ).sum ().backward ()
118
+ (neighbor_sums * randoms ).sum ().backward ()
113
119
114
- self .assertClose (verts .grad .cpu (), verts_cuda .grad .cpu ())
120
+ self .assertClose (verts .grad , verts_cuda .grad .cpu ())
121
+ self .assertClose (verts .grad , verts_cpu .grad )
115
122
116
123
def test_repr (self ):
117
124
conv = GraphConv (32 , 64 , directed = True )
@@ -141,22 +148,24 @@ def test_gather_scatter(self):
141
148
w0 = nn .Linear (3 , 1 )
142
149
input = w0 (verts )
143
150
144
- # output
145
- output_cpu = gather_scatter_python (input , edges , False )
151
+ # undirected
152
+ output_python = gather_scatter_python (input , edges , False )
146
153
output_cuda = _C .gather_scatter (
147
154
input .to (device = device ), edges .to (device = device ), False , False
148
155
)
149
- self .assertClose (output_cuda .cpu (), output_cpu )
150
- with self . assertRaises ( Exception ) as err :
151
- _C .gather_scatter (input .cpu (), edges .cpu (), False , False )
152
- self .assertTrue ( "Not implemented on the CPU" in str ( err . exception ) )
156
+ self .assertClose (output_cuda .cpu (), output_python )
157
+
158
+ output_cpu = _C .gather_scatter (input .cpu (), edges .cpu (), False , False )
159
+ self .assertClose ( output_cpu , output_python )
153
160
154
161
# directed
155
- output_cpu = gather_scatter_python (input , edges , True )
162
+ output_python = gather_scatter_python (input , edges , True )
156
163
output_cuda = _C .gather_scatter (
157
164
input .to (device = device ), edges .to (device = device ), True , False
158
165
)
159
- self .assertClose (output_cuda .cpu (), output_cpu )
166
+ self .assertClose (output_cuda .cpu (), output_python )
167
+ output_cpu = _C .gather_scatter (input .cpu (), edges .cpu (), True , False )
168
+ self .assertClose (output_cpu , output_python )
160
169
161
170
@staticmethod
162
171
def graph_conv_forward_backward (
0 commit comments