在上一节中,我们将Toy Dialect的部分Operation Lowering到Affine Dialect,MemRef Dialect和Standard Dialect,而toy.print
操作保持不变,所以又被叫作部分Lowering。通过这个Lowering可以将Toy Dialect的Operation更底层的实现逻辑表达出来,以寻求更多的优化机会,得到更好的MLIR表达式。这一节,我们将在上一节得到的混合型MLIR表达式完全Lowering到LLVM Dialect上,然后生成LLVM IR,并且我们可以使用MLIR的JIT编译引擎来运行最终的MLIR表达式并输出计算结果。
这一小节我们将来介绍如何将上一节结束的MLIR表达式完全Lowering为LLVM Dialect,我们还是回顾一下上一节最终的MLIR表达式:
func @main() { %cst = arith.constant 1.000000e+00 : f64 %cst_0 = arith.constant 2.000000e+00 : f64 %cst_1 = arith.constant 3.000000e+00 : f64 %cst_2 = arith.constant 4.000000e+00 : f64 %cst_3 = arith.constant 5.000000e+00 : f64 %cst_4 = arith.constant 6.000000e+00 : f64 // Allocating buffers for the inputs and outputs. %0 = memref.alloc() : memref<3x2xf64> %1 = memref.alloc() : memref<2x3xf64> // Initialize the input buffer with the constant values. affine.store %cst, %1[0, 0] : memref<2x3xf64> affine.store %cst_0, %1[0, 1] : memref<2x3xf64> affine.store %cst_1, %1[0, 2] : memref<2x3xf64> affine.store %cst_2, %1[1, 0] : memref<2x3xf64> affine.store %cst_3, %1[1, 1] : memref<2x3xf64> affine.store %cst_4, %1[1, 2] : memref<2x3xf64> affine.for %arg0 = 0 to 3 { affine.for %arg1 = 0 to 2 { // Load the transpose value from the input buffer. %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64> // Multiply and store into the output buffer. %3 = arith.mulf %2, %2 : f64 affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64> } } // Print the value held by the buffer. toy.print %0 : memref<3x2xf64> memref.dealloc %1 : memref<2x3xf64> memref.dealloc %0 : memref<3x2xf64> return }
我们要将这个三种Dialect混合的MLIR表达式完全Lowering为LLVM Dialect,注意LLVM Dialect是MLIR的一种特殊的Dialect层次的中间表示,它并不是LLVM IR。Lowering为LLVM Dialect的整体过程可以分为如下几步:
之前部分Lowering的时候并没有对toy.print
操作进行Lowering,所以这里优先将toy.print
进行Lowering。我们把toy.print
Lowering到一个非仿射循环嵌套,它为每个元素调用printf
。Dialect转换框架支持传递Lowering,不需要直接Lowering为LLVM Dialect。通过应用传递Lowering可以应用多种模式来使得操作合法化(合法化的意思在这里指的就是完全Lowering到LLVM Dialect)。 传递Lowering在这里体现为将toy.print
先Lowering到循环嵌套Dialect里面,而不是直接Lowering为LLVM Dialect。
在Lowering过程中,printf
的声明在mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
中,代码如下:
/// Return a symbol reference to the printf function, inserting it into the /// module if necessary. static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf")) return SymbolRefAttr::get(context, "printf"); // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` auto llvmI32Ty = IntegerType::get(context, 32); auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, /*isVarArg=*/true); // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType); return SymbolRefAttr::get(context, "printf"); }
这部分代码返回了printf
函数的符号引用,必要时将其插入Module。在函数中,为printf
创建了函数声明,然后将printf
函数插入到父Module的主体中。
第一个需要确定的是转换目标(ConversionTarget),对于这个Lowering我们除了顶层的Module将所有的内容都Lowering为LLVM Dialect。这里代码表达的信息和官方文档有一些出入,以最新的代码为准。
// The first thing to define is the conversion target. This will define the // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. LLVMConversionTarget target(getContext()); target.addLegalOp<ModuleOp>();
然后需要确定类型转换器(Type Converter),我们现存的MLIR表达式还有MemRef
类型,我们需要将其转换为LLVM的类型。为了执行这个转化,我们使用TypeConverter
作为Lowering的一部分。这个转换器指定一种类型如何映射到另外一种类型。由于现存的操作中已经不存在任何Toy Dialect操作,因此使用MLIR默认的转换器就可以满足需求。定义如下:
// During this lowering, we will also be lowering the MemRef types, that are // currently being operated on, to a representation in LLVM. To perform this // conversion we use a TypeConverter as part of the lowering. This converter // details how one type maps to another. This is necessary now that we will be // doing more complicated lowerings, involving loop region arguments. LLVMTypeConverter typeConverter(&getContext());
再然后还需要确定转换模式(Conversion Patterns)。这部分代码为:
// Now that the conversion target has been defined, we need to provide the // patterns used for lowering. At this point of the compilation process, we // have a combination of `toy`, `affine`, and `std` operations. Luckily, there // are already exists a set of patterns to transform `affine` and `std` // dialects. These patterns lowering in multiple stages, relying on transitive // lowerings. Transitive lowering, or A->B->C lowering, is when multiple // patterns must be applied to fully transform an illegal operation into a // set of legal ones. RewritePatternSet patterns(&getContext()); populateAffineToStdConversionPatterns(patterns); populateLoopToStdConversionPatterns(patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); // The only remaining operation to lower from the `toy` dialect, is the // PrintOp. patterns.add<PrintOpLowering>(&getContext());
上面这段代码展示了为Affine Dialect,Standard Dialect以及遗留的toy.print
定义匹配重写规则。首先将Affine Dialect下降到Standard Dialect,即populateAffineToStdConversionPatterns
。然后将Loop(针对的是toy.print
操作,它已经Lowering到了循环嵌套Dialect)下降到Standard Dialect,即populateLoopToStdConversionPatterns
。最后,将Standard Dialect转换到LLVM Dialect,即populateMemRefToLLVMConversionPatterns
。以及不要忘了把toy.print
的Lowering模式PrintOpLowering
加到patterns
里面。
定义了Lowering过程需要的所有组件之后,就可以执行完全Lowering了。使用applyFullConversion(module, target, std::move(patterns)))
函数可以保证转换的结果只存在合法的操作,上一篇部分Lowering的笔记调用的是mlir::applyPartialConversion(function, target, patterns)
可以对比着看一下。
// We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. auto module = getOperation(); if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure();
这段代码在mlir/examples/toy/Ch6/toyc.cpp
中:
if (isLoweringToLLVM) { // Finish lowering the toy IR to the LLVM dialect. pm.addPass(mlir::toy::createLowerToLLVMPass()); }
这段代码在优化Pipline中添加了mlir::toy::createLowerToLLVMPass()
这个完全Lowering的Pass,可以把MLIR 表达式下降为LLVM Dialect表达式。我们运行一下示例程序看下结果:
执行下面的命令:
cd llvm-project/build/bin ./toyc-ch6 ../../mlir/test/Examples/Toy/Ch6/llvm-lowering.mlir -emit=mlir-llvm
即获得了完全Lowering之后的MLIR表达式,结果比较长,这里只展示一部分。可以看到目前MLIR表达式已经完全在LLVM Dialect空间下了。
llvm.func @free(!llvm<"i8*">) llvm.func @printf(!llvm<"i8*">, ...) -> i32 llvm.func @malloc(i64) -> !llvm<"i8*"> llvm.func @main() { %0 = llvm.mlir.constant(1.000000e+00 : f64) : f64 %1 = llvm.mlir.constant(2.000000e+00 : f64) : f64 ... ^bb16: %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> %222 = llvm.mlir.constant(0 : index) : i64 %223 = llvm.mlir.constant(2 : index) : i64 %224 = llvm.mul %214, %223 : i64 %225 = llvm.add %222, %224 : i64 %226 = llvm.mlir.constant(1 : index) : i64 %227 = llvm.mul %219, %226 : i64 %228 = llvm.add %225, %227 : i64 %229 = llvm.getelementptr %221[%228] : (!llvm."double*">, i64) -> !llvm<"f64*"> %230 = llvm.load %229 : !llvm<"double*"> %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, f64) -> i32 %232 = llvm.add %219, %218 : i64 llvm.br ^bb15(%232 : i64) ... ^bb18: %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*"> llvm.call @free(%236) : (!llvm<"i8*">) -> () %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*"> llvm.call @free(%238) : (!llvm<"i8*">) -> () %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*"> llvm.call @free(%240) : (!llvm<"i8*">) -> () llvm.return }
我们可以使用JIT编译引擎来运行上面得到的LLVM Dialect IR,获得推理结果。这里我们使用了mlir::ExecutionEngine
基础架构来运行LLVM Dialect IR。程序位于:mlir/examples/toy/Ch6/toyc.cpp
。
int runJit(mlir::ModuleOp module) { // Initialize LLVM targets. llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); // Register the translation from MLIR to LLVM IR, which must happen before we // can JIT-compile. mlir::registerLLVMDialectTranslation(*module->getContext()); // An optimization pipeline to use within the execution engine. auto optPipeline = mlir::makeOptimizingTransformer( /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, /*targetMachine=*/nullptr); // Create an MLIR execution engine. The execution engine eagerly JIT-compiles // the module. auto maybeEngine = mlir::ExecutionEngine::create( module, /*llvmModuleBuilder=*/nullptr, optPipeline); assert(maybeEngine && "failed to construct an execution engine"); auto &engine = maybeEngine.get(); // Invoke the JIT-compiled function. auto invocationResult = engine->invokePacked("main"); if (invocationResult) { llvm::errs() << "JIT invocation failed\n"; return -1; } return 0; }
这里尤其需要注意这行:mlir::registerLLVMDialectTranslation(*module->getContext());。从代码的注释来看这个是将LLVM Dialect表达式翻译成LLVM IR,在JIT编译的时候起到缓存作用,也就是说下次执行的时候不会重复执行上面的各种MLIR表达式变换。
这里创建一个MLIR执行引擎mlir::ExecutionEngine
来运行表达式中的main
函数。可以使用下面的命令来输出最终的计算结果:
cd llvm-project/build/bin ./toyc-ch6 ../../mlir/test/Examples/Toy/Ch6/codegen.toy -emit=jit -opt
结果为:
1.000000 16.000000 4.000000 25.000000 9.000000 36.000000
到这里,我们就将原始的MLIR表达式经过一系列Pass进行优化,以及部分Lowering到三种Dialect混合的表达式,和完全Lowering为LLVM Dialect表达式,最后翻译到LLVM IR使用MLIR的Jit执行引擎进行执行,获得了最终结果。
另外,mlir/examples/toy/Ch6/toyc.cpp
中还提供了一个dumpLLVMIR
函数,可以将MLIR表达式翻译成LLVM IR表达式。然后再经过LLVM IR的优化处理。使用如下命令可以打印出生成的LLVM IR:
$ cd llvm-project/build/bin $ ./toyc-ch6 ../../mlir/test/Examples/Toy/Ch6/codegen.toy -emit=llvm -opt
这篇文章介绍了如何将部分Lowering之后的MLIR表达式进一步完全Lowering到LLVM Dialect上,然后通过JIT编译引擎来执行代码并获得推理结果,另外还可以输出LLVM Dialect生成的LLVM IR。