@@ -366,21 +366,25 @@ def __trunc__(self):
366366 # https://docs.xarray.dev/en/latest/api.html#id1
367367 @property
368368 def values (self ) -> TensorVariable :
369+ """Convert to a TensorVariable with the same data."""
369370 return typing .cast (TensorVariable , px .basic .tensor_from_xtensor (self ))
370371
371372 # Can't provide property data because that's already taken by Constants!
372373 # data = values
373374
374375 @property
375376 def coords (self ):
377+ """Not implemented."""
376378 raise NotImplementedError ("coords not implemented for XTensorVariable" )
377379
378380 @property
379381 def dims (self ) -> tuple [str , ...]:
382+ """The names of the dimensions of the variable."""
380383 return self .type .dims
381384
382385 @property
383386 def sizes (self ) -> dict [str , TensorVariable ]:
387+ """The sizes of the dimensions of the variable."""
384388 return dict (zip (self .dims , self .shape ))
385389
386390 @property
@@ -392,18 +396,22 @@ def as_numpy(self):
392396 # https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
393397 @property
394398 def ndim (self ) -> int :
399+ """The number of dimensions of the variable."""
395400 return self .type .ndim
396401
397402 @property
398403 def shape (self ) -> tuple [TensorVariable , ...]:
404+ """The shape of the variable."""
399405 return tuple (px .basic .tensor_from_xtensor (self ).shape ) # type: ignore
400406
401407 @property
402408 def size (self ) -> TensorVariable :
409+ """The total number of elements in the variable."""
403410 return typing .cast (TensorVariable , variadic_mul (* self .shape ))
404411
405412 @property
406- def dtype (self ):
413+ def dtype (self ) -> str :
414+ """The data type of the variable."""
407415 return self .type .dtype
408416
409417 @property
@@ -414,6 +422,7 @@ def broadcastable(self):
414422 # DataArray contents
415423 # https://docs.xarray.dev/en/latest/api.html#dataarray-contents
416424 def rename (self , new_name_or_name_dict = None , ** names ):
425+ """Rename the variable or its dimension(s)."""
417426 if isinstance (new_name_or_name_dict , str ):
418427 new_name = new_name_or_name_dict
419428 name_dict = None
@@ -425,31 +434,41 @@ def rename(self, new_name_or_name_dict=None, **names):
425434 return new_out
426435
427436 def copy (self , name : str | None = None ):
437+ """Create a copy of the variable.
438+
439+ This is just an identity operation, as XTensorVariables are immutable.
440+ """
428441 out = px .math .identity (self )
429442 out .name = name
430443 return out
431444
432445 def astype (self , dtype ):
446+ """Convert the variable to a different data type."""
433447 return px .math .cast (self , dtype )
434448
435449 def item (self ):
450+ """Not implemented."""
436451 raise NotImplementedError ("item not implemented for XTensorVariable" )
437452
438453 # Indexing
439454 # https://docs.xarray.dev/en/latest/api.html#id2
440455 def __setitem__ (self , idx , value ):
456+ """Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
441457 raise TypeError (
442458 "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
443459 )
444460
445461 @property
446462 def loc (self ):
463+ """Not implemented."""
447464 raise NotImplementedError ("loc not implemented for XTensorVariable" )
448465
449466 def sel (self , * args , ** kwargs ):
467+ """Not implemented."""
450468 raise NotImplementedError ("sel not implemented for XTensorVariable" )
451469
452470 def __getitem__ (self , idx ):
471+ """Index the variable positionally."""
453472 if isinstance (idx , dict ):
454473 return self .isel (idx )
455474
@@ -465,6 +484,7 @@ def isel(
465484 missing_dims : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
466485 ** indexers_kwargs ,
467486 ):
487+ """Index the variable along the specified dimension(s)."""
468488 if indexers_kwargs :
469489 if indexers is not None :
470490 raise ValueError (
@@ -505,6 +525,48 @@ def isel(
505525 return px .indexing .index (self , * indices )
506526
507527 def set (self , value ):
528+ """Return a copy of the variable indexed by self with the indexed values set to y.
529+
530+ The original variable is not modified.
531+
532+ Raises
533+ ------
534+ Value:
535+ If self is not the result of an index operation
536+
537+ Examples
538+ --------
539+
540+ .. test-code::
541+
542+ import pytensor.xtensor as ptx
543+
544+ x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
545+ idx = ptx.as_xtensor([0, 1], dims=("a",))
546+ out = x[:, idx].set(1)
547+ out.eval()
548+
549+ .. test-output::
550+
551+ array([[1, 0],
552+ [0, 1]])
553+
554+
555+ .. test-code::
556+
557+ import pytensor.xtensor as ptx
558+
559+ x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
560+ idx = ptx.as_xtensor([0, 1], dims=("a",))
561+ out = x.isel({"b": idx}).set(-1)
562+ out.eval()
563+
564+ .. test-output::
565+
566+ array([[-1, 0],
567+ [0, -1]])
568+
569+ """
508570 if not (
509571 self .owner is not None and isinstance (self .owner .op , px .indexing .Index )
510572 ):
@@ -516,6 +578,48 @@ def set(self, value):
516578 return px .indexing .index_assignment (x , value , * idxs )
517579
518580 def inc (self , value ):
581+ """Return a copy of the variable indexed by self with the indexed values incremented by value.
582+
583+ The original variable is not modified.
584+
585+ Raises
586+ ------
587+ Value:
588+ If self is not the result of an index operation
589+
590+ Examples
591+ --------
592+
593+ .. test-code::
594+
595+ import pytensor.xtensor as ptx
596+
597+ x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
598+ idx = ptx.as_xtensor([0, 1], dims=("a",))
599+ out = x[:, idx].inc(1)
600+ out.eval()
601+
602+ .. test-output::
603+
604+ array([[2, 1],
605+ [1, 2]])
606+
607+
608+ .. test-code::
609+
610+ import pytensor.xtensor as ptx
611+
612+ x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
613+ idx = ptx.as_xtensor([0, 1], dims=("a",))
614+ out = x.isel({"b": idx}).inc(-1)
615+ out.eval()
616+
617+ .. test-output::
618+
619+ array([[0, 1],
620+ [1, 0]])
621+
622+ """
519623 if not (
520624 self .owner is not None and isinstance (self .owner .op , px .indexing .Index )
521625 ):
@@ -579,7 +683,7 @@ def squeeze(
579683 drop = None ,
580684 axis : int | Sequence [int ] | None = None ,
581685 ):
582- """Remove dimensions of size 1 from an XTensorVariable .
686+ """Remove dimensions of size 1.
583687
584688 Parameters
585689 ----------
@@ -606,7 +710,7 @@ def expand_dims(
606710 axis : int | Sequence [int ] | None = None ,
607711 ** dim_kwargs ,
608712 ):
609- """Add one or more new dimensions to the tensor .
713+ """Add one or more new dimensions to the variable .
610714
611715 Parameters
612716 ----------
@@ -616,14 +720,10 @@ def expand_dims(
616720 - int: the new size
617721 - sequence: coordinates (length determines size)
618722 create_index_for_new_dim : bool, default: True
619- Currently ignored. Reserved for future coordinate support.
620- In xarray, when True (default), creates a coordinate index for the new dimension
621- with values from 0 to size-1. When False, no coordinate index is created.
723+ Ignored by PyTensor
622724 axis : int | Sequence[int] | None, default: None
623725 Not implemented yet. In xarray, specifies where to insert the new dimension(s).
624726 By default (None), new dimensions are inserted at the beginning (axis=0).
625- Symbolic axis is not supported yet.
626- Negative values count from the end.
627727 **dim_kwargs : int | Sequence
628728 Alternative to `dim` dict. Only used if `dim` is None.
629729
@@ -643,65 +743,75 @@ def expand_dims(
643743 # ndarray methods
644744 # https://docs.xarray.dev/en/latest/api.html#id7
645745 def clip (self , min , max ):
746+ """Clip the values of the variable to a specified range."""
646747 return px .math .clip (self , min , max )
647748
648749 def conj (self ):
750+ """Return the complex conjugate of the variable."""
649751 return px .math .conj (self )
650752
651753 @property
652754 def imag (self ):
755+ """Return the imaginary part of the variable."""
653756 return px .math .imag (self )
654757
655758 @property
656759 def real (self ):
760+ """Return the real part of the variable."""
657761 return px .math .real (self )
658762
659763 @property
660764 def T (self ):
661- """Return the full transpose of the tensor .
765+ """Return the full transpose of the variable .
662766
663767 This is equivalent to calling transpose() with no arguments.
664-
665- Returns
666- -------
667- XTensorVariable
668- Fully transposed tensor.
669768 """
670769 return self .transpose ()
671770
672771 # Aggregation
673772 # https://docs.xarray.dev/en/latest/api.html#id6
674773 def all (self , dim = None ):
774+ """Reduce the variable by applying `all` along some dimension(s)."""
675775 return px .reduction .all (self , dim )
676776
677777 def any (self , dim = None ):
778+ """Reduce the variable by applying `any` along some dimension(s)."""
678779 return px .reduction .any (self , dim )
679780
680781 def max (self , dim = None ):
782+ """Compute the maximum along the given dimension(s)."""
681783 return px .reduction .max (self , dim )
682784
683785 def min (self , dim = None ):
786+ """Compute the minimum along the given dimension(s)."""
684787 return px .reduction .min (self , dim )
685788
686789 def mean (self , dim = None ):
790+ """Compute the mean along the given dimension(s)."""
687791 return px .reduction .mean (self , dim )
688792
689793 def prod (self , dim = None ):
794+ """Compute the product along the given dimension(s)."""
690795 return px .reduction .prod (self , dim )
691796
692797 def sum (self , dim = None ):
798+ """Compute the sum along the given dimension(s)."""
693799 return px .reduction .sum (self , dim )
694800
695801 def std (self , dim = None , ddof = 0 ):
802+ """Compute the standard deviation along the given dimension(s)."""
696803 return px .reduction .std (self , dim , ddof = ddof )
697804
698805 def var (self , dim = None , ddof = 0 ):
806+ """Compute the variance along the given dimension(s)."""
699807 return px .reduction .var (self , dim , ddof = ddof )
700808
701809 def cumsum (self , dim = None ):
810+ """Compute the cumulative sum along the given dimension(s)."""
702811 return px .reduction .cumsum (self , dim )
703812
704813 def cumprod (self , dim = None ):
814+ """Compute the cumulative product along the given dimension(s)."""
705815 return px .reduction .cumprod (self , dim )
706816
707817 def diff (self , dim , n = 1 ):
@@ -720,7 +830,7 @@ def transpose(
720830 * dim : str | EllipsisType ,
721831 missing_dims : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
722832 ):
723- """Transpose dimensions of the tensor .
833+ """Transpose the dimensions of the variable .
724834
725835 Parameters
726836 ----------
@@ -747,21 +857,38 @@ def transpose(
747857 return px .shape .transpose (self , * dim , missing_dims = missing_dims )
748858
749859 def stack (self , dim , ** dims ):
860+ """Stack existing dimensions into a single new dimension."""
750861 return px .shape .stack (self , dim , ** dims )
751862
752863 def unstack (self , dim , ** dims ):
864+ """Unstack a dimension into multiple dimensions of a given size.
865+
866+ Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
867+ Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.
868+
869+ .. test-code::
870+
871+ import pytensor.xtensor as ptx
872+
873+ x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
874+ stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
875+ unstacked_cumsum = stacked_x.unstack({"c": x.sizes})
876+ unstacked_cumsum.eval()
877+
878+ .. test-output::
879+
880+ array([[ 1, 3],
881+ [ 6, 10]])
882+
883+ """
753884 return px .shape .unstack (self , dim , ** dims )
754885
755886 def dot (self , other , dim = None ):
756- """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims ."""
887+ """Generalized dot product with another XTensorVariable."""
757888 return px .math .dot (self , other , dim = dim )
758889
759- def broadcast (self , * others , exclude = None ):
760- """Broadcast this tensor against other XTensorVariables."""
761- return px .shape .broadcast (self , * others , exclude = exclude )
762-
763890 def broadcast_like (self , other , exclude = None ):
764- """Broadcast this tensor against another XTensorVariable."""
891+ """Broadcast against another XTensorVariable."""
765892 _ , self_bcast = px .shape .broadcast (other , self , exclude = exclude )
766893 return self_bcast
767894
0 commit comments