File tree Expand file tree Collapse file tree 3 files changed +12
-6
lines changed Expand file tree Collapse file tree 3 files changed +12
-6
lines changed Original file line number Diff line number Diff line change @@ -1406,11 +1406,8 @@ def infer_static_shape(
14061406 `shape` will be validated and constant folded. As a result, this function
14071407 can be expensive and shouldn't be used unless absolutely necessary.
14081408
1409- It mostly exists as a hold-over from pre-static shape times, when it was
1410- required in order to produce correct broadcastable arrays and prevent
1411- some graphs from being unusable. Now, it is no longer strictly required,
1412- so don't use it unless you want the same shape graphs to be rewritten
1413- multiple times during graph construction.
1409+ It is often needed for `Op`s whose static shape and broadcastable flags
1410+ depend on the values of their inputs, such as `Alloc` and `RandomVariable`.
14141411
14151412 Returns
14161413 -------
Original file line number Diff line number Diff line change @@ -992,12 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node):
992992 return [specify_shape (inner_obj , shape )]
993993
994994
995+ _empty_shape = constant ([], dtype = "int64" )
996+
997+
995998@register_infer_shape
996999@node_rewriter ([Shape ])
9971000def local_shape_ground (fgraph , node ):
9981001 """Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
9991002 [x ] = node .inputs
10001003 static_shape = x .type .shape
1004+ if len (static_shape ) == 0 :
1005+ return [_empty_shape ]
10011006 if not any (dim is None for dim in static_shape ):
10021007 return [stack ([constant (dim , dtype = "int64" ) for dim in static_shape ])]
10031008
Original file line number Diff line number Diff line change @@ -908,7 +908,7 @@ def test_runtime_broadcast(self, mode):
908908 self .check_runtime_broadcast (mode )
909909
910910
911- def test_infer_shape ():
911+ def test_infer_static_shape ():
912912 with pytest .raises (TypeError , match = "^Shapes must be scalar integers.*" ):
913913 infer_static_shape ([constant (1.0 )])
914914
@@ -925,6 +925,10 @@ def test_infer_shape():
925925 sh , static_shape = infer_static_shape (specify_size )
926926 assert static_shape == (1 ,)
927927
928+ x = scalar ("x" )
929+ sh , static_shape = infer_static_shape ([x .size ])
930+ assert static_shape == (1 ,)
931+
928932
929933# This is slow for the ('int8', 3) version.
930934def test_eye ():
You can’t perform that action at this time.
0 commit comments