每一种硬件对应一门特定的编程语言,再通过特定的编译器去进行编译产生机器码,那随着硬件和语言的增多,编译器的维护难度会有很大困难。现代编译器已经解决了这个问题。
为了解决这个问题,科学家为编译器抽象出来了编译前端/编译中端/编译后端等概念,并引入IR(Intermediate Representation)。解释如下:
深度神经网络编译器受到编译器的启发,将各种深度学习模型传入到深度学习编译器之后吐出IR,深度学习的IR其实就是计算图,可以直接叫做Graph IR,然后将Graph IR经过计算图的优化操作再吐出IR分发到各种硬件使用。
类比到深度学习编译器上,如下图
TVM:一个基于编译优化思想的推理框架
其中NNVM目前已经更新到了Relay, 是NNVM的进阶版本,同时具有编程语言的特点和深度学习图构造的能力,借助 TVM 代码生成工具以及 TOPI 中丰富的算子库,可以完成一系列深度学习编译部署工作。
本例中采用onnx导入深度学习模型,
onnx导入resnet18.onnx模型
import onnx import numpy as np import tvm from tvm import te import tvm.relay as relay onnx_model = onnx.load('resnet18.onnx') from PIL import Image image_path = 'cat.png' img = Image.open(image_path).resize((224, 224)) # Preprocess the image and convert to tensor from torchvision import transforms my_preprocess = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) img = my_preprocess(img) x = np.expand_dims(img, 0)
接下来便需要将TVM的Relay将ONNX模型变成TVM可以识别的Graph IR,TVM在Realy中提供了一个frontend.from_onnx用来加载ONNX模型并转换为Relay IR。
# 这里设置了target表示我们要在CPU后端运行Realy IR target = "llvm" input_name = "input.1" shape_dict = {input_name: x.shape} mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
打印输出mod:
经过 relay.frontend.from_onnx
函数接口调用,便可将onnx模型转换成IR,另可通过打印Mod,便可看到mod是一个Relay Function函数,函数输入的是ONNX模型中输入Tensor的shape信息。接下来看下,TVM如何将ONNX转换成Relay IR
from_onnx的执行流程如下:
relay.frontend.from_onnx
函数是在tvm/python/tvm/relay/frontend/onnx.py中,具体实现如下(参考:https://zhuanlan.zhihu.com/p/365800737):
def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=False): """将一个ONNX模型转换成一个等价的Relay函数. ONNX Graph被一个Python的Protobuf对象来表示,伴随着的参数将被自动处理。 然而,ONNX Graph的输入名称是模糊的,混淆了输入和网络权重/偏差,如“1”,“2”。。。 为方便起见,我们将“real”输入名重命名为“input_0”,“input_1”... 并将参数重命名为“param_0”、“param_1”... 默认情况下,ONNX根据动态形状定义模型。 ONNX导入器在导入时会保留这种动态性,并且编译器会在编译 时尝试将模型转换为静态形状。 如果失败,则模型中可能仍存在动态操作。 并非所有的TVM kernels当前 都支持动态形状,如果在使用动态kernels时遇到错误,请在ask.tvm.apache.org上提出问题。 参数 ---------- model : protobuf 对象 ONNX ModelProto after ONNX v1.1.0 shape : str为key,tuple为value的字典, 可选 计算图的输入shape dtype : str or dict of str to str 计算图的输入shapes(可能有多个输入,所以可能是str,也可能是字典) opset : int, 可选 覆盖自动检测的算子集合。 对于一些测试是有用的。 freeze_params: bool If this parameter is true, the importer will take any provided onnx input values (weights, shapes, etc) and embed them into the relay model as Constants instead of variables. This allows more aggressive optimizations at compile time and helps in making models static if certain inputs represent attributes relay would traditionally consider compile-time constants. 这段话简单来说就是一旦打开freeze_params这个参数,通过ONNX产生的Relay IR就会把所有可能提 供的输入,包括权重,shape都以常量的方式嵌入到Relay IR中。这有助于编译时优化和产生静态模型。 返回 ------- mod : tvm.IRModule 用于编译的Realy IR params : dict of str to tvm.nd.NDArray Relay使用的参数字典,存储权重 """ try: import onnx if hasattr(onnx.checker, "check_model"): # try use onnx's own model checker before converting any model try: onnx.checker.check_model(model) except Exception as e: # pylint: disable=c-extension-no-member, broad-except # the checker is a bit violent about errors, so simply print warnings here warnings.warn(str(e)) except ImportError: pass # 一个从pb2.GraphProto复制的helper class,用于处理Relay IR。 g = GraphProto(shape, dtype, freeze_params) # ONNX模型的GraphProto graph = model.graph if opset is None: try: opset = model.opset_import[0].version if model.opset_import else 1 except AttributeError: opset = 1 # Use the graph proto as a scope so that ops can access other nodes if needed. with g: mod, params = g.from_onnx(graph, opset) return mod, params
其中,freeze_params参数,参数为true,那编译出来的静态模型只能处理用户指定shape的模型。如一个全卷积网络,原本可以输入任意分辨率,但如果用户指定了的(224,224)分辨率进行构建Realy IR并将这个参数设置为True,那么在模型推理时如果接收了非长宽的图片就会抛出异常。
Relay IR在接收到ONNX模型后新建了一个GraphProto对象用来管理ONNX模型的OP转换以及生成Relay IR.其中核心函数就是g.from_onnx(graph, opset)
.该函数的就具体实现:
def from_onnx(self, graph, opset, get_output_expr=False): """Construct Relay expression from ONNX graph. Onnx graph is a python protobuf object. The companion parameters will be handled automatically. However, the input names from onnx graph is vague, mixing inputs and network weights/bias such as "1", "2"... For convenience, we rename the `real` input names to "input_0", "input_1"... And renaming parameters to "param_0", "param_1"... Parameters ---------- graph : onnx protobuf object The loaded onnx graph opset : opset version get_output_expr: bool If set to true, this conversion will return each output expression rather than a packaged module. This can be useful when converting subgraphs to relay. Returns ------- mod : tvm.IRModule The returned relay module params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ self.opset = opset self._parse_graph_initializers(graph) self._parse_graph_input(graph) self._check_user_inputs_in_outermost_graph_scope() self._check_for_unsupported_ops(graph) self._construct_nodes(graph) # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) # If requested, directly return the converted expressions. if get_output_expr: return outputs ## Maintain the order of inputs and parameters from the ONNX graph, but only include ## those parameters that are needed to execute the relay graph free_vars = analysis.free_vars(outputs) nodes = {v: k for k, v in self._nodes.items()} free_vars = [nodes[var] for var in free_vars] for i_name in self._params: if i_name in free_vars and i_name not in self._inputs: self._inputs[i_name] = self._nodes[i_name] # Create a function from our output expression and all input variables. func = _function.Function([v for k, v in self._inputs.items()], outputs) return IRModule.from_expr(func), self._params
该函数中实现了ONNX graph转换为Relay表达。
首先_parse_graph_initializers(graph)
解析网络的输入到relay中, 又叫参数,onnx的initializer就是用来保存模型参数的
_parse_graph_input
解析模型的输入
然后就是校验用户最外层输入的尺寸_check_user_inputs_in_outermost_graph_scope
,检查不支持的op算子_check_for_unsupported_ops
,如果有不支持的算子,将会抛出异常
如果没有抛出异常,则说明onnx模型中的算子都被Relay支持,接下来就可以正常转换了,即最主要的函数_construct_nodes
def _construct_nodes(self, graph): """Nodes are stored as directed acyclic graph.""" for node in graph.node: op_name = node.op_type attr = self._parse_attr(node.attribute) # Create and populate input list.创建并填充onnx输入对象 inputs = onnx_input() for i in node.input: if i != "": #self._renames.get(i, i)用来获取ONNX Graph每个节点的输入 inputs.append(self._nodes[self._renames.get(i, i)]) else: inputs.append(None) i_name = self._parse_value_proto(node) node_output = self._fix_outputs(op_name, node.output) attr["tvm_custom"] = {} attr["tvm_custom"]["name"] = i_name attr["tvm_custom"]["num_outputs"] = len(node_output) op = self._convert_operator(op_name, inputs, attr, self.opset) if not isinstance(op, _expr.TupleWrapper): outputs_num = 1 else: outputs_num = len(op) if outputs_num == 1: op = fold_constant(op) else: op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) if outputs_num > 1: # ONNX supports optional outputs for some nodes. # This block searches for missing outputs in the ONNX graph # and removes any unneeded ops valid_outputs = [False] * outputs_num for i, output in enumerate(node_output): if output != "": valid_outputs[i] = True # If we have outputs ONNX isn't expecting, we need to drop them if not all(valid_outputs): tup = op.astuple() # TupleWrapper can also wrap ops with TupleType outputs if isinstance(tup, _expr.Tuple): # For tuples, we extract the fields instead of using GetTupleItem outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid] else: # For call nodes, we need to GetTupleItem outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid] # Create the new op with valid outputs if len(outputs) == 1: op = outputs[0] elif len(outputs) != outputs_num: op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs)) # Drop invalid outputs for the onnx node outputs_num = len(outputs) node_output = [output for output in node_output if output != ""] assert ( len(node_output) == outputs_num ), "Number of output mismatch {} vs {} in {}.".format( len(node_output), outputs_num, op_name ) if outputs_num == 1: self._nodes[node_output[0]] = op else: for k, i in zip(list(node_output), range(len(node_output))): self._nodes[k] = op[i]
其中,执行转换操作的是_convert_operator
def _convert_operator(self, op_name, inputs, attrs, opset): """Convert ONNX operator into a Relay operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters ---------- op_name : str Operator name, such as Convolution, FullyConnected inputs : list of tvm.relay.function.Function List of inputs. attrs : dict Dict of operator attributes opset : int Opset version Returns ------- sym : tvm.relay.function.Function Converted relay function """ convert_map = _get_convert_map(opset) if op_name in _identity_list: sym = get_relay_op(op_name)(*inputs, **attrs) elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs, self._params) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) return sym
通过_get_convert_map
函数获取ONNX特定的Opset Version中被TVM支持的OP字典,字典的Key是ONNX OP的类型名字,而字典的Value就是转换之后的Relay IR
。
# _convert_map defines maps of name to converter functor(callable) # for 1 to 1 mapping, use Renamer if nothing but name is different # use AttrCvt if attributes need to be converted # for 1 to N mapping(composed), use custom callable functions # for N to 1 mapping, currently not supported(?) def _get_convert_map(opset): return { # defs/experimental "Identity": Renamer("copy"), "Affine": Affine.get_converter(opset), "BitShift": BitShift.get_converter(opset), "ThresholdedRelu": ThresholdedRelu.get_converter(opset), "ScaledTanh": ScaledTanh.get_converter(opset), "ParametricSoftplus": ParametricSoftPlus.get_converter(opset), "Constant": Constant.get_converter(opset), "ConstantOfShape": ConstantOfShape.get_converter(opset), # 'GivenTensorFill' "FC": AttrCvt("dense", ignores=["axis", "axis_w"]), "Scale": Scale.get_converter(opset), # 'GRUUnit' # 'ATen' # 'ImageScaler' "MeanVarianceNormalization": MeanVarianceNormalization.get_converter(opset), # 'Crop' # 'Embedding' "Upsample": Upsample.get_converter(opset), "SpatialBN": BatchNorm.get_converter(opset), # defs/generator # 'Constant' # Implemented # 'RandomUniform' # 'RandomNormal' # 'RandomUniformLike' # 'RandomNormalLike' # defs/logical # defs/math "Add": Add.get_converter(opset), "Sub": Sub.get_converter(opset), "Mul": Mul.get_converter(opset), "Div": Div.get_converter(opset), "Neg": Renamer("negative"), "Abs": Absolute.get_converter(opset), "Reciprocal": Reciprocal.get_converter(opset), "Floor": Renamer("floor"), "Ceil": Renamer("ceil"), "Round": Round.get_converter(opset), "IsInf": IsInf.get_converter(opset), "IsNaN": Renamer("isnan"), "Sqrt": Renamer("sqrt"), "Relu": Renamer("relu"), "Celu": Celu.get_converter(opset), "LeakyRelu": Renamer("leaky_relu"), "Selu": Selu.get_converter(opset), "Elu": Elu.get_converter(opset), "Gelu": Gelu.get_converter(opset), "BiasGelu": BiasGelu.get_converter(opset), # TODO: We need a better way to handle different domains, in case # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention # are in the `com.microsoft` domain. "EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset), "SkipLayerNormalization": SkipLayerNormalization.get_converter(opset), "Attention": Attention.get_converter(opset), "Exp": Renamer("exp"), "Greater": Renamer("greater"), "GreaterOrEqual": Renamer("greater_equal"), "Less": Renamer("less"), "LessOrEqual": Renamer("less_equal"), "Log": Renamer("log"), "Acos": Renamer("acos"), "Acosh": Renamer("acosh"), "Asin": Renamer("asin"), "Asinh": Renamer("asinh"), "Atan": Renamer("atan"), "Atanh": Renamer("atanh"), "Cos": Renamer("cos"), "Cosh": Renamer("cosh"), "Sin": Renamer("sin"), "Sinh": Renamer("sinh"), "Tan": Renamer("tan"), "Tanh": Renamer("tanh"), "Pow": Pow.get_converter(opset), "PRelu": Prelu.get_converter(opset), "Sigmoid": Renamer("sigmoid"), "HardSigmoid": HardSigmoid.get_converter(opset), "HardSwish": HardSwish.get_converter(opset), "Max": Maximum.get_converter(opset), "Min": Minimum.get_converter(opset), "Sum": Sum.get_converter(opset), "Mean": Mean.get_converter(opset), "Clip": Clip.get_converter(opset), "Softplus": Softplus.get_converter(opset), # softmax default axis is different in onnx "Softmax": Softmax.get_converter(opset), "LogSoftmax": LogSoftmax.get_converter(opset), "OneHot": OneHot.get_converter(opset), "Hardmax": Hardmax.get_converter(opset), "Shrink": Shrink.get_converter(opset), "Softsign": Softsign.get_converter(opset), "Gemm": Gemm.get_converter(opset), "MatMul": MatMul.get_converter(opset), "MatMulInteger": MatMulInteger.get_converter(opset), "MatMulInteger16": MatMulInteger16.get_converter(opset), "Mod": Mod.get_converter(opset), "Xor": Renamer("logical_xor"), # defs/nn "AveragePool": AveragePool.get_converter(opset), "LpPool": LpPool.get_converter(opset), "GlobalLpPool": GlobalLpPool.get_converter(opset), "MaxPool": MaxPool.get_converter(opset), "MaxUnpool": MaxUnpool.get_converter(opset), "Conv": Conv.get_converter(opset), "ConvTranspose": ConvTranspose.get_converter(opset), "GlobalAveragePool": GlobalAveragePool.get_converter(opset), "GlobalMaxPool": GlobalMaxPool.get_converter(opset), "BatchNormalization": BatchNorm.get_converter(opset), "InstanceNormalization": InstanceNorm.get_converter(opset), # 'LpNormalization' "Dropout": AttrCvt("dropout", {"ratio": "rate"}, ignores=["is_test"]), "Flatten": Flatten.get_converter(opset), "LRN": LRN.get_converter(opset), # Recurrent Layers "LSTM": LSTM.get_converter(opset), "GRU": GRU.get_converter(opset), # defs/vision "MaxRoiPool": MaxRoiPool.get_converter(opset), "RoiAlign": RoiAlign.get_converter(opset), "NonMaxSuppression": NonMaxSuppression.get_converter(opset), # defs/reduction "ReduceMax": ReduceMax.get_converter(opset), "ReduceMin": ReduceMin.get_converter(opset), "ReduceSum": ReduceSum.get_converter(opset), "ReduceMean": ReduceMean.get_converter(opset), "ReduceProd": ReduceProd.get_converter(opset), "ReduceLogSumExp": ReduceLogSumExp.get_converter(opset), "ReduceLogSum": ReduceLogSum.get_converter(opset), "ReduceSumSquare": ReduceSumSquare.get_converter(opset), "ReduceL1": ReduceL1.get_converter(opset), "ReduceL2": ReduceL2.get_converter(opset), # defs/sorting "ArgMax": ArgMax.get_converter(opset), "ArgMin": ArgMin.get_converter(opset), "TopK": TopK.get_converter(opset), # defs/tensor "Cast": Cast.get_converter(opset), "Reshape": Reshape.get_converter(opset), "Expand": Expand.get_converter(opset), "Concat": Concat.get_converter(opset), "Split": Split.get_converter(opset), "Slice": Slice.get_converter(opset), "Transpose": AttrCvt("transpose", {"perm": "axes"}), "DepthToSpace": DepthToSpace.get_converter(opset), "SpaceToDepth": SpaceToDepth.get_converter(opset), "Gather": Gather.get_converter(opset), "GatherElements": GatherElements.get_converter(opset), "GatherND": GatherND.get_converter(opset), "Compress": Compress.get_converter(opset), "Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}), "Scatter": Scatter.get_converter(opset), "ScatterElements": Scatter.get_converter(opset), "ScatterND": ScatterND.get_converter(opset), "EyeLike": EyeLike.get_converter(opset), "Squeeze": Squeeze.get_converter(opset), "Unsqueeze": Unsqueeze.get_converter(opset), "Pad": Pad.get_converter(opset), "Shape": Shape.get_converter(opset), "Sign": Sign.get_converter(opset), "Equal": Equal.get_converter(opset), "Not": Not.get_converter(opset), "And": And.get_converter(opset), "Tile": Tile.get_converter(opset), "Erf": Erf.get_converter(opset), "Where": Where.get_converter(opset), "Or": Or.get_converter(opset), "Resize": Resize.get_converter(opset), "NonZero": NonZero.get_converter(opset), "Range": Range.get_converter(opset), "CumSum": CumSum.get_converter(opset), "Unique": Unique.get_converter(opset), "Einsum": Einsum.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), # Torch ATen Dispatcher. "ATen": ATen.get_converter(opset), # Quantization "QuantizeLinear": QuantizeLinear.get_converter(opset), "DequantizeLinear": DequantizeLinear.get_converter(opset), "DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset), "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), "QLinearConcat": QLinearConcat.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), "QLinearMatMul": QLinearMatMul.get_converter(opset), "QLinearMul": QLinearMul.get_converter(opset), "QLinearSigmoid": QLinearSigmoid.get_converter(opset), "ConvInteger": ConvInteger.get_converter(opset), "QLinearAveragePool": QLinearAveragePool.get_converter(opset), "QLinearGlobalAveragePool": QLinearGlobalAveragePool.get_converter(opset), "QLinearLeakyRelu": QLinearLeakyRelu.get_converter(opset), # Random number generation. "RandomNormal": RandomNormal.get_converter(opset), "RandomNormalLike": RandomNormalLike.get_converter(opset), "RandomUniform": RandomUniform.get_converter(opset), "RandomUniformLike": RandomUniformLike.get_converter(opset), # Loss functions / training "NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset), "SoftmaxCrossEntropyLoss": SoftmaxCrossEntropyLoss.get_converter(opset), "Adagrad": Adagrad.get_converter(opset), "Adam": Adam.get_converter(opset), "Momentum": Momentum.get_converter(opset), "Scan": Scan.get_converter(opset), # ML "LinearRegressor": LinearRegressor.get_converter(opset), # Sequence operators "SequenceConstruct": SequenceConstruct.get_converter(opset), "SequenceInsert": SequenceInsert.get_converter(opset), "ConcatFromSequence": ConcatFromSequence.get_converter(opset), }
以卷积层为例来看看ONNX的OP是如何被转换成Relay表达式的。卷积OP一般有输入,权重,偏置这三个项,对应了下面函数中的inputs[0],inputs[1],inputs[2]。
而auto_pad这个属性是ONNX特有的属性,TVM的Relay 卷积OP不支持这种属性,所以需要将ONNX 卷积OP需要Pad的数值计算出来并分情况进行处理(这里有手动对输入进行Pad以及给Relay的卷积OP增加一个padding参数两种做法,具体问题具体分析)。然后需要注意的是在这个转换函数中inputs[0]是Relay IR,而不是真实的数据,我们可以通过打印下面代码中的inputs[0]看到。
class Conv(OnnxOpConverter): """Operator converter for Conv.""" @classmethod def _impl_v1(cls, inputs, attr, params): # Use shape of input to determine convolution type. data = inputs[0] kernel = inputs[1] input_shape = infer_shape(data) ndim = len(input_shape) kernel_type = infer_type(inputs[1]) kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] if "kernel_shape" not in attr: attr["kernel_shape"] = kernel_shapes[0][2:] if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): # Warning: Convolution does not yet support dynamic shapes, # one will need to run dynamic_to_static on this model after import data = autopad( data, attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": attr["pads"] = [0 for i in range(ndim - 2)] elif attr["auto_pad"] == "NOTSET": pass else: msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) attr.pop("auto_pad") attr["channels"] = kernel_shapes[0][0] out = AttrCvt( op_name=dimension_picker("conv"), transforms={ "kernel_shape": "kernel_size", "dilations": ("dilation", 1), "pads": ("padding", 0), "group": ("groups", 1), }, custom_check=dimension_constraint(), )([data, kernel], attr, params) use_bias = len(inputs) == 3 if use_bias: out = _op.nn.bias_add(out, inputs[2]) return out
然后这个函数里面还有一个AttrCvt类,是用来做属性转换的,即上面提到的将ONNX Graph中OP的属性对应转换到TVM Relay的OP属性。最后如果卷积层有Bias,则使用_op.nn.bias_add将Bias加上去,注意这个OP返回的仍然是一个Relay表达式。
其它的OP处理类似卷积OP,做完所有ONNX的OP 一对一转换之后我们就可以获得第二节中的Relay IR和权重参数了,即这行代码:mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
现在我们已经知道TVM是如何将ONNX转换成Realy IR的了,那么如果我们在适配自定义模型的时候某些OP TVM还不支持怎么办?这个时候就需要我们自定义OP了,自定义OP的方式可以是基于已有的OP进行拼接,也可以在TVM中独立实现这个OP,然后再在前端新增转换接口。这里以SeLU为例简单介绍新增OP需要做什么?
首先我们需要实现一个SeLU Class
,这个类继承了OnnxOpConverter
,然后实现_impl_v1
方法,代码如下:
class Selu(OnnxOpConverter): """Operator converter for Selu.""" @classmethod def _impl_v1(cls, inputs, attr, params): alpha = float(attr.get("alpha", 1.6732)) gamma = float(attr.get("gamma", 1.0507)) return _expr.const(gamma) * ( _expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]) )
可以看到这里是基于一些常用的算子按照SeLU的公式来拼出这个OP,在实现了这个转换逻辑之后,我们需要将这个OP注册到_convert_map中,即在_get_convert_map新增一行:"Selu": Selu.get_converter(opset),,然后保存源码重新编译TVM即可。这里新增SeLU类继承的OnnxOpConverter类实现如下:
class OnnxOpConverter(object): """A helper class for holding onnx op converters.""" @classmethod def get_converter(cls, opset): """获取匹配给定的算子集合的转换器 Parameters ---------- opset: int opset from model. Returns ------- converter, which should be `_impl_vx`. Number x is the biggest number smaller than or equal to opset belongs to all support versions. """ # 这里的_impl_v_xxx方法是每个OP的具体实现方法,xxx代表版本,对应ONNX的Opset Version versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d] versions = sorted(versions + [opset]) version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1] if hasattr(cls, "_impl_v{}".format(version)): return getattr(cls, "_impl_v{}".format(version)) raise NotImplementedError( "opset version {} of {} not implemented".format(version, cls.__name__) )
重新编译完即可以对我们自定义的模型进行部署。
有了Relay IR中的模型和参数,接下来便可进行编译
target = tvm.target.Target("llvm", host="llvm") dev = tvm.cpu(0) with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params)
这几行代码展示了TVM的编译流程,在这个 编译流程中不仅包含了基于Relay IR进行的优化策略来去除冗余的算子(也叫Pass)还包含了将Relay程序编译成特定后端(这里是llvm)可以执行的代码(codegen)。
relay.build(mod, target=target, params=params)
便会进入TVM的编译流程。
这里的mod
和prams
分别代表模型的图结构和权重参数,relay.build函数定义在tvm/python/tvm/relay/build_module.py这个函数中,入口代码如下:
@register_func("tvm.relay.build") def _build_module_no_factory_impl(mod, target, target_host, params, mod_name): return build( mod, target=target, target_host=target_host, params=params, mod_name=mod_name ).module
target为llvm代表这个模型会被TVM编译成CPU的可执行程序。
PS:
register_func()
为注册全局函数,便可访问到tvm.relay.build函数,具体实现如下:def register_func(func_name, f=None, override=False): """Register global function Parameters ---------- func_name : str or function The function name f : function, optional The function to be registered. override: boolean optional Whether override existing entry. Returns ------- fregister : function Register function if f is not specified. Examples -------- The following code registers my_packed_func as global function. Note that we simply get it back from global function table to invoke it from python side. However, we can also invoke the same function from C++ backend, or in the compiled TVM code. .. code-block:: python targs = (10, 10.0, "hello") @tvm.register_func def my_packed_func(*args): assert(tuple(args) == targs) return 10 # Get it out from global function table f = tvm.get_global_func("my_packed_func") assert isinstance(f, tvm.PackedFunc) y = f(*targs) assert y == 10 """ if callable(func_name): f = func_name func_name = f.__name__ if not isinstance(func_name, str): raise ValueError("expect string function name") ioverride = ctypes.c_int(override) def register(myf): """internal register function""" if not isinstance(myf, PackedFuncBase): myf = convert_to_tvm_func(myf) check_call(_LIB.TVMFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride)) return myf if f: return register(f) return register
接着看build函数,代码如下:
def build( ir_mod, target=None, target_host=None, executor=Executor("graph"), runtime=Runtime("cpp"), workspace_memory_pools=None, constant_memory_pools=None, params=None, mod_name="default", ): # fmt: off # pylint: disable=line-too-long """Helper function that builds a Relay function to run on TVM graph executor. Parameters ---------- ir_mod : :py:class:`~tvm.IRModule` The IR module to build. Using relay.Function is deprecated. target : None, or any multi-target like object, see Target.canon_multi_target For homogeneous compilation, the unique build target. For heterogeneous compilation, a dictionary or list of possible build targets. Defaults to the current target in the environment if None. target_host : None, or any target like object, see Target.canon_target Host compilation target, if target is device. executor : Optional[Executor] The executor configuration with which to build the model. Defaults to "graph" if no executor specified. runtime : Optional[Runtime] Runtime configuration to use when building the model. Defaults to "cpp" if no runtime specified. workspace_memory_pools : Optional[WorkspaceMemoryPools] The object that contains an Array of WorkspacePoolInfo objects that hold properties of read-write workspace pools that could be used by the inference. constant_memory_pools : Optional[ConstantMemoryPools] The object that contains an Array of ConstantPoolInfo objects that hold properties of read-only pools that could be used by the inference. params : dict of str to NDArray Input parameters to the graph that do not change during inference time. Used for constant folding. mod_name: Optional[str] The module name we will build Returns ------- factory_module : tvm.relay.backend.executor_factory.ExecutorFactoryModule The runtime factory for the TVM graph executor. """ # pylint: enable=line-too-long # fmt: on if not isinstance(ir_mod, (IRModule, _function.Function)): raise ValueError("Type of input parameter mod must be tvm.IRModule") if isinstance(ir_mod, _function.Function): if params: ir_mod = bind_params_by_name(ir_mod, params) # 将relay.Function与params组合 ir_mod = IRModule.from_expr(ir_mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " "instead of deprecated parameter mod (tvm.relay.function.Function)", DeprecationWarning, ) raw_targets = Target.canon_multi_target_and_host(Target.target_or_current(target), target_host) #检查和更新目标设备类型target和target对应的host端类型 assert len(raw_targets) > 0 target_host = raw_targets[0].host # All of this logic is to raise deprecation warnings for various parameters # TODO(Mousius) Remove these after some time deprecated_params_target = target_host or list(raw_targets)[0] deprecated_executor, deprecated_runtime = _reconstruct_from_deprecated_options( deprecated_params_target ) if deprecated_executor: executor = deprecated_executor if deprecated_runtime: runtime = deprecated_runtime # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): tophub_context = autotvm.tophub.context(list(raw_targets)) #寻找是否有AutoTVM预先fintune的记录,如果没有就使用autotvm.FallbackContext else: tophub_context = autotvm.utils.EmptyContext() with tophub_context: bld_mod = BuildModule() # 构建BuildModule对象 # BuildModule对象调用build函数,生成硬件可以执行的更底层IR graph_json, runtime_mod, params = bld_mod.build( mod=ir_mod, target=raw_targets, params=params, executor=executor, runtime=runtime, workspace_memory_pools=workspace_memory_pools, constant_memory_pools=constant_memory_pools, mod_name=mod_name, ) func_metadata = bld_mod.get_function_metadata() devices = bld_mod.get_devices() lowered_ir_mods = bld_mod.get_irmodule() executor_codegen_metadata = bld_mod.get_executor_codegen_metadata() if executor.name == "aot": executor_factory = _executor_factory.AOTExecutorFactoryModule( ir_mod, lowered_ir_mods, raw_targets, executor, runtime, runtime_mod, mod_name, params, func_metadata, executor_codegen_metadata, devices, ) elif executor.name == "graph": executor_factory = _executor_factory.GraphExecutorFactoryModule( ir_mod, raw_targets, executor, graph_json, runtime_mod, mod_name, params, func_metadata, ) else: assert False, "Executor " + executor + " not supported" return executor_factory
在上面的函数中,首先将relay.Function
和params
组织成一个IRModule
待用,并且再次检查和更新目标设备类型target和target对应的host端类型。接下来,Relay会寻找是否有AutoTVM预先Fintune的记录,如果没有那么就使用autotvm.FallbackContext
这个环境上下文信息,如果有那么接下来的所有操作都在tophub_context 的 scope 之下(with tophub_context:)。值得一提的是 Relay考虑了异构情景下的代码生成,用户可以指定多个生成代码的目标(target)。
在with tophub_context:
中,创建了一个BuildModule
对象bld_mod
,然后调用了bld_mod
对象的build
函数生成一个硬件可以执行的更底层的IR,以及包含各种必需运行时库的tvm.Module和优化后的计算图的参数。这里还有一个_executor_factory.GraphExecutorFactoryModule
函数,它的功能就是将上面的IR,运行时库以及参数打包成一个tvm.Module,这样用户只需要把这个tvm.Module存下来,下次就可以省去编译过程直接在硬件上执行了。
TVM编译Relay IR的核心实现应该就是BuildModule
类中的build函数,接着看源码:
class BuildModule(object): """Build an IR module to run on TVM graph executor. This class is used to expose the `RelayBuildModule` APIs implemented in C++. """ def __init__(self): self.mod = _build_module._BuildModule() self._get_graph_json = self.mod["get_graph_json"] self._get_module = self.mod["get_module"] self._build = self.mod["build"] self._optimize = self.mod["optimize"] self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] self._get_function_metadata = self.mod["get_function_metadata"] self._get_executor_codegen_metadata = self.mod["get_executor_codegen_metadata"] self._get_devices = self.mod["get_devices"] self._get_irmodule = self.mod["get_irmodule"]
在构建BuildModule对象时,会调用BuildModule中的__init__
函数,通过self.mod = _build_module._BuildModule()
获取对应的C++函数,其中对_build_module对应/tvm/python/tvm/relay/_build_module.py, 具体实现如下:
"""The interface for building Relay functions exposed from C++.""" import tvm._ffi tvm._ffi._init_api("relay.build_module", __name__)
可看到,其实它就是一个接口,将函数名_BuildModule()传给_init_api(), _init_api()函数在/tvm/python/tvm/_ffi/**registry.py **中,具体实现如下
def _init_api(namespace, target_module_name=None): """Initialize api for a given module name namespace : str The namespace of the source registry target_module_name : str The target module name if different from namespace """ target_module_name = target_module_name if target_module_name else namespace if namespace.startswith("tvm."): _init_api_prefix(target_module_name, namespace[4:]) else: _init_api_prefix(target_module_name, namespace)
最终都会调到_init_api_prefix() 函数中,继续往下分析:
def _init_api_prefix(module_name, prefix): module = sys.modules[module_name] for name in list_global_func_names(): if not name.startswith(prefix): continue fname = name[len(prefix) + 1 :] target_module = module if fname.find(".") != -1: continue f = get_global_func(name) ff = _get_api(f) ff.__name__ = fname ff.__doc__ = "TVM PackedFunc %s. " % fname setattr(target_module, ff.__name__, ff)
可看到最终它调用了get_global_func(name)
函数,这个函数为通过函数名从全局函数表中获取函数的实现。
TVM的注册机制:这个函数是通过register注册的全局函数,通过对函数名_BuildModule
全局搜索,可找到函数注册的地方
TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = RelayBuildCreate(); });
该注册位于/tvm/src/relay/backend/build_module.cc中,
首先,看下TVM_REGISTER_GLOBAL
这个宏,定义如下:
#define TVM_REGISTER_GLOBAL(OpName) \ TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName)
该Register函数的实现位于/tvm/src/runtime/registry.cc
Registry& Registry::Register(const std::string& name, bool can_override) { // NOLINT(*) Manager* m = Manager::Global(); std::lock_guard<std::mutex> lock(m->mutex); if (m->fmap.count(name)) { ICHECK(can_override) << "Global PackedFunc " << name << " is already registered"; } Registry* r = new Registry(); r->name_ = name; m->fmap[name] = r; return *r; }
构造Register()对象,并将函数名与对应的Registry类型(set_body
)进行绑定
然后 set_body函数中,就是调用了*rv = RelayBuildCreate();
,RelayBuildCreate
实现如下:
runtime::Module RelayBuildCreate() { auto exec = make_object<RelayBuildModule>(); return runtime::Module(exec); }
其中构造了RelayBuildModule
对象,类RelayBuildModule
继承自runtime::ModuleNode
,具体实现如下:
class RelayBuildModule : public runtime::ModuleNode { public: RelayBuildModule() = default; /*! * \brief Get member function to front-end * \param name The name of the function. * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { if (name == "get_graph_json") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); }); } else if (name == "get_module") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 8); this->Build(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); }); } else if (name == "list_params") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->ListParamNames(); }); } else if (name == "get_params") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map<String, Constant> params = args[0]; for (const auto& kv : params) { this->SetParam(kv.first, kv.second->data); } }); } else if (name == "get_devices") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->executor_codegen_->ListDevices(); }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->executor_codegen_->GetIRModule(); }); } else if (name == "get_external_modules") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->executor_codegen_->GetExternalModules(); }); } else if (name == "get_function_metadata") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->executor_codegen_->GetFunctionMetadata(); }); } else if (name == "get_executor_codegen_metadata") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->executor_codegen_->GetExecutorCodegenMetadata(); }); } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 2); *rv = this->Optimize(args[0], args[1]); }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); } }
这个类中有一个GetFunction
函数,这个函数会通过名字查询要使用的函数,打包成PackedFunc返回,这个函数和上面__init__
中的self.mod[“build”]
等建立了映射关系。
则self.mod[“build”]
会调用到如下代码:
if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 8); this->Build(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); });
其中Build代码实现如下:
/*! * \brief Build relay IRModule for graph executor * * \param mod Relay IRModule * \param raw_targets List of available targets for kernels. * \param executor Executor to target * \param runtime Runtime to codegen for * \param mod_name Name of the module */ void Build(IRModule mod, const Array<Target>& raw_targets, const tvm::Target& target_host, const Executor& executor, const Runtime& runtime, const WorkspaceMemoryPools& workspace_memory_pools, const ConstantMemoryPools& constant_memory_pools, const String mod_name) { VLOG_CONTEXT << "Build"; executor_ = executor; runtime_ = runtime; workspace_memory_pools_ = workspace_memory_pools; constant_memory_pools_ = constant_memory_pools; config_ = CompilationConfig(PassContext::Current(), raw_targets); VLOG(1) << "Using compilation config:" << std::endl << config_; BuildRelay(std::move(mod), mod_name); }
BuildRelay
代码实现:
/*! * \brief Compile a Relay IR module to runtime module. * * \param relay_module The Relay IR module. * \param params The parameters. */ void BuildRelay(IRModule relay_module, const String& mod_name) { // Relay IRModule -> IRModule optimizations. IRModule module = WithAttrs( relay_module, {{tvm::attr::kExecutor, executor_}, {tvm::attr::kRuntime, runtime_}}); relay_module = OptimizeImpl(std::move(module)); // Get the updated function and new IRModule to build. // Instead of recreating the IRModule, we should look at the differences between this and the // incoming IRModule to see if we can just pass (IRModule, Function) to the code generator. Function func = Downcast<Function>(relay_module->Lookup("main")); IRModule func_module = WithAttrs(IRModule::FromExpr(func), {{tvm::attr::kExecutor, executor_}, {tvm::attr::kRuntime, runtime_}, {tvm::attr::kWorkspaceMemoryPools, workspace_memory_pools_}, {tvm::attr::kConstantMemoryPools, constant_memory_pools_}}); // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_->name); executor_codegen_->Init(nullptr, config_->primitive_targets); executor_codegen_->Codegen(func_module, func, mod_name); executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); auto lowered_funcs = executor_codegen_->GetIRModule(); // No need to build for external functions. Target ext_dev("ext_dev"); if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) { lowered_funcs.Set(ext_dev, IRModule()); } const Target& host_target = config_->host_virtual_device->target; const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); // When there is no lowered_funcs due to reasons such as optimization. if (lowered_funcs.size() == 0) { if (host_target->kind->name == "llvm") { CHECK(pf != nullptr) << "Unable to create empty module for llvm without llvm codegen."; // If we can decide the target is LLVM, we then create an empty LLVM module. ret_.mod = (*pf)(host_target->str(), "empty_module"); } else { // If we cannot decide the target is LLVM, we create an empty CSourceModule. // The code content is initialized with ";" to prevent complaining // from CSourceModuleNode::SaveToFile. ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{}); } } else { ret_.mod = tvm::TIRToRuntime(lowered_funcs, host_target); } auto ext_mods = executor_codegen_->GetExternalModules(); ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target, runtime_, executor_, executor_codegen_->GetExecutorCodegenMetadata()); // Remove external params which were stored in metadata module. for (tvm::runtime::Module mod : ext_mods) { auto pf_var = mod.GetFunction("get_const_vars"); if (pf_var != nullptr) { Array<String> variables = pf_var(); for (size_t i = 0; i < variables.size(); i++) { auto it = ret_.params.find(variables[i].operator std::string()); if (it != ret_.params.end()) { VLOG(1) << "constant '" << variables[i] << "' has been captured in external module"; ret_.params.erase(it); } } } } }
在这个函数是编译流程的主要代码,可以看到它包含了Optimize
,Codegen
两个过程。
至此,类BuildModule中的__init__函数初始化完成,
接下来便执行调用该类中的build函数,这个函数为BuildModule类中的成员函数,具体实现如下:
def build( self, mod, target=None, target_host=None, executor=Executor("graph"), runtime=Runtime("cpp"), workspace_memory_pools=None, constant_memory_pools=None, params=None, mod_name=None, ): """ Parameters ---------- mod : :py:class:`~tvm.IRModule` The IRModule to build. target : any multi-target like object, see Target.canon_multi_target For homogeneous compilation, the unique build target. For heterogeneous compilation, a dictionary or list of possible build targets. target_host : None, or any target-like object, see Target.canon_target Host compilation target, if target is device. When TVM compiles device specific program such as CUDA, we also need host(CPU) side code to interact with the driver to setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, otherwise a stackvm interpreter is used. executor : Optional[Executor] The executor configuration with which to build the model. Defaults to "graph" if no executor specified. runtime : Optional[Runtime] Runtime configuration to use when building the model. Defaults to "cpp" if no runtime specified. workspace_memory_pools : Optional[WorkspaceMemoryPools] The object that contains an Array of WorkspacePoolInfo objects that hold properties of read-write workspace pools that could be used by the inference. constant_memory_pools : Optional[ConstantMemoryPools] The object that contains an Array of ConstantPoolInfo objects that hold properties of read-only memory pools that could be used by the inference. params : dict of str to NDArray Input parameters to the graph that do not change during inference time. Used for constant folding. mod_name: Optional[str] The module name we will build Returns ------- graph_json : str The json string that can be accepted by graph executor. mod : tvm.Module The module containing necessary libraries. params : dict The parameters of the final graph. """ # pylint: disable=import-outside-toplevel from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.meta_schedule import is_meta_schedule_enabled # pylint: enable=import-outside-toplevel # Setup the params. if params: self._set_params(params) # Build the IR module. If auto_scheduler is not enabled, # then use the TOPI-defined schedule. # Turn off AutoTVM config not found warnings if auto_scheduler is enabled. old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent autotvm.GLOBAL_SCOPE.silent = ( is_auto_scheduler_enabled() or is_meta_schedule_enabled() or old_autotvm_silent ) mod_name = mangle_module_name(mod_name) self._build( mod, target, target_host, executor, runtime, workspace_memory_pools, constant_memory_pools, mod_name, ) autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent # Get artifacts mod = self.get_module() params = self.get_params() executor_config = self.get_graph_json() if executor.name == "graph" else None return executor_config, mod, params
该成员函数返回,graph_json(The json string that can be accepted by graph executor.)/mod(tvm.Module类型,包含必要的库)/params(dict类型,最后graph的参数)
最后 _executor_factory.GraphExecutorFactoryModule()
函数将结果返回
Now we can try deploying the compiled model on target.
from tvm.contrib import graph_executor dtype = "float32" m = graph_executor.GraphModule(lib["default"](dev)) # Set inputs m.set_input(input_name, tvm.nd.array(img.astype(dtype))) # Execute m.run() # Get outputs tvm_output = m.get_output(0)
参考:https://zhuanlan.zhihu.com/p/50529704
https://zhuanlan.zhihu.com/p/376863322