Skip to content

Commit ba19c56

Browse files
author
zhangkaihuo
authored
新增一个融合算子:fused_feedforward (#3999)
* 修改错别字 * 备注CPU不支持float16 * update example * add fused_feedforward * add fused_feedforward * add fused_feedforward * opt the description * update docs * update docs * update doc * move fused_feedforward docs position
1 parent ac1220a commit ba19c56

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
.. _cn_api_incubate_nn_functional_fused_feedforward:
2+
3+
fused_feedforward
4+
-------------------------------
5+
6+
.. py:function:: paddle.incubate.nn.functional.fused_feedforward(x, linear1_weight, linear2_weight, linear1_bias=None, linear2_bias=None, ln1_scale=None, ln1_bias=None, ln2_scale=None, ln2_bias=None, dropout1_rate=0.5, dropout2_rate=0.5,activation="relu", ln1_epsilon=1e-5, ln2_epsilon=1e-5, pre_layer_norm=False, name=None):
7+
8+
这是一个融合算子,该算子是对transformer模型中feed forward层的多个算子进行融合,该算子只支持在GPU下运行,该算子与如下伪代码表达一样的功能:
9+
10+
.. code-block:: ipython
11+
12+
residual = src;
13+
if pre_layer_norm:
14+
src = layer_norm(src)
15+
src = linear(dropout(activation(dropout(linear(src)))))
16+
if not pre_layer_norm:
17+
src = layer_norm(out)
18+
19+
参数
20+
:::::::::
21+
- **x** (Tensor) - 输入Tensor,数据类型支持float16, float32 和float64, 输入的形状是`[batch_size, sequence_length, d_model]`。
22+
- **linear1_weight** (Tensor) - 第一个linear算子的权重数据,数据类型与`x`一样,形状是`[d_model, dim_feedforward]`。
23+
- **linear2_weight** (Tensor) - 第二个linear算子的权重数据,数据类型与`x`一样,形状是`[dim_feedforward, d_model]`。
24+
- **linear1_bias** (Tensor, 可选) - 第一个linear算子的偏置数据,数据类型与`x`一样,形状是`[dim_feedforward]`。默认值为None。
25+
- **linear2_bias** (Tensor, 可选) - 第二个linear算子的偏置数据,数据类型与`x`一样,形状是`[d_model]`。默认值为None。
26+
- **ln1_scale** (Tensor, 可选) - 第一个layer_norm算子的权重数据,数据类型可以是float32或者float64,形状和`x`一样。默认值为None。
27+
- **ln1_bias** (Tensor, 可选) - 第一个layer_norm算子的偏置数据,数据类型和`ln1_scale`一样, 形状是`[d_model]`。默认值为None。
28+
- **ln2_scale** (Tensor, 可选) - 第二个layer_norm算子的权重数据,数据类型可以是float32或者float64,形状和`x`一样。默认值为None。
29+
- **ln2_bias** (Tensor, 可选) - 第二个layer_norm算子的偏置数据,数据类型和`ln2_scale`一样, 形状是`[d\_model]`。默认值为None。
30+
- **dropout1_rate** (float, 可选) - 第一个dropout算子置零的概率。默认是0.5。
31+
- **dropout2_rate** (float, 可选) - 第二个dropout算子置零的概率。默认是0.5。
32+
- **activation** (string, 可选) - 激活函数。默认值是relu。
33+
- **ln1_epsilon** (float, 可选) - 一个很小的浮点数,被第一个layer_norm算子加到分母,避免出现除零的情况。默认值是1e-5。
34+
- **ln2_epsilon** (float, 可选) - 一个很小的浮点数,被第二个layer_norm算子加到分母,避免出现除零的情况。默认值是1e-5。
35+
- **pre_layer_norm** (bool, 可选) - 在预处理阶段加上layer_norm,或者在后处理阶段加上layer_norm。默认值是False。
36+
- **name** (string, 可选) – fused_feedforward的名称, 默认值为None。更多信息请参见 :ref:`api_guide_Name` 。
37+
38+
返回
39+
:::::::::
40+
- Tensor, 输出Tensor,数据类型与`x`一样。
41+
42+
代码示例
43+
::::::::::
44+
45+
.. code-block:: python
46+
47+
# required: gpu
48+
import paddle
49+
import numpy as np
50+
x_data = np.random.random((1, 8, 8)).astype("float32")
51+
linear1_weight_data = np.random.random((8, 8)).astype("float32")
52+
linear2_weight_data = np.random.random((8, 8)).astype("float32")
53+
x = paddle.to_tensor(x_data)
54+
linear1_weight = paddle.to_tensor(linear1_weight_data)
55+
linear2_weight = paddle.to_tensor(linear2_weight_data)
56+
out = paddle.incubate.nn.functional.fused_feedforward(x, linear1_weight, linear2_weight)
57+
print(out.numpy().shape)
58+
# (1, 8, 8)
59+

0 commit comments

Comments
 (0)