Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,4 @@ def is_auto_scheduler_enabled():
return PassContext.current().config.get(
"relay.backend.use_auto_scheduler",
False,
) or PassContext.current().config.get(
"relay.backend.use_meta_schedule",
False,
)
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
search_strategy,
space_generator,
)
from .profiler import Profiler
from .apply_history_best import ApplyHistoryBest
from .extracted_task import ExtractedTask
from .relay_integration import extract_task_from_relay
from .profiler import Profiler
from .relay_integration import extract_task_from_relay, is_meta_schedule_enabled
from .search_strategy import MeasureCandidate
from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir
from .tune_context import TuneContext
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,17 @@ def extract_task_from_relay(
disabled_pass=disabled_pass,
):
return list(extract_task_func(mod, target, relay_params, te_filter_func))


def is_meta_schedule_enabled() -> bool:
"""Return whether the meta-schedule is enabled.

Returns
-------
enabled: bool
Whether the meta schedule is enabled
"""
return transform.PassContext.current().config.get(
"relay.backend.use_meta_schedule",
False,
)
13 changes: 9 additions & 4 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
# specific language governing permissions and limitations
# under the License.
"""Definition of ARM CPU operator strategy."""
import logging

# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import re
import logging

from tvm import relay, topi

from ....auto_scheduler import is_auto_scheduler_enabled
from ....meta_schedule import is_meta_schedule_enabled
from ....target import arm_isa
from ....topi.generic import conv2d as conv2d_generic
from ....auto_scheduler import is_auto_scheduler_enabled
from .generic import *
from .. import op as _op
from .generic import *

logger = logging.getLogger("strategy")

Expand Down Expand Up @@ -477,7 +480,9 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
logger.warning("dense is not optimized for arm cpu.")
strategy.add_implementation(
wrap_compute_dense(
topi.nn.dense, need_auto_scheduler_layout=is_auto_scheduler_enabled()
topi.nn.dense,
need_auto_scheduler_layout=is_auto_scheduler_enabled(),
need_meta_schedule_layout=is_meta_schedule_enabled(),
),
wrap_topi_schedule(topi.generic.schedule_dense),
name="dense.generic",
Expand Down
32 changes: 28 additions & 4 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.contrib import nvcc
from tvm.contrib.thrust import can_use_thrust
from tvm.meta_schedule import is_meta_schedule_enabled
from tvm.te import SpecializedCondition

from .. import op as _op
from ....target import Target
from ....tir import IntImm
from .. import op as _op
from .generic import *


Expand Down Expand Up @@ -251,7 +252,17 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
)

# register auto-scheduler implementations
if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
if (
is_auto_scheduler_enabled() or is_meta_schedule_enabled()
) and judge_winograd_auto_scheduler:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc.winograd",
plevel=15,
)
# register meta-schedule implementations
if is_meta_schedule_enabled() and judge_winograd_auto_scheduler:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
naive_schedule, # this implementation should never be picked by autotvm
Expand Down Expand Up @@ -534,7 +545,14 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
)

if is_auto_scheduler_enabled():
if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc_winograd_without_weight_transform",
plevel=15,
)
if is_meta_schedule_enabled():
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
naive_schedule, # this implementation should never be picked by autotvm
Expand Down Expand Up @@ -805,7 +823,13 @@ def matmul_strategy_cuda(attrs, inputs, out_type, target):
"""Matmul cuda strategy."""
strategy = _op.OpStrategy()

if is_auto_scheduler_enabled():
if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
naive_schedule,
name="matmul.cuda",
)
elif is_meta_schedule_enabled():
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
naive_schedule,
Expand Down
53 changes: 48 additions & 5 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

from tvm import _ffi, ir, te, topi
from tvm.target import generic_func, override_native_generic_func
from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple
from tvm.topi.utils import (
get_const_float,
get_const_int,
get_const_tuple,
get_float_tuple,
)

from .. import op as _op

