@@ -458,16 +458,16 @@ def convert_shape(shape: Shape) -> Optional[WeakShape]:
458458 """Process a user-provided shape variable into None or a valid shape object."""
459459 if shape is None :
460460 return None
461-
462- if isinstance (shape , int ) or (isinstance (shape , TensorVariable ) and shape .ndim == 0 ):
461+ elif isinstance (shape , int ) or (isinstance (shape , TensorVariable ) and shape .ndim == 0 ):
463462 shape = (shape ,)
463+ elif isinstance (shape , TensorVariable ) and shape .ndim == 1 :
464+ shape = tuple (shape )
464465 elif isinstance (shape , (list , tuple )):
465466 shape = tuple (shape )
466467 else :
467468 raise ValueError (
468469 f"The `shape` parameter must be a tuple, TensorVariable, int or list. Actual: { type (shape )} "
469470 )
470-
471471 if isinstance (shape , tuple ) and any (s == Ellipsis for s in shape [:- 1 ]):
472472 raise ValueError (
473473 f"Ellipsis in `shape` may only appear in the last position. Actual: { shape } "
@@ -480,16 +480,16 @@ def convert_size(size: Size) -> Optional[StrongSize]:
480480 """Process a user-provided size variable into None or a valid size object."""
481481 if size is None :
482482 return None
483-
484- if isinstance (size , int ) or (isinstance (size , TensorVariable ) and size .ndim == 0 ):
483+ elif isinstance (size , int ) or (isinstance (size , TensorVariable ) and size .ndim == 0 ):
485484 size = (size ,)
485+ elif isinstance (size , TensorVariable ) and size .ndim == 1 :
486+ size = tuple (size )
486487 elif isinstance (size , (list , tuple )):
487488 size = tuple (size )
488489 else :
489490 raise ValueError (
490491 f"The `size` parameter must be a tuple, TensorVariable, int or list. Actual: { type (size )} "
491492 )
492-
493493 if isinstance (size , tuple ) and Ellipsis in size :
494494 raise ValueError (f"The `size` parameter cannot contain an Ellipsis. Actual: { size } " )
495495
0 commit comments