| 
4 | 4 | from pytensor.graph.basic import Apply, Constant  | 
5 | 5 | from pytensor.graph.op import Op  | 
6 | 6 | from pytensor.misc.safe_asarray import _asarray  | 
7 |  | -from pytensor.tensor.basic import arange, as_tensor_variable, flatten, switch  | 
 | 7 | +from pytensor.tensor.basic import arange, as_tensor_variable, switch  | 
8 | 8 | from pytensor.tensor.math import eq, ge, mul  | 
9 |  | -from pytensor.tensor.shape import shape  | 
10 |  | -from pytensor.tensor.subtensor import set_subtensor  | 
11 |  | -from pytensor.tensor.type import TensorType, integer_dtypes  | 
 | 9 | +from pytensor.tensor.type import TensorType  | 
12 | 10 | 
 
  | 
13 | 11 | 
 
  | 
14 | 12 | def _variable_is_none(var):  | 
@@ -304,270 +302,3 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):  | 
304 | 302 |     else:  | 
305 | 303 |         zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]  | 
306 | 304 |         return zi.astype(idx_dtype)  | 
307 |  | - | 
308 |  | - | 
309 |  | -class TopKOp(Op):  | 
310 |  | -    """Operations related to finding k-largest elements.  | 
311 |  | -
  | 
312 |  | -    Parameters  | 
313 |  | -    ----------  | 
314 |  | -    axis: integer  | 
315 |  | -        Defaults to ``-1``.  | 
316 |  | -        The axis to perform the operation. Must be in range ``[-ndim, ndim)``, where  | 
317 |  | -        ``ndim`` is the dimensionality of input tensor.  | 
318 |  | -
  | 
319 |  | -    idx_dtype: string  | 
320 |  | -        Specify output dtype for indices, defaults to ``int64``, must be integer type.  | 
321 |  | -
  | 
322 |  | -    sorted: bool  | 
323 |  | -        NOTE: NOT IMPLEMENTED YET  | 
324 |  | -        Defaults to ``True``  | 
325 |  | -
  | 
326 |  | -        If True, the result array would be sorted in descending order.  | 
327 |  | -
  | 
328 |  | -
  | 
329 |  | -    Notes  | 
330 |  | -    -----  | 
331 |  | -    - The output order is not guaranteed. On the CPU, we use  | 
332 |  | -      ``np.partition`` and ``np.argpartition`` that only make sure the  | 
333 |  | -      k-th element is the correct one and that the other  | 
334 |  | -      elements are on the correct side.  | 
335 |  | -    - By default, this Op gives two outputs: values and indices. However  | 
336 |  | -      optimizers may remove a certain output if not needed.  | 
337 |  | -    - Computing the gradient requests the computation of the indices in  | 
338 |  | -      forward pass.  | 
339 |  | -    - If the top-k-th value is not unique, we cannot guarantee the  | 
340 |  | -      output indices being deterministically chosen.  | 
341 |  | -
  | 
342 |  | -    See Also  | 
343 |  | -    --------  | 
344 |  | -    topk  | 
345 |  | -    argtopk  | 
346 |  | -    argtopk_and_topk  | 
347 |  | -
  | 
348 |  | -    """  | 
349 |  | - | 
350 |  | -    # TODO more params  | 
351 |  | -    """  | 
352 |  | -    only_top_kth: bool  | 
353 |  | -        Defaults to ``False``  | 
354 |  | -
  | 
355 |  | -        If ``True``, will only find one exact top k-th element on given axis.  | 
356 |  | -
  | 
357 |  | -    """  | 
358 |  | - | 
359 |  | -    # TODO c_code  | 
360 |  | -    # TODO add opt, if k==1, use max/min reduce  | 
361 |  | -    #      also if k is axis size, just copy input tensor  | 
362 |  | -    # TODO add opt, to merge argtopk / topk  | 
363 |  | -    __props__ = ("axis", "sorted", "return_values", "return_indices", "idx_dtype")  | 
364 |  | - | 
365 |  | -    def __init__(  | 
366 |  | -        self,  | 
367 |  | -        axis=-1,  | 
368 |  | -        sorted=True,  | 
369 |  | -        idx_dtype="int64",  | 
370 |  | -        return_values=True,  | 
371 |  | -        return_indices=True,  | 
372 |  | -    ):  | 
373 |  | -        # numpy always uses int64 as output dtype for arg*() routines  | 
374 |  | -        # however, we add "idx_dtype" param as memory is more precious on gpu  | 
375 |  | -        if not isinstance(axis, int):  | 
376 |  | -            raise TypeError(f'"axis" parameter must be integer, got "{type(axis)}"')  | 
377 |  | -        if sorted:  | 
378 |  | -            raise NotImplementedError(  | 
379 |  | -                "The sorted parameter is not yet implemented. Use sorted=False for now."  | 
380 |  | -            )  | 
381 |  | -        if idx_dtype not in integer_dtypes:  | 
382 |  | -            raise TypeError(  | 
383 |  | -                f'"idx_dtype" parameter must be an integer dtype, got "{idx_dtype}"'  | 
384 |  | -            )  | 
385 |  | - | 
386 |  | -        if not (return_indices or return_values):  | 
387 |  | -            raise ValueError(  | 
388 |  | -                "Neither return_values nor return_indices is True, this isn't allowed"  | 
389 |  | -            )  | 
390 |  | - | 
391 |  | -        self.axis = axis  | 
392 |  | -        self.sorted = sorted  | 
393 |  | -        self.return_values = return_values  | 
394 |  | -        self.return_indices = return_indices  | 
395 |  | -        self.idx_dtype = idx_dtype  | 
396 |  | - | 
397 |  | -    def __str__(self):  | 
398 |  | -        return "%(op)s{axis=%(axis)d, sorted=%(sorted)s}" % dict(  | 
399 |  | -            op=self.__class__.__name__, axis=self.axis, sorted=self.sorted  | 
400 |  | -        )  | 
401 |  | - | 
402 |  | -    def make_node(self, inp, kth):  | 
403 |  | -        inp = as_tensor_variable(inp)  | 
404 |  | -        ndim = inp.ndim  | 
405 |  | -        if ndim == 0:  | 
406 |  | -            raise ValueError("Cannot take scalar as input")  | 
407 |  | -        if not -ndim <= self.axis < ndim:  | 
408 |  | -            raise IndexError(  | 
409 |  | -                '"axis" parameter out of range,'  | 
410 |  | -                f" expected integer within [{int(-ndim)}, {int(ndim - 1)}]"  | 
411 |  | -            )  | 
412 |  | - | 
413 |  | -        kth = as_tensor_variable(kth)  | 
414 |  | -        _check_tensor_is_scalar(kth)  | 
415 |  | -        outs = []  | 
416 |  | -        if self.return_values:  | 
417 |  | -            outs.append(  | 
418 |  | -                TensorType(dtype=inp.type.dtype, shape=(None,) * inp.type.ndim)()  | 
419 |  | -            )  | 
420 |  | -        if self.return_indices:  | 
421 |  | -            outs.append(  | 
422 |  | -                TensorType(dtype=self.idx_dtype, shape=(None,) * inp.type.ndim)()  | 
423 |  | -            )  | 
424 |  | -        return Apply(self, [inp, kth], outs)  | 
425 |  | - | 
426 |  | -    def perform(self, node, inputs, output_storage):  | 
427 |  | -        x, k = inputs  | 
428 |  | -        axis = self.axis  | 
429 |  | -        if not self.return_indices:  | 
430 |  | -            pzv = output_storage[0]  | 
431 |  | -            pzv[0] = _topk_py_impl(self, x, k, axis, None)  | 
432 |  | -        elif self.return_values:  | 
433 |  | -            pzv = output_storage[0]  | 
434 |  | -            pzi = output_storage[1]  | 
435 |  | -            pzv[0], pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[1].dtype)  | 
436 |  | -        else:  | 
437 |  | -            pzi = output_storage[0]  | 
438 |  | -            pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[0].dtype)  | 
439 |  | - | 
440 |  | -    def infer_shape(self, fgraph, node, inp_shapes):  | 
441 |  | -        shp = list(inp_shapes[0])  | 
442 |  | -        shp[self.axis] = np.abs(node.inputs[1])  | 
443 |  | -        shp = tuple(shp)  | 
444 |  | -        return [shp for i in [self.return_values, self.return_indices] if i]  | 
445 |  | - | 
446 |  | -    def L_op(self, inputs, outputs, out_grads):  | 
447 |  | -        x, k = inputs  | 
448 |  | -        k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable")  | 
449 |  | - | 
450 |  | -        if not (self.return_indices or self.return_values):  | 
451 |  | -            x_grad = grad_undefined(  | 
452 |  | -                self,  | 
453 |  | -                0,  | 
454 |  | -                x,  | 
455 |  | -                "topk: cannot get gradient without both indices and values",  | 
456 |  | -            )  | 
457 |  | -        else:  | 
458 |  | -            x_shp = shape(x)  | 
459 |  | -            z_grad = out_grads[0]  | 
460 |  | -            ndim = x.ndim  | 
461 |  | -            axis = self.axis % ndim  | 
462 |  | -            grad_indices = [  | 
463 |  | -                arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1))  | 
464 |  | -                if i != axis  | 
465 |  | -                else outputs[-1]  | 
466 |  | -                for i in range(ndim)  | 
467 |  | -            ]  | 
468 |  | -            x_grad = x.zeros_like(dtype=z_grad.dtype)  | 
469 |  | -            x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad)  | 
470 |  | - | 
471 |  | -        return [x_grad, k_grad]  | 
472 |  | - | 
473 |  | - | 
474 |  | -def topk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):  | 
475 |  | -    """  | 
476 |  | -    Returns the k-largest elements along an axis.  | 
477 |  | -
  | 
478 |  | -    Parameters  | 
479 |  | -    ----------  | 
480 |  | -
  | 
481 |  | -    x: tensor instance  | 
482 |  | -
  | 
483 |  | -    kth: integer constant/variable  | 
484 |  | -        Must not be 0. If negative, gives k-smallest elements instead.  | 
485 |  | -
  | 
486 |  | -    axis: integer or ``None``  | 
487 |  | -        Upon which axis shall the operation be performed on.  | 
488 |  | -        If ``None``, works on flattened array.  | 
489 |  | -
  | 
490 |  | -    sorted: bool  | 
491 |  | -        NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.  | 
492 |  | -        Defaults to ``True``  | 
493 |  | -
  | 
494 |  | -        If True, the result array would be sorted in descending order.  | 
495 |  | -
  | 
496 |  | -    idx_dtype: string  | 
497 |  | -        Specify output dtype used in indices, defaults to ``int64``, must be integer type.  | 
498 |  | -        This option is here because indices are needed for gradient.  | 
499 |  | -
  | 
500 |  | -    Returns  | 
501 |  | -    -------  | 
502 |  | -    Tensor variable with same dtype as `x`.  | 
503 |  | -
  | 
504 |  | -    Notes  | 
505 |  | -    -----  | 
506 |  | -    - ``sorted=True`` is not supported yet.  | 
507 |  | -
  | 
508 |  | -    """  | 
509 |  | -    if axis is None:  | 
510 |  | -        x = flatten(x)  | 
511 |  | -        axis = 0  | 
512 |  | -    return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[0]  | 
513 |  | - | 
514 |  | - | 
515 |  | -def argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):  | 
516 |  | -    """  | 
517 |  | -    Returns the indices of k-largest elements along an axis.  | 
518 |  | -
  | 
519 |  | -    Parameters  | 
520 |  | -    ----------  | 
521 |  | -
  | 
522 |  | -    x: tensor instance  | 
523 |  | -
  | 
524 |  | -    kth: integer constant/variable  | 
525 |  | -        Must not be 0. If negative, gives k-smallest elements instead.  | 
526 |  | -
  | 
527 |  | -    sorted: bool  | 
528 |  | -        NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.  | 
529 |  | -        Defaults to ``True``  | 
530 |  | -
  | 
531 |  | -        If True, the result array of corresponding indices would be sorted in descending order.  | 
532 |  | -
  | 
533 |  | -
  | 
534 |  | -    axis: integer, tuple/list of integers, or ``None``  | 
535 |  | -        Upon which axis shall the operation be performed on.  | 
536 |  | -        If ``None``, works on flattened array.  | 
537 |  | -
  | 
538 |  | -    idx_dtype: string  | 
539 |  | -        Specify output dtype, defaults to ``int64``, must be integer type.  | 
540 |  | -
  | 
541 |  | -    Returns  | 
542 |  | -    -------  | 
543 |  | -    Tensor variable with dtype specified in `idx_dtype`.  | 
544 |  | -
  | 
545 |  | -    Notes  | 
546 |  | -    -----  | 
547 |  | -    - ``sorted=True`` is not supported yet.  | 
548 |  | -
  | 
549 |  | -    - If the top-k-th value is not unique, we cannot guarantee the output  | 
550 |  | -      indices are deterministically chosen.  | 
551 |  | -
  | 
552 |  | -    """  | 
553 |  | -    if axis is None:  | 
554 |  | -        x = flatten(x)  | 
555 |  | -        axis = 0  | 
556 |  | -    return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[1]  | 
557 |  | - | 
558 |  | - | 
559 |  | -def topk_and_argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):  | 
560 |  | -    """  | 
561 |  | -    Returns the results of both topk() and argtopk() in one Op.  | 
562 |  | -
  | 
563 |  | -    See the respective documentation for details.  | 
564 |  | -
  | 
565 |  | -    Returns  | 
566 |  | -    -------  | 
567 |  | -    tuple: (values, indices)  | 
568 |  | -
  | 
569 |  | -    """  | 
570 |  | -    if axis is None:  | 
571 |  | -        x = flatten(x)  | 
572 |  | -        axis = 0  | 
573 |  | -    return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)  | 
0 commit comments