@@ -366,21 +366,25 @@ def __trunc__(self):
366366 # https://docs.xarray.dev/en/latest/api.html#id1
367367 @property
368368 def values (self ) -> TensorVariable :
369+ """Convert to a TensorVariable with the same data."""
369370 return typing .cast (TensorVariable , px .basic .tensor_from_xtensor (self ))
370371
371372 # Can't provide property data because that's already taken by Constants!
372373 # data = values
373374
374375 @property
375376 def coords (self ):
377+ """Not implemented."""
376378 raise NotImplementedError ("coords not implemented for XTensorVariable" )
377379
378380 @property
379381 def dims (self ) -> tuple [str , ...]:
382+ """The names of the dimensions of the variable."""
380383 return self .type .dims
381384
382385 @property
383386 def sizes (self ) -> dict [str , TensorVariable ]:
387+ """The sizes of the dimensions of the variable."""
384388 return dict (zip (self .dims , self .shape ))
385389
386390 @property
@@ -392,18 +396,22 @@ def as_numpy(self):
392396 # https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
393397 @property
394398 def ndim (self ) -> int :
399+ """The number of dimensions of the variable."""
395400 return self .type .ndim
396401
397402 @property
398403 def shape (self ) -> tuple [TensorVariable , ...]:
404+ """The shape of the variable."""
399405 return tuple (px .basic .tensor_from_xtensor (self ).shape ) # type: ignore
400406
401407 @property
402408 def size (self ) -> TensorVariable :
409+ """The total number of elements in the variable."""
403410 return typing .cast (TensorVariable , variadic_mul (* self .shape ))
404411
405412 @property
406- def dtype (self ):
413+ def dtype (self ) -> str :
414+ """The data type of the variable."""
407415 return self .type .dtype
408416
409417 @property
@@ -414,6 +422,7 @@ def broadcastable(self):
414422 # DataArray contents
415423 # https://docs.xarray.dev/en/latest/api.html#dataarray-contents
416424 def rename (self , new_name_or_name_dict = None , ** names ):
425+ """Rename the variable or its dimension(s)."""
417426 if isinstance (new_name_or_name_dict , str ):
418427 new_name = new_name_or_name_dict
419428 name_dict = None
@@ -425,31 +434,41 @@ def rename(self, new_name_or_name_dict=None, **names):
425434 return new_out
426435
427436 def copy (self , name : str | None = None ):
437+ """Create a copy of the variable.
438+
439+ This is just an identity operation, as XTensorVariables are immutable.
440+ """
428441 out = px .math .identity (self )
429442 out .name = name
430443 return out
431444
432445 def astype (self , dtype ):
446+ """Convert the variable to a different data type."""
433447 return px .math .cast (self , dtype )
434448
435449 def item (self ):
450+ """Not implemented."""
436451 raise NotImplementedError ("item not implemented for XTensorVariable" )
437452
438453 # Indexing
439454 # https://docs.xarray.dev/en/latest/api.html#id2
440455 def __setitem__ (self , idx , value ):
456+ """Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
441457 raise TypeError (
442458 "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
443459 )
444460
445461 @property
446462 def loc (self ):
463+ """Not implemented."""
447464 raise NotImplementedError ("loc not implemented for XTensorVariable" )
448465
449466 def sel (self , * args , ** kwargs ):
467+ """Not implemented."""
450468 raise NotImplementedError ("sel not implemented for XTensorVariable" )
451469
452470 def __getitem__ (self , idx ):
471+ """Index the variable positionally."""
453472 if isinstance (idx , dict ):
454473 return self .isel (idx )
455474
@@ -465,6 +484,7 @@ def isel(
465484 missing_dims : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
466485 ** indexers_kwargs ,
467486 ):
487+ """Index the variable along the specified dimension(s)."""
468488 if indexers_kwargs :
469489 if indexers is not None :
470490 raise ValueError (
@@ -505,6 +525,48 @@ def isel(
505525 return px .indexing .index (self , * indices )
506526
507527 def set (self , value ):
528+ """Return a copy of the variable indexed by self with the indexed values set to y.
529+
530+ The original variable is not modified.
531+
532+ Raises
533+ ------
534+ ValueError
535+ If self is not the result of an index operation
536+
537+ Examples
538+ --------
539+
540+ .. testcode::
541+
542+ import pytensor.xtensor as ptx
543+
544+ x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
545+ idx = ptx.as_xtensor([0, 1], dims=("a",))
546+ out = x[:, idx].set(1)
547+ out.eval()
548+
549+ .. test-output::
550+
551+ array([[1, 0],
552+ [0, 1]])
553+
554+
555+ .. testcode::
556+
557+ import pytensor.xtensor as ptx
558+
559+ x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
560+ idx = ptx.as_xtensor([0, 1], dims=("a",))
561+ out = x.isel({"b": idx}).set(-1)
562+ out.eval()
563+
564+ .. test-output::
565+
566+ array([[-1, 0],
567+ [0, -1]])
568+
569+ """
508570 if not (
509571 self .owner is not None and isinstance (self .owner .op , px .indexing .Index )
510572 ):
@@ -516,6 +578,48 @@ def set(self, value):
516578 return px .indexing .index_assignment (x , value , * idxs )
517579
518580 def inc (self , value ):
581+ """Return a copy of the variable indexed by self with the indexed values incremented by value.
582+
583+ The original variable is not modified.
584+
585+ Raises
586+ ------
587+ ValueError
588+ If self is not the result of an index operation
589+
590+ Examples
591+ --------
592+
593+ .. testcode::
594+
595+ import pytensor.xtensor as ptx
596+
597+ x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
598+ idx = ptx.as_xtensor([0, 1], dims=("a",))
599+ out = x[:, idx].inc(1)
600+ out.eval()
601+
602+ .. test-output::
603+
604+ array([[2, 1],
605+ [1, 2]])
606+
607+
608+ .. testcode::
609+
610+ import pytensor.xtensor as ptx
611+
612+ x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
613+ idx = ptx.as_xtensor([0, 1], dims=("a",))
614+ out = x.isel({"b": idx}).inc(-1)
615+ out.eval()
616+
617+ .. test-output::
618+
619+ array([[0, 1],
620+ [1, 0]])
621+
622+ """
519623 if not (
520624 self .owner is not None and isinstance (self .owner .op , px .indexing .Index )
521625 ):
@@ -579,7 +683,7 @@ def squeeze(
579683 drop = None ,
580684 axis : int | Sequence [int ] | None = None ,
581685 ):
582- """Remove dimensions of size 1 from an XTensorVariable .
686+ """Remove dimensions of size 1.
583687
584688 Parameters
585689 ----------
@@ -606,24 +710,21 @@ def expand_dims(
606710 axis : int | Sequence [int ] | None = None ,
607711 ** dim_kwargs ,
608712 ):
609- """Add one or more new dimensions to the tensor .
713+ """Add one or more new dimensions to the variable .
610714
611715 Parameters
612716 ----------
613717 dim : str | Sequence[str] | dict[str, int | Sequence] | None
614718 If str or sequence of str, new dimensions with size 1.
615719 If dict, keys are dimension names and values are either:
616- - int: the new size
617- - sequence: coordinates (length determines size)
720+
721+ - int: the new size
722+ - sequence: coordinates (length determines size)
618723 create_index_for_new_dim : bool, default: True
619- Currently ignored. Reserved for future coordinate support.
620- In xarray, when True (default), creates a coordinate index for the new dimension
621- with values from 0 to size-1. When False, no coordinate index is created.
724+ Ignored by PyTensor
622725 axis : int | Sequence[int] | None, default: None
623726 Not implemented yet. In xarray, specifies where to insert the new dimension(s).
624727 By default (None), new dimensions are inserted at the beginning (axis=0).
625- Symbolic axis is not supported yet.
626- Negative values count from the end.
627728 **dim_kwargs : int | Sequence
628729 Alternative to `dim` dict. Only used if `dim` is None.
629730
@@ -643,65 +744,75 @@ def expand_dims(
643744 # ndarray methods
644745 # https://docs.xarray.dev/en/latest/api.html#id7
645746 def clip (self , min , max ):
747+ """Clip the values of the variable to a specified range."""
646748 return px .math .clip (self , min , max )
647749
648750 def conj (self ):
751+ """Return the complex conjugate of the variable."""
649752 return px .math .conj (self )
650753
651754 @property
652755 def imag (self ):
756+ """Return the imaginary part of the variable."""
653757 return px .math .imag (self )
654758
655759 @property
656760 def real (self ):
761+ """Return the real part of the variable."""
657762 return px .math .real (self )
658763
659764 @property
660765 def T (self ):
661- """Return the full transpose of the tensor .
766+ """Return the full transpose of the variable .
662767
663768 This is equivalent to calling transpose() with no arguments.
664-
665- Returns
666- -------
667- XTensorVariable
668- Fully transposed tensor.
669769 """
670770 return self .transpose ()
671771
672772 # Aggregation
673773 # https://docs.xarray.dev/en/latest/api.html#id6
674774 def all (self , dim = None ):
775+ """Reduce the variable by applying `all` along some dimension(s)."""
675776 return px .reduction .all (self , dim )
676777
677778 def any (self , dim = None ):
779+ """Reduce the variable by applying `any` along some dimension(s)."""
678780 return px .reduction .any (self , dim )
679781
680782 def max (self , dim = None ):
783+ """Compute the maximum along the given dimension(s)."""
681784 return px .reduction .max (self , dim )
682785
683786 def min (self , dim = None ):
787+ """Compute the minimum along the given dimension(s)."""
684788 return px .reduction .min (self , dim )
685789
686790 def mean (self , dim = None ):
791+ """Compute the mean along the given dimension(s)."""
687792 return px .reduction .mean (self , dim )
688793
689794 def prod (self , dim = None ):
795+ """Compute the product along the given dimension(s)."""
690796 return px .reduction .prod (self , dim )
691797
692798 def sum (self , dim = None ):
799+ """Compute the sum along the given dimension(s)."""
693800 return px .reduction .sum (self , dim )
694801
695802 def std (self , dim = None , ddof = 0 ):
803+ """Compute the standard deviation along the given dimension(s)."""
696804 return px .reduction .std (self , dim , ddof = ddof )
697805
698806 def var (self , dim = None , ddof = 0 ):
807+ """Compute the variance along the given dimension(s)."""
699808 return px .reduction .var (self , dim , ddof = ddof )
700809
701810 def cumsum (self , dim = None ):
811+ """Compute the cumulative sum along the given dimension(s)."""
702812 return px .reduction .cumsum (self , dim )
703813
704814 def cumprod (self , dim = None ):
815+ """Compute the cumulative product along the given dimension(s)."""
705816 return px .reduction .cumprod (self , dim )
706817
707818 def diff (self , dim , n = 1 ):
@@ -720,7 +831,7 @@ def transpose(
720831 * dim : str | EllipsisType ,
721832 missing_dims : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
722833 ):
723- """Transpose dimensions of the tensor .
834+ """Transpose the dimensions of the variable .
724835
725836 Parameters
726837 ----------
@@ -729,6 +840,7 @@ def transpose(
729840 Can use ellipsis (...) to represent remaining dimensions.
730841 missing_dims : {"raise", "warn", "ignore"}, default="raise"
731842 How to handle dimensions that don't exist in the tensor:
843+
732844 - "raise": Raise an error if any dimensions don't exist
733845 - "warn": Warn if any dimensions don't exist
734846 - "ignore": Silently ignore any dimensions that don't exist
@@ -747,21 +859,38 @@ def transpose(
747859 return px .shape .transpose (self , * dim , missing_dims = missing_dims )
748860
749861 def stack (self , dim , ** dims ):
862+ """Stack existing dimensions into a single new dimension."""
750863 return px .shape .stack (self , dim , ** dims )
751864
752865 def unstack (self , dim , ** dims ):
866+ """Unstack a dimension into multiple dimensions of a given size.
867+
868+ Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
869+ Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.
870+
871+ .. testcode::
872+
873+ import pytensor.xtensor as ptx
874+
875+ x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
876+ stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
877+ unstacked_cumsum = stacked_cumsum.unstack({"c": x.sizes})
878+ unstacked_cumsum.eval()
879+
880+ .. test-output::
881+
882+ array([[ 1, 3],
883+ [ 6, 10]])
884+
885+ """
753886 return px .shape .unstack (self , dim , ** dims )
754887
755888 def dot (self , other , dim = None ):
756- """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims ."""
889+ """Generalized dot product with another XTensorVariable."""
757890 return px .math .dot (self , other , dim = dim )
758891
759- def broadcast (self , * others , exclude = None ):
760- """Broadcast this tensor against other XTensorVariables."""
761- return px .shape .broadcast (self , * others , exclude = exclude )
762-
763892 def broadcast_like (self , other , exclude = None ):
764- """Broadcast this tensor against another XTensorVariable."""
893+ """Broadcast against another XTensorVariable."""
765894 _ , self_bcast = px .shape .broadcast (other , self , exclude = exclude )
766895 return self_bcast
767896
0 commit comments