Expand Down Expand Up @@ -211,6 +216,9 @@ def schedule_bitpack(attrs, outs, target):
get_auto_scheduler_rewritten_layout = _ffi.get_global_func(
"relay.attrs.get_auto_scheduler_rewritten_layout"
)
get_meta_schedule_original_shape = _ffi.get_global_func(
"relay.attrs.get_meta_schedule_original_shape"
)

# conv2d
def wrap_compute_conv2d(
Expand All @@ -219,6 +227,7 @@ def wrap_compute_conv2d(
need_out_layout=False,
has_groups=False,
need_auto_scheduler_layout=False,
need_meta_schedule_layout=False,
):
"""Wrap conv2d topi compute"""

Expand All @@ -240,6 +249,9 @@ def _compute_conv2d(attrs, inputs, out_type):
args.append(out_dtype)
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
elif need_meta_schedule_layout:
args.append("")
args.append(get_meta_schedule_original_shape(attrs))
return [topi_compute(*args)]

return _compute_conv2d
Expand Down Expand Up @@ -530,7 +542,12 @@ def conv3d_transpose_strategy(attrs, inputs, out_type, target):


# conv3d
def wrap_compute_conv3d(topi_compute, need_layout=False, need_auto_scheduler_layout=False):
def wrap_compute_conv3d(
topi_compute,
need_layout=False,
need_auto_scheduler_layout=False,
need_meta_schedule_layout=False,
):
"""wrap conv3d topi compute"""

def _compute_conv3d(attrs, inputs, out_type):
Expand All @@ -552,6 +569,9 @@ def _compute_conv3d(attrs, inputs, out_type):
args.append(out_dtype)
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
elif need_meta_schedule_layout:
args.append("")
args.append(get_meta_schedule_original_shape(attrs))
return [topi_compute(*args)]

return _compute_conv3d
Expand Down Expand Up @@ -782,7 +802,11 @@ def copy_if_identical(tensor_a, tensor_b):


# matmul
def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False):
def wrap_compute_matmul(
topi_compute,
need_auto_scheduler_layout=False,
need_meta_schedule_layout=False,
):
"""wrap matmul topi compute"""

def _compute_matmul(attrs, inputs, out_type):
Expand All @@ -799,6 +823,9 @@ def _compute_matmul(attrs, inputs, out_type):
]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
elif need_meta_schedule_layout:
args.append("")
args.append(get_meta_schedule_original_shape(attrs))
args[1] = copy_if_identical(inputs[0], inputs[1])
return [topi_compute(*args)]

Expand All @@ -819,7 +846,11 @@ def matmul_strategy(attrs, inputs, out_type, target):


# dense
def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
def wrap_compute_dense(
topi_compute,
need_auto_scheduler_layout=False,
need_meta_schedule_layout=False,
):
"""wrap dense topi compute"""

def _compute_dense(attrs, inputs, out_type):
Expand All @@ -829,6 +860,9 @@ def _compute_dense(attrs, inputs, out_type):
args = [inputs[0], inputs[1], None, out_dtype]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
elif need_meta_schedule_layout:
args.append("")
args.append(get_meta_schedule_original_shape(attrs))
args[1] = copy_if_identical(inputs[0], inputs[1])
return [topi_compute(*args)]

Expand Down Expand Up @@ -862,7 +896,13 @@ def dense_pack_strategy(attrs, inputs, out_type, target):


# batch_matmul
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, need_out_dtype=False):
def wrap_compute_batch_matmul(
topi_compute,
*,
need_auto_scheduler_layout=False,
need_meta_schedule_layout=False,
need_out_dtype=False,
):
"""wrap batch_matmul topi compute"""

def _compute_batch_matmul(attrs, inputs, out_type):
Expand All @@ -872,6 +912,9 @@ def _compute_batch_matmul(attrs, inputs, out_type):
args.append(attrs.transpose_b)
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
elif need_meta_schedule_layout:
args.append("")
args.append(get_meta_schedule_original_shape(attrs))
args[1] = copy_if_identical(inputs[0], inputs[1])
return [topi_compute(*args)]

Expand Down
Loading