Skip to content

Commit f6da494

Browse files
authored
【Hackathon 6th No.17】为 Paddle 新增 sparse.mask_as API (#6663)
* [Add] sparse mask doc * [Update] doc * [Update] doc * [Update] doc * [Update] doc
1 parent e368ab4 commit f6da494

File tree

4 files changed

+58
-0
lines changed

4 files changed

+58
-0
lines changed

docs/api/paddle/sparse/Overview_cn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ paddle.sparse 目录包含飞桨框架支持稀疏数据存储和计算相关的
5959
" :ref:`paddle.sparse.reshape <cn_api_paddle_sparse_reshape>` ", "改变一个 SparseTensor 的形状"
6060
" :ref:`paddle.sparse.coalesce<cn_api_paddle_sparse_coalesce>` ", "对 SparseCooTensor 进行排序并合并"
6161
" :ref:`paddle.sparse.transpose <cn_api_paddle_sparse_transpose>` ", "在不改变数据的情况下改变 ``x`` 的维度顺序, 支持 COO 格式的多维 SparseTensor 以及 COO 格式的 2 维和 3 维 SparseTensor"
62+
" :ref:`paddle.sparse.mask_as<cn_api_paddle_sparse_mask_as>` ", "稀疏张量的掩码逻辑,使用稀疏张量 `mask` 的索引过滤输入的稠密张量 `x`"
6263

6364
.. _about_sparse_nn:
6465

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
.. _cn_api_paddle_sparse_mask_as:
2+
3+
mask_as
4+
-------------------------------
5+
6+
.. py:function:: paddle.sparse.mask_as(x, mask, name=None)
7+
8+
使用稀疏张量 `mask` 的索引过滤输入的稠密张量 `x`,并生成相应格式的稀疏张量。输入的 `x` 和 `mask` 必须具有相同的形状,且返回的稀疏张量具有与 `mask` 相同的索引,即使对应的索引中存在 `` 值。
9+
10+
参数
11+
:::::::::
12+
- **x** (DenseTensor) - 输入的 DenseTensor。数据类型为 float32,float64,int32,int64,complex64,complex128,int8,int16,float16。
13+
- **mask** (SparseTensor) - 输入的稀疏张量,可以为 SparseCooTensor、SparseCsrTensor。当其为 SparseCsrTensor 时,应该是 2D 或 3D 的形式。
14+
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
15+
16+
返回
17+
:::::::::
18+
SparseTensor: 其稀疏格式、dtype、shape 均与 `mask` 相同。
19+
20+
21+
代码示例
22+
:::::::::
23+
24+
COPY-FROM: paddle.sparse.mask_as
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
## [ 参数完全一致 ] torch.Tensor.sparse_mask
2+
3+
### [torch.Tensor.sparse_mask](https://pytorch.org/docs/stable/generated/torch.Tensor.sparse_mask.html)
4+
5+
```python
6+
torch.Tensor.sparse_mask(mask)
7+
```
8+
9+
### [paddle.sparse.mask_as](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/sparse/mask_as_cn.html)
10+
11+
```python
12+
paddle.sparse.mask_as(x, mask, name=None)
13+
```
14+
15+
两者功能一致,但调用方式不同,torch 通过 Tensor 类方法调用,而 paddle 是直接调用函数,具体如下:
16+
17+
### 参数映射
18+
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ---------- | ------------ | ------------------------------------ |
21+
| - | x | 输入的 DenseTensor。 |
22+
| mask | mask | 掩码逻辑的 mask,参数完全一致。 |
23+
24+
### 转写示例
25+
26+
```python
27+
# torch 调用 Tensor 类方法
28+
x.sparse_mask(mask)
29+
30+
# paddle 直接调用函数
31+
paddle.sparse.mask_as(x, mask)
32+
```

docs/guides/model_convert/convert_from_pytorch/pytorch_api_mapping_cn.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,7 @@
10661066
| REFERENCE-MAPPING-ITEM(`torch.Tensor.softmax`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.softmax.md) |
10671067
| REFERENCE-MAPPING-ITEM(`torch.Tensor.sort`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.sort.md) |
10681068
| REFERENCE-MAPPING-ITEM(`torch.Tensor.split`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.split.md) |
1069+
| REFERENCE-MAPPING-ITEM(`torch.Tensor.sparse_mask`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.sparse_mask.md) |
10691070
| REFERENCE-MAPPING-ITEM(`torch.Tensor.sqrt`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.sqrt.md) |
10701071
| REFERENCE-MAPPING-ITEM(`torch.Tensor.sqrt_`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.sqrt_.md) |
10711072
| REFERENCE-MAPPING-ITEM(`torch.Tensor.square`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.square.md) |

0 commit comments

Comments
 (0)