Skip to content

Commit 70a2ccf

Browse files
authored
API Improvement for paddle.nn.initializer.TruncatedNormal 易用性提升 (#6642)
* update truncated normal docs * update docs * update docs * update docs * update api difference * update api difference
1 parent d715da6 commit 70a2ccf

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

docs/api/paddle/nn/initializer/TruncatedNormal_cn.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
TruncatedNormal
44
-------------------------------
55

6-
.. py:class:: paddle.nn.initializer.TruncatedNormal(mean=0.0, std=1.0, name=None)
7-
6+
.. py:class:: paddle.nn.initializer.TruncatedNormal(mean=0.0, std=1.0, a=-2.0, b=2.0, name=None)
87
98
截断正态分布(高斯分布)初始化方法。
109

10+
.. note::
11+
在参数设置时建议将 `mean` 设为 :math:`a \le mean \le b`。
12+
:math:`mean < a - 2 \cdot std` 或 :math:`mean > b + 2 \cdot std`,采样值的分布可能是有误的。
13+
1114
参数
1215
- **mean** (float,可选) - 正态分布的均值,默认值为 :math:`0.0`。
1316
- **std** (float,可选) - 正态分布的标准差,默认值为 :math:`1.0`。
17+
- **a** (float,可选) - 截断正态分布的下界,默认值为 :math:`-2.0`。
18+
- **b** (float,可选) - 截断正态分布的上界,默认值为 :math:`2.0`。
1419
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
1520

1621
返回
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
## [ 组合替代实现 ]torch.nn.init.trunc_normal_
2+
3+
### [torch.nn.init.trunc_normal_](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.trunc_normal_)
4+
5+
```python
6+
torch.nn.init.trunc_normal_(tensor,
7+
mean=0.0,
8+
std=1.0,
9+
a=-2.0,
10+
b=2.0)
11+
```
12+
13+
### [paddle.nn.initializer.TruncatedNormal](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/initializer/TruncatedNormal_cn.html)
14+
15+
```python
16+
paddle.nn.initializer.TruncatedNormal(mean=0.0,
17+
std=1.0,
18+
a=-2.0,
19+
b=2.0,
20+
name=None)
21+
```
22+
23+
两者用法不同:torch 是 inplace 的用法,paddle 是类设置的,具体如下:
24+
25+
### 参数映射
26+
27+
| PyTorch | PaddlePaddle | 备注 |
28+
| ------------- | ------------ | ------------------------------------------------------ |
29+
| tensor | - | n 维 tensor。Paddle 无此参数,因为是通过调用类的 __call__ 函数来进行 tensor 的初始化。 |
30+
| mean | mean | 正态分布的平均值。参数名和默认值一致。 |
31+
| std | std | 正态分布的标准差。参数名和默认值一致。 |
32+
| a | a | 截断正态分布的下界。参数名和默认值一致。 |
33+
| b | b | 截断正态分布的上界。参数名和默认值一致。 |
34+
35+
### 转写示例
36+
```python
37+
# PyTorch 写法
38+
conv = torch.nn.Conv2d(4, 6, (3, 3))
39+
torch.nn.init.trunc_normal_(conv.weight)
40+
41+
# Paddle 写法
42+
conv = nn.Conv2D(in_channels=4, out_channels=6, kernel_size=(3,3))
43+
init_trunc_normal = paddle.nn.initializer.TruncatedNormal()
44+
init_trunc_normal(conv.weight)
45+
```

0 commit comments

Comments
 (0)