11from  collections .abc  import  Sequence 
22
33import  torch 
4- from  class_resolver  import  Hint 
54from  torch  import  nn 
65
76from  torchdrug  import  core , layers 
87from  torchdrug .core  import  Registry  as  R 
9- from  torchdrug .layers  import  Readout , readout_resolver 
108
119
1210@R .register ("models.GCN" ) 
@@ -29,7 +27,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
2927    """ 
3028
3129    def  __init__ (self , input_dim , hidden_dims , edge_input_dim = None , short_cut = False , batch_norm = False ,
32-                  activation = "relu" , concat_hidden = False , readout :  Hint [ Readout ]  =   "sum" ):
30+                  activation = "relu" , concat_hidden = False , readout = "sum" ):
3331        super (GraphConvolutionalNetwork , self ).__init__ ()
3432
3533        if  not  isinstance (hidden_dims , Sequence ):
@@ -44,7 +42,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
4442        for  i  in  range (len (self .dims ) -  1 ):
4543            self .layers .append (layers .GraphConv (self .dims [i ], self .dims [i  +  1 ], edge_input_dim , batch_norm , activation ))
4644
47-         self .readout  =  readout_resolver .make (readout )
45+         if  readout  ==  "sum" :
46+             self .readout  =  layers .SumReadout ()
47+         elif  readout  ==  "mean" :
48+             self .readout  =  layers .MeanReadout ()
49+         elif  readout  ==  "max" :
50+             self .readout  =  layers .MaxReadout ()
51+         else :
52+             raise  ValueError ("Unknown readout `%s`"  %  readout )
4853
4954    def  forward (self , graph , input , all_loss = None , metric = None ):
5055        """ 
@@ -103,7 +108,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
103108    """ 
104109
105110    def  __init__ (self , input_dim , hidden_dims , num_relation , edge_input_dim = None , short_cut = False , batch_norm = False ,
106-                  activation = "relu" , concat_hidden = False , readout :  Hint [ Readout ]  =   "sum" ):
111+                  activation = "relu" , concat_hidden = False , readout = "sum" ):
107112        super (RelationalGraphConvolutionalNetwork , self ).__init__ ()
108113
109114        if  not  isinstance (hidden_dims , Sequence ):
@@ -120,7 +125,14 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
120125            self .layers .append (layers .RelationalGraphConv (self .dims [i ], self .dims [i  +  1 ], num_relation , edge_input_dim ,
121126                                                          batch_norm , activation ))
122127
123-         self .readout  =  readout_resolver .make (readout )
128+         if  readout  ==  "sum" :
129+             self .readout  =  layers .SumReadout ()
130+         elif  readout  ==  "mean" :
131+             self .readout  =  layers .MeanReadout ()
132+         elif  readout  ==  "max" :
133+             self .readout  =  layers .MaxReadout ()
134+         else :
135+             raise  ValueError ("Unknown readout `%s`"  %  readout )
124136
125137    def  forward (self , graph , input , all_loss = None , metric = None ):
126138        """ 
0 commit comments