|
6 | 6 | * to you under the Apache License, Version 2.0 (the |
7 | 7 | * "License"); you may not use this file except in compliance |
8 | 8 | * with the License. You may obtain a copy of the License at |
9 | | - * |
| 9 | + * |
10 | 10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | | - * |
| 11 | + * |
12 | 12 | * Unless required by applicable law or agreed to in writing, |
13 | 13 | * software distributed under the License is distributed on an |
14 | 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
|
18 | 18 | */ |
19 | 19 |
|
20 | 20 | /*! |
21 | | - * Copyright (c) 2018 by Contributors |
| 21 | + * Copyright (c) 2019 by Contributors |
22 | 22 | * \file src/tvm/relay/expr_mutator.cc |
23 | 23 | * \brief A wrapper around ExprFunctor which functionally updates the AST. |
24 | 24 | * |
25 | 25 | * ExprMutator uses memoization and self return in order to amortize |
26 | 26 | * the cost of using functional updates. |
27 | 27 | */ |
| 28 | +#include <tvm/relay/analysis.h> |
28 | 29 | #include <tvm/relay/expr_functor.h> |
29 | 30 | #include "type_functor.h" |
30 | 31 |
|
@@ -400,11 +401,27 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { |
400 | 401 | new_params.size() == func->params.size()) { |
401 | 402 | return expr; |
402 | 403 | } |
403 | | - return FunctionNode::make(new_params, |
404 | | - new_body, |
405 | | - func->ret_type, |
406 | | - func->type_params, |
407 | | - func->attrs); |
| 404 | + auto ret = FunctionNode::make(new_params, |
| 405 | + new_body, |
| 406 | + func->ret_type, |
| 407 | + func->type_params, |
| 408 | + func->attrs); |
| 409 | + std::unordered_set<Var, NodeHash, NodeEqual> set; |
| 410 | + for (const auto& v : FreeVars(expr)) { |
| 411 | + set.insert(v); |
| 412 | + } |
| 413 | + for (const auto& v : FreeVars(ret)) { |
| 414 | + if (set.count(v) == 0) { |
| 415 | + new_params.push_back(v); |
| 416 | + } |
| 417 | + } |
| 418 | + ret = FunctionNode::make(new_params, |
| 419 | + new_body, |
| 420 | + func->ret_type, |
| 421 | + func->type_params, |
| 422 | + func->attrs); |
| 423 | + CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); |
| 424 | + return ret; |
408 | 425 | } else { |
409 | 426 | return ExprBinder(args_map).Mutate(expr); |
410 | 427 | } |
|
0 commit comments