|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=invalid-name, unused-argument |
| 18 | +""" |
| 19 | +External compiler related feature registration. |
| 20 | +
|
| 21 | +It implements dispatchers that check if an operator should use a given compiler |
| 22 | +to generate code. |
| 23 | +
|
| 24 | +Each compiler can customize the support of an operator. For example, they can |
| 25 | +check the attribute of the operator and/or the features of the input arguments |
| 26 | +to decide if we should use the compiler for codegen. |
| 27 | +""" |
| 28 | +from __future__ import absolute_import |
| 29 | + |
| 30 | +import logging |
| 31 | +import pkgutil |
| 32 | +from pathlib import Path |
| 33 | +from importlib import import_module |
| 34 | + |
| 35 | +from .. import op as reg |
| 36 | + |
| 37 | +logger = logging.getLogger('AnnotateCompiler') |
| 38 | + |
| 39 | +# Load available contrib compilers |
| 40 | +compilers = {} |
| 41 | +for _, name, _ in pkgutil.iter_modules([Path(__file__).parent]): |
| 42 | + compilers[name] = import_module( |
| 43 | + '.%s' % name, package='.'.join(__name__.split('.')[:-1])) |
| 44 | + |
| 45 | + |
| 46 | +def get_annotate_compiler(compiler, op_name): |
| 47 | + """Get the annotate_compiler function from the registered compilers. |
| 48 | +
|
| 49 | + Parameters |
| 50 | + ---------- |
| 51 | + compiler : Str |
| 52 | + The name of a compiler that is used to generate code. |
| 53 | +
|
| 54 | + op_name : Str |
| 55 | + The name of an operator. |
| 56 | +
|
| 57 | + Returns |
| 58 | + ------- |
| 59 | + ret : bool |
| 60 | + If the operator uses the provided compiler for codegen. |
| 61 | + """ |
| 62 | + if compiler in compilers: |
| 63 | + if hasattr(compilers[compiler], 'annotate_compiler'): |
| 64 | + annotate_compiler = getattr(compilers[compiler], 'annotate_compiler') |
| 65 | + if hasattr(annotate_compiler, op_name): |
| 66 | + return getattr(annotate_compiler, op_name) |
| 67 | + |
| 68 | + logger.warning("%s in %s is not registered. Fallback to CPU", op_name, |
| 69 | + compiler) |
| 70 | + return lambda x, y: False |
| 71 | + |
| 72 | + |
| 73 | +@reg.register_annotate_compiler("nn.conv2d") |
| 74 | +def annotate_conv2d(attrs, args, compiler): |
| 75 | + """Check if the provided compiler should be used for conv2d. |
| 76 | + """ |
| 77 | + return get_annotate_compiler(compiler, 'conv2d')(attrs, args) |
| 78 | + |
| 79 | + |
| 80 | +@reg.register_annotate_compiler("nn.dense") |
| 81 | +def annotate_dense(attrs, args, compiler): |
| 82 | + """Check if the provided compiler should be used for dense. |
| 83 | + """ |
| 84 | + return get_annotate_compiler(compiler, 'dense')(attrs, args) |
| 85 | + |
| 86 | + |
| 87 | +@reg.register_annotate_compiler("nn.relu") |
| 88 | +def annotate_relu(attrs, args, compiler): |
| 89 | + """Check if the provided compiler should be used for relu. |
| 90 | + """ |
| 91 | + return get_annotate_compiler(compiler, 'relu')(attrs, args) |
| 92 | + |
| 93 | + |
| 94 | +@reg.register_annotate_compiler("nn.batch_norm") |
| 95 | +def annotate_batch_norm(attrs, args, compiler): |
| 96 | + """Check if the provided compiler should be used for batch_norm. |
| 97 | + """ |
| 98 | + return get_annotate_compiler(compiler, 'batch_norm')(attrs, args) |
| 99 | + |
| 100 | + |
| 101 | +@reg.register_annotate_compiler("subtract") |
| 102 | +def annotate_subtract(attrs, args, compiler): |
| 103 | + """Check if the provided compiler should be used for subtract. |
| 104 | + """ |
| 105 | + return get_annotate_compiler(compiler, 'subtract')(attrs, args) |
| 106 | + |
| 107 | + |
| 108 | +@reg.register_annotate_compiler("add") |
| 109 | +def annotate_add(attrs, args, compiler): |
| 110 | + """Check if the provided compiler should be used for add. |
| 111 | + """ |
| 112 | + return get_annotate_compiler(compiler, 'add')(attrs, args) |
| 113 | + |
| 114 | + |
| 115 | +@reg.register_annotate_compiler("multiply") |
| 116 | +def annotate_multiply(attrs, args, compiler): |
| 117 | + """Check if the provided compiler should be used for multiply. |
| 118 | + """ |
| 119 | + return get_annotate_compiler(compiler, 'multiply')(attrs, args) |
0 commit comments