Skip to content

Commit da8471c

Browse files
authored
【PaddlePaddle Hackathon 4 No.230】为TVM 添加tile、mish、stack、unstack、silu、softshrink、where算子支持 (#402)
* support tile/mish/unstack/silu/softshrink/where op for paddle frontend * add tvm pr
1 parent e8d229a commit da8471c

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
2+
# 在TVM中为paddle框架新增7个不支持的算子
3+
4+
|任务名称 | TVM项目5-为Paddle框架新增TVM算子 |
5+
|---|---|
6+
|提交作者 | 郑学贵 |
7+
|提交时间 | 2023-3-2 |
8+
|版本号 | V0.0 |
9+
|依赖飞桨版本 | v2.4.2 |
10+
|文件名 | add_tvm_op_for_paddle_frontend_0.md |
11+
12+
# 一、方案名称
13+
14+
tvm前端支持paddle算子
15+
16+
# 二、方案描述
17+
18+
tvm前端目前暂不支持paddle框架的`tile``stack``mish``unstack``silu``softshrink``where`算子,需要在tvm前端中适配这些算子,以支撑更多的paddle模型通过tvm进行部署。
19+
20+
# 三、方案流程
21+
22+
## 流程设计
23+
24+
1. 调研paddle中`tile``stack``mish``unstack``silu``softshrink``where`接口的实现,了解具体的计算逻辑和公式
25+
1. 调用并参考paddle2onnx的流程。
26+
2. 在tvm中新增相应的convert函数,对于不支持的算子通过Relay IR组合实现。
27+
3. 根据paddle框架中算子参数的可能情况,构建测试函数,覆盖所有使用场景。
28+
29+
## 算子实现
30+
31+
### 1.tile
32+
33+
tvm relay中也有相应的`tile`函数,因此只需要针对输入进行处理,再调用`_op.tile`即可。输入的`repeat_times`有三种类型:
34+
35+
- Tensor: 存储在`op.input("RepeatTimes")`,需要`infer`常量值。
36+
- list|tuple且元素为Tensor:存储在`op.input("repeat_times_tensor")`,需要逐个`infer`常量值,再拼接。
37+
- list|tuple且元素为整数:存储在`op.attr("repeat_times")`
38+
39+
### 2. mish
40+
41+
激活函数,根据API文档中计算公式,通过`Relay`中已有的`exp``mul``log`等函数组合实现
42+
43+
### 3. stack
44+
45+
通过Relay中已有的`stack`实现
46+
47+
### 4. unstack
48+
49+
Relay中没有实现`unstack`,可以采用`split``squeeze`组合实现
50+
51+
### 5.silu
52+
53+
激活函数,根据API文档中计算公式,通过`Relay`中已有的`sigmoid``mul`函数组合实现
54+
55+
### 6. softshrink
56+
57+
激活函数,根据API文档中计算公式,该函数是个三段的分段函数,可以通过组合`where``add`等函数实现
58+
59+
### 7. where
60+
61+
通过Relay中已有的`where`实现
62+
63+
# 四、方案运行效果
64+
65+
## 测试用例
66+
67+
根据API的参数所有可能的类型进行组合,输入通过随机以及手工构造边界样例生成不同`shape`的Tensor,覆盖所有使用场景。
68+
69+
## 运行结果
70+
71+
paddle框架中`tile``stack``mish``unstack``silu``softshrink``where`算子能够导入tvm并执行,计算结果和paddle框架保持一致。
72+
73+
# 五、项目提交时间计划
74+
75+
3-1日已完成代码,通过单测并提交到tvm [pr地址](https://github.com/apache/tvm/pull/14160)

0 commit comments

Comments
 (0)