Contents

HAL::HALTransformPassPipeline的主要作用是进行tiling、vectorization和bufferization等操作,分配计算负载,最终生成target device的代码。比如cuda target的dispatch source code会被递降为NVVM IR。

  • buildHALConfigurationPassPipeline

    • addCleanupPatterns

    • createAssignTargetDevicesPass

      在最外层的module上添加device targets属性,可以指定多个target devices。

      1
      2
      3
      module attributes {hal.device.targets = [#hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}>], legacy_sync}>]} {
      ...
      }
    • createVerifyTargetEnvironmentPass

      验证device tagets是否正确设置,以及编译后端是否被注册过。

    • createMaterializeInterfacesPass

      为每个executable创建device target相关的变体(variant),每一种device target对应一个executable variant。将executable的export和source func都转换为无参数的func,统一dispatch、export和source func的调用接口,dispatch指定输入和bindings的关系,source func则通过binding id来获取输入参数。

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      stream.executable private @test_dispatch_0 {
      stream.executable.export public @test_dispatch_0_generic_100000x100 workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
      stream.return %x, %y, %z : index, index, index
      }
      builtin.module {
      func.func @test_dispatch_0_generic_100000x100(%arg0: !stream.binding {stream.alignment = 64 : index}, %arg1: !stream.binding {stream.alignment = 64 : index}, %arg2: !stream.binding {stream.alignment = 64 : index}) {
      ...
      return
      }
      }
      }

      func.func @test(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
      ...
      %3 = stream.cmd.execute with(%0 as %arg2: !stream.resource<external>{%c40000000}, %1 as %arg3: !stream.resource<external>{%c40000000}, %2 as %arg4: !stream.resource<external>{%c400000}) {
      stream.cmd.fill %c0_i8, %arg4[%c0 for %c400000] : i8 -> !stream.resource<external>{%c400000}
      stream.cmd.dispatch @test_dispatch_0::@test_dispatch_0_generic_100000x100[%c100000, %c1] {
      ro %arg2[%c0 for %c40000000] : !stream.resource<external>{%c40000000},
      ro %arg3[%c0 for %c40000000] : !stream.resource<external>{%c40000000},
      rw %arg4[%c0 for %c400000] : !stream.resource<external>{%c400000}
      }
      } => !stream.timepoint
      ...
      }

      转换为

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      hal.executable private @test_dispatch_0 {
      hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
      hal.executable.export public @test_dispatch_0_generic_100000x100 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
      ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
      hal.return %x, %y, %z : index, index, index
      }
      builtin.module {
      func.func @test_dispatch_0_generic_100000x100() {
      %c0 = arith.constant 0 : index
      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>>
      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>>
      %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readwrite:tensor<100000xf32>>
      ...
      return
      }
      }
      }
      }

      func.func @test(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
      ...
      %3 = stream.cmd.execute with(%0 as %arg2: !stream.resource<external>{%c40000000}, %1 as %arg3: !stream.resource<external>{%c40000000}, %2 as %arg4: !stream.resource<external>{%c400000}) {
      stream.cmd.fill %c0_i8, %arg4[%c0 for %c400000] : i8 -> !stream.resource<external>{%c400000}
      stream.cmd.dispatch @test_dispatch_0::@test_dispatch_0_generic_100000x100[%c100000, %c1] {
      ro %arg2[%c0 for %c40000000] : !stream.resource<external>{%c40000000},
      ro %arg3[%c0 for %c40000000] : !stream.resource<external>{%c40000000},
      rw %arg4[%c0 for %c400000] : !stream.resource<external>{%c400000}
      } attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]}
      } => !stream.timepoint
      ...
      }
  • createTranslateExecutablesPass

    根据每一个hal.executable.variant 的target device调用对应的后端进行编译。比如cuda会调用CUDATargetBackend,CUDATargetBackend实际执行的是下面一序列passes。

    • buildLLVMGPUTransformPassPipeline

      • createTypePropagationPass

        对integer的element type进行标准化,并传播修改过的type。

      • createBufferizeCopyOnlyDispatchesPass

        将纯数据拷贝的dispatch(只有tensor load和store)转换成linalg generic op,并bufferize化。

        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        func.func @test_dispatch_0() {
        %c0 = arith.constant 0 : index
        %0 = hal.interface.constant.load[0] : i32
        %1 = arith.index_castui %0 : i32 to index
        %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%1}
        %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%1}
        %4 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [%1], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xf32>>{%1} -> tensor<?xf32>
        flow.dispatch.tensor.store %4, %3, offsets = [0], sizes = [%1], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%1}
        return
        }

        转换成

        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        func.func @test_dispatch_0() {
        %c0 = arith.constant 0 : index
        %0 = hal.interface.constant.load[0] : i32
        %1 = arith.index_castui %0 : i32 to index
        %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<?xf32, #hal.descriptor_type<storage_buffer>>{%1}
        memref.assume_alignment %2, 64 : memref<?xf32, #hal.descriptor_type<storage_buffer>>
        %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<?xf32, #hal.descriptor_type<storage_buffer>>{%1}
        memref.assume_alignment %3, 64 : memref<?xf32, #hal.descriptor_type<storage_buffer>>
        linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%2 : memref<?xf32, #hal.descriptor_type<storage_buffer>>) outs(%3 : memref<?xf32, #hal.descriptor_type<storage_buffer>>) {
        ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
        }
        return
        }
      • createEraseHALDescriptorTypeFromMemRefPass

        将memory space为hal descriptor type的value转换成memref。

      • createLLVMGPULowerExecutableTargetPass

        • initGPULaunchConfig

          根据具体的计算负载和类型,计算gpu launch的配置,包括分块策略、group count、thread num以及后续lowering分发的流程等。

          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          11
          12
          13
          14
          15
          16
          17
          18
          19
          hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
          hal.executable.export public @test_dispatch_0_generic_100000x100 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
          ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
          %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
          hal.return %x, %y, %z : index, index, index
          }
          builtin.module {
          func.func @test_dispatch_0_generic_100000x100() {
          ...
          %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%3, %4 : tensor<100000x100xf32>, tensor<100000x100xf32>) outs(%5 : tensor<100000xf32>) {
          ^bb0(%in: f32, %in_0: f32, %out: f32):
          %7 = arith.addf %in, %in_0 : f32
          %8 = arith.addf %7, %out : f32
          linalg.yield %8 : f32
          } -> tensor<100000xf32>
          ...
          }
          }
          }

          转换成

          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          11
          12
          13
          14
          15
          16
          17
          18
          19
          hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
          hal.executable.export public @test_dispatch_0_generic_100000x100 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize>, workgroup_size = [64 : index, 1 : index, 1 : index]} {
          ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
          %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
          hal.return %x, %y, %z : index, index, index
          }
          builtin.module {
          func.func @test_dispatch_0_generic_100000x100() {
          ...
          %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%3, %4 : tensor<100000x100xf32>, tensor<100000x100xf32>) outs(%5 : tensor<100000xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
          ^bb0(%in: f32, %in_0: f32, %out: f32):
          %7 = arith.addf %in, %in_0 : f32
          %8 = arith.addf %7, %out : f32
          linalg.yield %8 : f32
          } -> tensor<100000xf32>
          ...
          }
          }
          }

          可以看到export func多了translation_info和workgroup_size两个属性,而source func也多了一个lowering_config属性。translation_info表示后续lowering分发到LLVMGPUVectorize这个pipeline。workgroup_size可以认为是3维的gpu block dim,这里表示每个线程块有64个线程。lowering_config指明了每层循环的分块策略,这里表示一个线程块计算256个100xf32的数据,而且每个线程一次计算一个4xf32的向量。

        • DispatchLoweringPassPipeline

          根据translation_info分发到下面的pipeline继续lowering。

          • GPUSimpleDistributePassPipeline

          • GPUVectorizationPassPipeline

            • getTileAndDistributeConfig

              定位到dispatch的root节点(一般是最后一个linalg reduction op,如果没有reduction op,则会选择最后一个linalg generic op),从节点属性中取出lowering_config(tile size),将非parallel loop对应的tile size置0,表示接下来只会对parallel loop进行vectorize,并计算parallel loop的loop range。

            • LowerDispatchWorkgroupCountForDagRootOp

              根据loop range和tile size计算workgroup count。

              1
              2
              3
              4
              5
              hal.executable.export public @test_dispatch_0_generic_100000x100 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize>, workgroup_size = [64 : index, 1 : index, 1 : index]} {
              ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
              %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
              hal.return %x, %y, %z : index, index, index
              }

              转换成

              1
              2
              3
              4
              5
              6
              hal.executable.export public @test_dispatch_0_generic_100000x100 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize>, workgroup_size = [64 : index, 1 : index, 1 : index]} {
              ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
              %c391 = arith.constant 391 : index
              %c1 = arith.constant 1 : index
              hal.return %c391, %c1, %c1 : index, index, index
              }

              可以看到计算的group count为(391, 1, 1)。391 = UDIV(100000, 256)。

            • populateTileAndDistributeToWorkgroupsPatterns

              对parallel loop进行分块,将source func转换成单个work group的计算逻辑。

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              func.func @test_dispatch_0_generic_100000x100() {
              ...
              %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%3, %4 : tensor<100000x100xf32>, tensor<100000x100xf32>) outs(%5 : tensor<100000xf32>) attrs = {__internal_linalg_transform__ = "__workgroup_tiling__", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_0: f32, %out: f32):
              %7 = arith.addf %in, %in_0 : f32
              %8 = arith.addf %7, %out : f32
              linalg.yield %8 : f32
              } -> tensor<100000xf32>
              ...
              }

              转换成

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              func.func @test_dispatch_0_generic_100000x100() {
              ...
              %workgroup_id_x = hal.interface.workgroup.id[0] : index
              %workgroup_count_x = hal.interface.workgroup.count[0] : index
              %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x]
              %4 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_count_x]
              scf.for %arg0 = %3 to %c100000 step %4 {
              %5 = affine.min affine_map<(d0) -> (256, -d0 + 100000)>(%arg0)
              %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%5, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %7 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%5, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %8 = flow.dispatch.tensor.load %2, offsets = [%arg0], sizes = [%5], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<100000xf32>> -> tensor<?xf32>
              %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%6, %7 : tensor<?x100xf32>, tensor<?x100xf32>) outs(%8 : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_0: f32, %out: f32):
              %10 = arith.addf %in, %in_0 : f32
              %11 = arith.addf %10, %out : f32
              linalg.yield %11 : f32
              } -> tensor<?xf32>
              ...
              }
            • createWorkgroupSpecializationPass

              将分块之后的计算逻辑分成固定形状和剩余部分动态形状两部分计算逻辑。

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              func.func @test_dispatch_0_generic_100000x100() {
              ...
              %workgroup_id_x = hal.interface.workgroup.id[0] : index
              %workgroup_count_x = hal.interface.workgroup.count[0] : index
              %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x]
              %4 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_count_x]
              scf.for %arg0 = %3 to %c100000 step %4 {
              %5 = affine.min affine_map<(d0) -> (256, -d0 + 100000)>(%arg0)
              %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%5, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %7 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%5, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %8 = flow.dispatch.tensor.load %2, offsets = [%arg0], sizes = [%5], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<100000xf32>> -> tensor<?xf32>
              %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%6, %7 : tensor<?x100xf32>, tensor<?x100xf32>) outs(%8 : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_0: f32, %out: f32):
              %10 = arith.addf %in, %in_0 : f32
              %11 = arith.addf %10, %out : f32
              linalg.yield %11 : f32
              } -> tensor<?xf32>
              ...
              }

              会转换成

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              20
              21
              22
              23
              24
              25
              26
              27
              28
              29
              30
              31
              32
              33
              34
              35
              36
              37
              38
              func.func @test_dispatch_0_generic_100000x100() {
              ...
              %workgroup_id_x = hal.interface.workgroup.id[0] : index
              %workgroup_count_x = hal.interface.workgroup.count[0] : index
              %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x]
              %4 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_count_x]
              scf.for %arg0 = %3 to %c100000 step %4 {
              %5 = affine.min affine_map<(d0) -> (-d0 + 100000, 256)>(%arg0)
              %c256 = arith.constant 256 : index
              %6 = arith.cmpi eq, %5, %c256 : index
              scf.if %6 {
              // 计算[256,100]静态形状的分块
              %7 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%c256, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %8 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%c256, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %9 = flow.dispatch.tensor.load %2, offsets = [%arg0], sizes = [%c256], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<100000xf32>> -> tensor<?xf32>
              %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%7, %8 : tensor<?x100xf32>, tensor<?x100xf32>) outs(%9 : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_0: f32, %out: f32):
              %11 = arith.addf %in, %in_0 : f32
              %12 = arith.addf %11, %out : f32
              linalg.yield %12 : f32
              } -> tensor<?xf32>
              flow.dispatch.tensor.store %10, %2, offsets = [%arg0], sizes = [%c256], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<readwrite:tensor<100000xf32>>
              } else {
              // 计算剩下的[%5, 100]动态形状的分块
              %7 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%5, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %8 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%5, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %9 = flow.dispatch.tensor.load %2, offsets = [%arg0], sizes = [%5], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<100000xf32>> -> tensor<?xf32>
              %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%7, %8 : tensor<?x100xf32>, tensor<?x100xf32>) outs(%9 : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_0: f32, %out: f32):
              %11 = arith.addf %in, %in_0 : f32
              %12 = arith.addf %11, %out : f32
              linalg.yield %12 : f32
              } -> tensor<?xf32>
              flow.dispatch.tensor.store %10, %2, offsets = [%arg0], sizes = [%5], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<readwrite:tensor<100000xf32>>
              }
              }
              return
              }
            • createRemoveSingleIterationLoopPass

              移除确信只会循环1次的loop。比如上面的scf.for %arg0 = %3 to %c100000 step %4就只会被循环一次,因为step = 256 * 391 = 100096 > 100000,因此这个循环会被消除,转换成如下代码。

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              20
              21
              22
              23
              24
              25
              26
              27
              28
              29
              30
              31
              32
              func.func @test_dispatch_0_generic_100000x100() {
              ...
              %c256 = arith.constant 256 : index
              %workgroup_id_x = hal.interface.workgroup.id[0] : index
              %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x]
              %4 = affine.min affine_map<(d0) -> (-d0 + 100000, 256)>(%3)
              %5 = arith.cmpi eq, %4, %c256 : index
              scf.if %5 {
              %6 = flow.dispatch.tensor.load %0, offsets = [%3, 0], sizes = [256, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<256x100xf32>
              %7 = flow.dispatch.tensor.load %1, offsets = [%3, 0], sizes = [256, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<256x100xf32>
              %8 = flow.dispatch.tensor.load %2, offsets = [%3], sizes = [256], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<100000xf32>> -> tensor<256xf32>
              %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%6, %7 : tensor<256x100xf32>, tensor<256x100xf32>) outs(%8 : tensor<256xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_0: f32, %out: f32):
              %10 = arith.addf %in, %in_0 : f32
              %11 = arith.addf %10, %out : f32
              linalg.yield %11 : f32
              } -> tensor<256xf32>
              flow.dispatch.tensor.store %9, %2, offsets = [%3], sizes = [256], strides = [1] : tensor<256xf32> -> !flow.dispatch.tensor<readwrite:tensor<100000xf32>>
              } else {
              %6 = flow.dispatch.tensor.load %0, offsets = [%3, 0], sizes = [%4, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %7 = flow.dispatch.tensor.load %1, offsets = [%3, 0], sizes = [%4, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %8 = flow.dispatch.tensor.load %2, offsets = [%3], sizes = [%4], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<100000xf32>> -> tensor<?xf32>
              %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%6, %7 : tensor<?x100xf32>, tensor<?x100xf32>) outs(%8 : tensor<?xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_0: f32, %out: f32):
              %10 = arith.addf %in, %in_0 : f32
              %11 = arith.addf %10, %out : f32
              linalg.yield %11 : f32
              } -> tensor<?xf32>
              flow.dispatch.tensor.store %9, %2, offsets = [%3], sizes = [%4], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<readwrite:tensor<100000xf32>>
              }
              return
              }
            • createLLVMGPUTileTensor

              前面pass主要针对的是外层parallel loop的vectorize,生成的是一个线程块的计算逻辑,接下来继续将负载分布到每一个线程,并且对内层的reduction也做vectorize。上面的代码继续转换成如下代码,

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              20
              21
              22
              23
              24
              25
              26
              27
              28
              29
              30
              31
              32
              33
              34
              35
              36
              37
              38
              39
              40
              41
              42
              43
              44
              45
              46
              47
              48
              49
              50
              51
              52
              53
              54
              55
              56
              57
              58
              59
              60
              61
              62
              63
              64
              65
              66
              67
              68
              69
              70
              71
              72
              73
              74
              75
              76
              77
              78
              79
              func.func @test_dispatch_0_generic_100000x100() {
              %c100 = arith.constant 100 : index
              %c4 = arith.constant 4 : index
              %c64 = arith.constant 64 : index
              %c256 = arith.constant 256 : index
              %c0 = arith.constant 0 : index
              ...
              %workgroup_id_x = hal.interface.workgroup.id[0] : index
              %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x]
              %4 = affine.min affine_map<()[s0] -> (s0 * -256 + 100000, 256)>()[%workgroup_id_x]
              %5 = arith.cmpi eq, %4, %c256 : index
              scf.if %5 {
              %6 = flow.dispatch.tensor.load %0, offsets = [%3, 0], sizes = [256, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<256x100xf32>
              %7 = flow.dispatch.tensor.load %1, offsets = [%3, 0], sizes = [256, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<256x100xf32>
              %8 = flow.dispatch.tensor.load %2, offsets = [%3], sizes = [256], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<100000xf32>> -> tensor<256xf32>
              // 64个线程并发计算,每个线程计算[4, 100]的分块
              %9 = scf.foreach_thread (%arg0) in (%c64) shared_outs(%arg1 = %8) -> (tensor<256xf32>) {
              %10 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg0)
              %11 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg0)
              %12 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg0)
              %extracted_slice = tensor.extract_slice %6[%10, 0] [4, 100] [1, 1] : tensor<256x100xf32> to tensor<4x100xf32>
              %extracted_slice_0 = tensor.extract_slice %7[%11, 0] [4, 100] [1, 1] : tensor<256x100xf32> to tensor<4x100xf32>
              %extracted_slice_1 = tensor.extract_slice %arg1[%12] [4] [1] : tensor<256xf32> to tensor<4xf32>
              // 内层reduction loop的vectorize
              %13 = scf.for %arg2 = %c0 to %c100 step %c4 iter_args(%arg3 = %extracted_slice_1) -> (tensor<4xf32>) {
              %extracted_slice_2 = tensor.extract_slice %extracted_slice[0, %arg2] [4, 4] [1, 1] : tensor<4x100xf32> to tensor<4x4xf32>
              %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[0, %arg2] [4, 4] [1, 1] : tensor<4x100xf32> to tensor<4x4xf32>
              %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg3 : tensor<4xf32>) attrs = {__internal_linalg_transform__ = "workgroup_k_tiled", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_4: f32, %out: f32):
              %16 = arith.addf %in, %in_4 : f32
              %17 = arith.addf %16, %out : f32
              linalg.yield %17 : f32
              } -> tensor<4xf32>
              scf.yield %15 : tensor<4xf32>
              }
              %14 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg0)
              scf.foreach_thread.perform_concurrently {
              tensor.parallel_insert_slice %13 into %arg1[%14] [4] [1] : tensor<4xf32> into tensor<256xf32>
              }
              } {mapping = [#gpu.thread<x>]}
              flow.dispatch.tensor.store %9, %2, offsets = [%3], sizes = [256], strides = [1] : tensor<256xf32> -> !flow.dispatch.tensor<readwrite:tensor<100000xf32>>
              } else {
              %6 = flow.dispatch.tensor.load %0, offsets = [%3, 0], sizes = [%4, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %7 = flow.dispatch.tensor.load %1, offsets = [%3, 0], sizes = [%4, 100], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>> -> tensor<?x100xf32>
              %8 = flow.dispatch.tensor.load %2, offsets = [%3], sizes = [%4], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<100000xf32>> -> tensor<?xf32>
              %dim = tensor.dim %6, %c0 : tensor<?x100xf32>
              // 64个线程并发计算,每个线程计算[%11, 100]的分块
              %9 = scf.foreach_thread (%arg0) in (%c64) shared_outs(%arg1 = %8) -> (tensor<?xf32>) {
              %10 = affine.min affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 64)) + s0, s0 ceildiv 64)>(%arg0)[%dim]
              %11 = affine.max affine_map<(d0) -> (0, d0)>(%10)
              %12 = affine.apply affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 64))>(%arg0)[%dim]
              %13 = affine.apply affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 64))>(%arg0)[%dim]
              %14 = affine.apply affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 64))>(%arg0)[%dim]
              %extracted_slice = tensor.extract_slice %6[%12, 0] [%11, 100] [1, 1] : tensor<?x100xf32> to tensor<?x100xf32>
              %extracted_slice_0 = tensor.extract_slice %7[%13, 0] [%11, 100] [1, 1] : tensor<?x100xf32> to tensor<?x100xf32>
              %extracted_slice_1 = tensor.extract_slice %arg1[%14] [%11] [1] : tensor<?xf32> to tensor<?xf32>
              // 内层reduction loop的vectorize
              %15 = scf.for %arg2 = %c0 to %c100 step %c4 iter_args(%arg3 = %extracted_slice_1) -> (tensor<?xf32>) {
              %extracted_slice_2 = tensor.extract_slice %extracted_slice[0, %arg2] [%11, 4] [1, 1] : tensor<?x100xf32> to tensor<?x4xf32>
              %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[0, %arg2] [%11, 4] [1, 1] : tensor<?x100xf32> to tensor<?x4xf32>
              %extracted_slice_4 = tensor.extract_slice %arg3[0] [%11] [1] : tensor<?xf32> to tensor<?xf32>
              %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<?x4xf32>, tensor<?x4xf32>) outs(%extracted_slice_4 : tensor<?xf32>) attrs = {__internal_linalg_transform__ = "workgroup_k_tiled", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_5: f32, %out: f32):
              %18 = arith.addf %in, %in_5 : f32
              %19 = arith.addf %18, %out : f32
              linalg.yield %19 : f32
              } -> tensor<?xf32>
              %inserted_slice = tensor.insert_slice %17 into %arg3[0] [%11] [1] : tensor<?xf32> into tensor<?xf32>
              scf.yield %inserted_slice : tensor<?xf32>
              }
              %16 = affine.apply affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 64))>(%arg0)[%dim]
              scf.foreach_thread.perform_concurrently {
              tensor.parallel_insert_slice %15 into %arg1[%16] [%11] [1] : tensor<?xf32> into tensor<?xf32>
              }
              } {mapping = [#gpu.thread<x>]}
              flow.dispatch.tensor.store %9, %2, offsets = [%3], sizes = [%4], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<readwrite:tensor<100000xf32>>
              }
              return
              }
            • createRemoveSingleIterationLoopPass

            • createGPUVectorizationPass

              将内层可被向量化的linalg op转换成vector op。

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              %11 = scf.for %arg2 = %c0 to %c100 step %c4 iter_args(%arg3 = %extracted_slice_1) -> (tensor<4xf32>) {
              %extracted_slice_2 = tensor.extract_slice %extracted_slice[0, %arg2] [4, 4] [1, 1] : tensor<4x100xf32> to tensor<4x4xf32>
              %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[0, %arg2] [4, 4] [1, 1] : tensor<4x100xf32> to tensor<4x4xf32>
              %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg3 : tensor<4xf32>) attrs = {__internal_linalg_transform__ = "workgroup_k_tiled", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_4: f32, %out: f32):
              %13 = arith.addf %in, %in_4 : f32
              %14 = arith.addf %13, %out : f32
              linalg.yield %14 : f32
              } -> tensor<4xf32>
              scf.yield %12 : tensor<4xf32>
              }

              转换成

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              %11 = vector.transfer_read %extracted_slice_1[%c0], %cst {in_bounds = [true]} : tensor<4xf32>, vector<4xf32>
              %12 = scf.for %arg2 = %c0 to %c100 step %c4 iter_args(%arg3 = %11) -> (vector<4xf32>) {
              %extracted_slice_2 = tensor.extract_slice %extracted_slice[0, %arg2] [4, 4] [1, 1] : tensor<4x100xf32> to tensor<4x4xf32>
              %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[0, %arg2] [4, 4] [1, 1] : tensor<4x100xf32> to tensor<4x4xf32>
              %14 = vector.transfer_read %extracted_slice_2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
              %15 = vector.transfer_read %extracted_slice_3[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
              %16 = arith.addf %14, %15 : vector<4x4xf32>
              %17 = vector.multi_reduction <add>, %16, %arg3 [1] : vector<4x4xf32> to vector<4xf32>
              scf.yield %17 : vector<4xf32>
              }
              %13 = vector.transfer_write %12, %extracted_slice_1[%c0] {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
            • addBufferizePasses

              将tensor语义转换成memref语义。上面完整的source func代码会转换成如下代码:

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              20
              21
              22
              23
              24
              25
              26
              27
              28
              29
              30
              31
              32
              33
              34
              35
              36
              37
              38
              39
              40
              41
              42
              43
              44
              45
              46
              47
              48
              49
              50
              51
              52
              53
              54
              55
              56
              57
              58
              59
              60
              61
              62
              func.func @test_dispatch_0_generic_100000x100() {
              %cst = arith.constant 0.000000e+00 : f32
              %c100 = arith.constant 100 : index
              %c4 = arith.constant 4 : index
              %c64 = arith.constant 64 : index
              %c256 = arith.constant 256 : index
              %c0 = arith.constant 0 : index
              %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32, #hal.descriptor_type<storage_buffer>>
              memref.assume_alignment %0, 64 : memref<100000x100xf32, #hal.descriptor_type<storage_buffer>>
              %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>>
              %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32, #hal.descriptor_type<storage_buffer>>
              memref.assume_alignment %2, 64 : memref<100000x100xf32, #hal.descriptor_type<storage_buffer>>
              %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<100000x100xf32>>
              %4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<100000xf32, #hal.descriptor_type<storage_buffer>>
              memref.assume_alignment %4, 64 : memref<100000xf32, #hal.descriptor_type<storage_buffer>>
              %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readwrite:tensor<100000xf32>>
              %workgroup_id_x = hal.interface.workgroup.id[0] : index
              %6 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x]
              %7 = affine.min affine_map<()[s0] -> (s0 * -256 + 100000, 256)>()[%workgroup_id_x]
              %8 = arith.cmpi eq, %7, %c256 : index
              scf.if %8 {
              %subview = memref.subview %0[%6, 0] [256, 100] [1, 1] : memref<100000x100xf32, #hal.descriptor_type<storage_buffer>> to memref<256x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              %subview_0 = memref.subview %2[%6, 0] [256, 100] [1, 1] : memref<100000x100xf32, #hal.descriptor_type<storage_buffer>> to memref<256x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              %subview_1 = memref.subview %4[%6] [256] [1] : memref<100000xf32, #hal.descriptor_type<storage_buffer>> to memref<256xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              scf.foreach_thread (%arg0) in (%c64) {
              %9 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg0)
              %subview_2 = memref.subview %subview_1[%9] [4] [1] : memref<256xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              %10 = vector.transfer_read %subview_1[%9], %cst {in_bounds = [true]} : memref<256xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
              %11 = scf.for %arg1 = %c0 to %c100 step %c4 iter_args(%arg2 = %10) -> (vector<4xf32>) {
              %12 = vector.transfer_read %subview[%9, %arg1], %cst {in_bounds = [true, true]} : memref<256x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4x4xf32>
              %13 = vector.transfer_read %subview_0[%9, %arg1], %cst {in_bounds = [true, true]} : memref<256x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4x4xf32>
              %14 = arith.addf %12, %13 : vector<4x4xf32>
              %15 = vector.multi_reduction <add>, %14, %arg2 [1] : vector<4x4xf32> to vector<4xf32>
              scf.yield %15 : vector<4xf32>
              }
              vector.transfer_write %11, %subview_2[%c0] {in_bounds = [true]} : vector<4xf32>, memref<4xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              } {mapping = [#gpu.thread<x>]}
              } else {
              %subview = memref.subview %0[%6, 0] [%7, 100] [1, 1] : memref<100000x100xf32, #hal.descriptor_type<storage_buffer>> to memref<?x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              %subview_0 = memref.subview %2[%6, 0] [%7, 100] [1, 1] : memref<100000x100xf32, #hal.descriptor_type<storage_buffer>> to memref<?x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              %subview_1 = memref.subview %4[%6] [%7] [1] : memref<100000xf32, #hal.descriptor_type<storage_buffer>> to memref<?xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              scf.foreach_thread (%arg0) in (%c64) {
              %9 = affine.min affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 64)) + s0, s0 ceildiv 64)>(%arg0)[%7]
              %10 = affine.max affine_map<(d0) -> (0, d0)>(%9)
              %11 = affine.apply affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 64))>(%arg0)[%7]
              %subview_2 = memref.subview %subview[%11, 0] [%10, 100] [1, 1] : memref<?x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              %subview_3 = memref.subview %subview_0[%11, 0] [%10, 100] [1, 1] : memref<?x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              %subview_4 = memref.subview %subview_1[%11] [%10] [1] : memref<?xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              scf.for %arg1 = %c0 to %c100 step %c4 {
              %subview_5 = memref.subview %subview_2[0, %arg1] [%10, 4] [1, 1] : memref<?x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x4xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              %subview_6 = memref.subview %subview_3[0, %arg1] [%10, 4] [1, 1] : memref<?x100xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x4xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%subview_5, %subview_6 : memref<?x4xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<?x4xf32, strided<[100, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_4 : memref<?xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>) attrs = {__internal_linalg_transform__ = "workgroup_k_tiled", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_7: f32, %out: f32):
              %12 = arith.addf %in, %in_7 : f32
              %13 = arith.addf %12, %out : f32
              linalg.yield %13 : f32
              }
              }
              } {mapping = [#gpu.thread<x>]}
              }
              return
              }
            • createLLVMGPUDistribute

              将任务分配到每一个线程,source func从线程块的计算逻辑转换成每个线程的计算逻辑,即用gpu.thread_id(x, y, z)替换scf.foreach_thread。

              1
              2
              3
              4
              5
              6
              7
              8
              9
              10
              11
              12
              13
              14
              15
              16
              17
              18
              19
              20
              21
              22
              23
              24
              25
              26
              27
              28
              29
              30
              31
              32
              33
              34
              35
              36
              37
              38
              39
              40
              41
              42
              43
              44
              45
              46
              47
              48
              49
              50
              51
              52
              53
              54
              55
              56
              57
              58
              59
              60
              61
              62
              63
              func.func @test_dispatch_0_generic_100000x100() {
              %cst = arith.constant 0.000000e+00 : f32
              %c100 = arith.constant 100 : index
              %c4 = arith.constant 4 : index
              %c64 = arith.constant 64 : index
              %c256 = arith.constant 256 : index
              %c0 = arith.constant 0 : index
              %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32>
              memref.assume_alignment %0, 64 : memref<100000x100xf32>
              %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32>
              memref.assume_alignment %1, 64 : memref<100000x100xf32>
              %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<100000xf32>
              memref.assume_alignment %2, 64 : memref<100000xf32>
              %workgroup_id_x = hal.interface.workgroup.id[0] : index
              %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x]
              %4 = affine.min affine_map<()[s0] -> (s0 * -256 + 100000, 256)>()[%workgroup_id_x]
              %5 = arith.cmpi eq, %4, %c256 : index
              scf.if %5 {
              %subview = memref.subview %0[%3, 0] [256, 100] [1, 1] : memref<100000x100xf32> to memref<256x100xf32, strided<[100, 1], offset: ?>>
              %subview_0 = memref.subview %1[%3, 0] [256, 100] [1, 1] : memref<100000x100xf32> to memref<256x100xf32, strided<[100, 1], offset: ?>>
              %subview_1 = memref.subview %2[%3] [256] [1] : memref<100000xf32> to memref<256xf32, strided<[1], offset: ?>>
              %c1 = arith.constant 1 : index
              %6 = gpu.thread_id x
              %7 = gpu.thread_id y
              %8 = gpu.thread_id z
              %9 = affine.apply affine_map<(d0) -> (d0 * 4)>(%6)
              %subview_2 = memref.subview %subview_1[%9] [4] [1] : memref<256xf32, strided<[1], offset: ?>> to memref<4xf32, strided<[1], offset: ?>>
              %10 = vector.transfer_read %subview_1[%9], %cst {in_bounds = [true]} : memref<256xf32, strided<[1], offset: ?>>, vector<4xf32>
              %11 = scf.for %arg0 = %c0 to %c100 step %c4 iter_args(%arg1 = %10) -> (vector<4xf32>) {
              %12 = vector.transfer_read %subview[%9, %arg0], %cst {in_bounds = [true, true]} : memref<256x100xf32, strided<[100, 1], offset: ?>>, vector<4x4xf32>
              %13 = vector.transfer_read %subview_0[%9, %arg0], %cst {in_bounds = [true, true]} : memref<256x100xf32, strided<[100, 1], offset: ?>>, vector<4x4xf32>
              %14 = arith.addf %12, %13 : vector<4x4xf32>
              %15 = vector.multi_reduction <add>, %14, %arg1 [1] : vector<4x4xf32> to vector<4xf32>
              scf.yield %15 : vector<4xf32>
              }
              vector.transfer_write %11, %subview_2[%c0] {in_bounds = [true]} : vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
              } else {
              %subview = memref.subview %0[%3, 0] [%4, 100] [1, 1] : memref<100000x100xf32> to memref<?x100xf32, strided<[100, 1], offset: ?>>
              %subview_0 = memref.subview %1[%3, 0] [%4, 100] [1, 1] : memref<100000x100xf32> to memref<?x100xf32, strided<[100, 1], offset: ?>>
              %subview_1 = memref.subview %2[%3] [%4] [1] : memref<100000xf32> to memref<?xf32, strided<[1], offset: ?>>
              %c1 = arith.constant 1 : index
              %6 = gpu.thread_id x
              %7 = gpu.thread_id y
              %8 = gpu.thread_id z
              %9 = affine.min affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 64)) + s0, s0 ceildiv 64)>(%6)[%4]
              %10 = affine.max affine_map<(d0) -> (0, d0)>(%9)
              %11 = affine.apply affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 64))>(%6)[%4]
              %subview_2 = memref.subview %subview[%11, 0] [%10, 100] [1, 1] : memref<?x100xf32, strided<[100, 1], offset: ?>> to memref<?x100xf32, strided<[100, 1], offset: ?>>
              %subview_3 = memref.subview %subview_0[%11, 0] [%10, 100] [1, 1] : memref<?x100xf32, strided<[100, 1], offset: ?>> to memref<?x100xf32, strided<[100, 1], offset: ?>>
              %subview_4 = memref.subview %subview_1[%11] [%10] [1] : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
              scf.for %arg0 = %c0 to %c100 step %c4 {
              %subview_5 = memref.subview %subview_2[0, %arg0] [%10, 4] [1, 1] : memref<?x100xf32, strided<[100, 1], offset: ?>> to memref<?x4xf32, strided<[100, 1], offset: ?>>
              %subview_6 = memref.subview %subview_3[0, %arg0] [%10, 4] [1, 1] : memref<?x100xf32, strided<[100, 1], offset: ?>> to memref<?x4xf32, strided<[100, 1], offset: ?>>
              linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%subview_5, %subview_6 : memref<?x4xf32, strided<[100, 1], offset: ?>>, memref<?x4xf32, strided<[100, 1], offset: ?>>) outs(%subview_4 : memref<?xf32, strided<[1], offset: ?>>) attrs = {__internal_linalg_transform__ = "workgroup_k_tiled", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[256, 4]]>} {
              ^bb0(%in: f32, %in_7: f32, %out: f32):
              %12 = arith.addf %in, %in_7 : f32
              %13 = arith.addf %12, %out : f32
              linalg.yield %13 : f32
              }
              }
              }
              return
              }
            • createLoopInvariantCodeMotionPass

            • memref::createFoldMemRefAliasOpsPass

            • createOptimizeVectorTransferPass

          • GPUMatmulSimtPassPipeline

          • GPUMatmulTensorCorePassPipeline

          • GPUTransposePassPipeline

          • GPUWarpReductionPassPipeline

          • GPUTransformDialectPasses

      • addLowerToLLVMGPUPasses

        继续将device代码递降到affine和gpu dialect,最终转换到NVVM IR或ROCDL IR。

        • IREE::LinalgExt::createLinalgExtToLoopsPass

          将LinalgExt op转换成loops。

        • createMemrefCopyToLinalgPass

          memref.copy转换成linalg generic op。

        • createConvertLinalgToLoopsPass

          将linalg generic op转换成loops。

        • createPadDynamicAlloc

          以pad的方式申请动态大小的内存。比如需要申请的内存大小和dim相关,%dim = affine_max(0, %src),那么这里就会以%dim = %src的最大size来申请内存。

        • createLowerAffinePass

          将affine op(比如affine.for, affine.if and affine.apply等) 递降成更低层的arith、memref和scf op。上面完整的source func代码会转换成如下代码,

          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          11
          12
          13
          14
          15
          16
          17
          18
          19
          20
          21
          22
          23
          24
          25
          26
          27
          28
          29
          30
          31
          32
          33
          34
          35
          36
          37
          38
          39
          40
          41
          42
          43
          44
          45
          46
          47
          48
          49
          50
          51
          52
          53
          54
          55
          56
          57
          58
          59
          60
          61
          62
          63
          64
          65
          66
          67
          68
          69
          70
          71
          72
          73
          74
          func.func @test_dispatch_0_generic_100000x100() {
          %c-1 = arith.constant -1 : index
          %c64 = arith.constant 64 : index
          %c100000 = arith.constant 100000 : index
          %c-256 = arith.constant -256 : index
          %c1 = arith.constant 1 : index
          %cst = arith.constant 0.000000e+00 : f32
          %c100 = arith.constant 100 : index
          %c4 = arith.constant 4 : index
          %c256 = arith.constant 256 : index
          %c0 = arith.constant 0 : index
          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32>
          memref.assume_alignment %0, 64 : memref<100000x100xf32>
          %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32>
          memref.assume_alignment %1, 64 : memref<100000x100xf32>
          %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<100000xf32>
          memref.assume_alignment %2, 64 : memref<100000xf32>
          %workgroup_id_x = hal.interface.workgroup.id[0] : index
          %3 = arith.muli %workgroup_id_x, %c-256 : index
          %4 = arith.addi %3, %c100000 : index
          %5 = arith.cmpi slt, %4, %c256 : index
          %6 = arith.select %5, %4, %c256 : index
          %7 = arith.cmpi eq, %6, %c256 : index
          scf.if %7 {
          %8 = gpu.thread_id x
          %9 = arith.muli %8, %c4 : index
          %10 = arith.muli %workgroup_id_x, %c256 : index
          %11 = arith.addi %9, %10 : index
          %12 = vector.transfer_read %2[%11], %cst {in_bounds = [true]} : memref<100000xf32>, vector<4xf32>
          %13 = scf.for %arg0 = %c0 to %c100 step %c4 iter_args(%arg1 = %12) -> (vector<4xf32>) {
          %14 = vector.transfer_read %0[%11, %arg0], %cst {in_bounds = [true, true]} : memref<100000x100xf32>, vector<4x4xf32>
          %15 = vector.transfer_read %1[%11, %arg0], %cst {in_bounds = [true, true]} : memref<100000x100xf32>, vector<4x4xf32>
          %16 = arith.addf %14, %15 : vector<4x4xf32>
          %17 = vector.multi_reduction <add>, %16, %arg1 [1] : vector<4x4xf32> to vector<4xf32>
          scf.yield %17 : vector<4xf32>
          }
          vector.transfer_write %13, %2[%11] {in_bounds = [true]} : vector<4xf32>, memref<100000xf32>
          } else {
          %8 = gpu.thread_id x
          %9 = arith.cmpi sle, %6, %c0 : index
          %10 = arith.subi %c0, %6 : index
          %11 = arith.subi %6, %c1 : index
          %12 = arith.select %9, %10, %11 : index
          %13 = arith.divsi %12, %c64 : index
          %14 = arith.subi %c0, %13 : index
          %15 = arith.addi %13, %c1 : index
          %16 = arith.select %9, %14, %15 : index
          %17 = arith.muli %8, %16 : index
          %18 = arith.muli %17, %c-1 : index
          %19 = arith.addi %18, %6 : index
          %20 = arith.cmpi slt, %19, %16 : index
          %21 = arith.select %20, %19, %16 : index
          %22 = arith.cmpi slt, %21, %c0 : index
          %23 = arith.select %22, %c0, %21 : index
          %24 = arith.muli %workgroup_id_x, %c256 : index
          %25 = arith.addi %17, %24 : index
          %subview = memref.subview %2[%25] [%23] [1] : memref<100000xf32> to memref<?xf32, strided<[1], offset: ?>>
          scf.for %arg0 = %c0 to %c100 step %c4 {
          %subview_0 = memref.subview %0[%25, %arg0] [%23, 4] [1, 1] : memref<100000x100xf32> to memref<?x4xf32, strided<[100, 1], offset: ?>>
          %subview_1 = memref.subview %1[%25, %arg0] [%23, 4] [1, 1] : memref<100000x100xf32> to memref<?x4xf32, strided<[100, 1], offset: ?>>
          scf.for %arg1 = %c0 to %23 step %c1 {
          scf.for %arg2 = %c0 to %c4 step %c1 {
          %26 = memref.load %subview_0[%arg1, %arg2] : memref<?x4xf32, strided<[100, 1], offset: ?>>
          %27 = memref.load %subview_1[%arg1, %arg2] : memref<?x4xf32, strided<[100, 1], offset: ?>>
          %28 = memref.load %subview[%arg1] : memref<?xf32, strided<[1], offset: ?>>
          %29 = arith.addf %26, %27 : f32
          %30 = arith.addf %29, %28 : f32
          memref.store %30, %subview[%arg1] : memref<?xf32, strided<[1], offset: ?>>
          }
          }
          }
          }
          return
          }
        • arith::createConstantBufferizePass

        • createFoldTensorExtractOpPass

        • createLLVMGPUVectorLoweringPass

          将多维vector op展开成一维的vector op。上面完整的source func代码会转换成如下代码,

          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          11
          12
          13
          14
          15
          16
          17
          18
          19
          20
          21
          22
          23
          24
          25
          26
          27
          28
          29
          30
          31
          32
          33
          34
          35
          36
          37
          38
          39
          40
          41
          42
          43
          44
          45
          46
          47
          48
          49
          50
          51
          52
          53
          54
          55
          56
          57
          58
          59
          60
          61
          62
          63
          64
          65
          66
          67
          68
          69
          70
          71
          72
          73
          74
          75
          76
          77
          78
          79
          80
          81
          82
          83
          84
          85
          86
          87
          88
          89
          90
          91
          92
          93
          94
          95
          96
          97
          98
          99
          100
          101
          102
          103
          104
          105
          func.func @test_dispatch_0_generic_100000x100() {
          %cst = arith.constant dense<0.000000e+00> : vector<4x4xf32>
          %c-1 = arith.constant -1 : index
          %c64 = arith.constant 64 : index
          %c100000 = arith.constant 100000 : index
          %c-256 = arith.constant -256 : index
          %c1 = arith.constant 1 : index
          %c100 = arith.constant 100 : index
          %c4 = arith.constant 4 : index
          %c256 = arith.constant 256 : index
          %c0 = arith.constant 0 : index
          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32>
          memref.assume_alignment %0, 64 : memref<100000x100xf32>
          %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32>
          memref.assume_alignment %1, 64 : memref<100000x100xf32>
          %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<100000xf32>
          memref.assume_alignment %2, 64 : memref<100000xf32>
          %workgroup_id_x = hal.interface.workgroup.id[0] : index
          %3 = arith.muli %workgroup_id_x, %c-256 : index
          %4 = arith.addi %3, %c100000 : index
          %5 = arith.cmpi slt, %4, %c256 : index
          %6 = arith.select %5, %4, %c256 : index
          %7 = arith.cmpi eq, %6, %c256 : index
          scf.if %7 {
          %8 = gpu.thread_id x
          %9 = arith.muli %8, %c4 : index
          %10 = arith.muli %workgroup_id_x, %c256 : index
          %11 = arith.addi %9, %10 : index
          %12 = vector.load %2[%11] : memref<100000xf32>, vector<4xf32>
          %13 = scf.for %arg0 = %c0 to %c100 step %c4 iter_args(%arg1 = %12) -> (vector<4xf32>) {
          %14 = vector.load %0[%11, %arg0] : memref<100000x100xf32>, vector<4xf32>
          %15 = vector.insert %14, %cst [0] : vector<4xf32> into vector<4x4xf32>
          %16 = affine.apply affine_map<(d0) -> (d0 + 1)>(%11)
          %17 = vector.load %0[%16, %arg0] : memref<100000x100xf32>, vector<4xf32>
          %18 = vector.insert %17, %15 [1] : vector<4xf32> into vector<4x4xf32>
          %19 = affine.apply affine_map<(d0) -> (d0 + 2)>(%11)
          %20 = vector.load %0[%19, %arg0] : memref<100000x100xf32>, vector<4xf32>
          %21 = vector.insert %20, %18 [2] : vector<4xf32> into vector<4x4xf32>
          %22 = affine.apply affine_map<(d0) -> (d0 + 3)>(%11)
          %23 = vector.load %0[%22, %arg0] : memref<100000x100xf32>, vector<4xf32>
          %24 = vector.insert %23, %21 [3] : vector<4xf32> into vector<4x4xf32>
          %25 = vector.load %1[%11, %arg0] : memref<100000x100xf32>, vector<4xf32>
          %26 = vector.insert %25, %cst [0] : vector<4xf32> into vector<4x4xf32>
          %27 = affine.apply affine_map<(d0) -> (d0 + 1)>(%11)
          %28 = vector.load %1[%27, %arg0] : memref<100000x100xf32>, vector<4xf32>
          %29 = vector.insert %28, %26 [1] : vector<4xf32> into vector<4x4xf32>
          %30 = affine.apply affine_map<(d0) -> (d0 + 2)>(%11)
          %31 = vector.load %1[%30, %arg0] : memref<100000x100xf32>, vector<4xf32>
          %32 = vector.insert %31, %29 [2] : vector<4xf32> into vector<4x4xf32>
          %33 = affine.apply affine_map<(d0) -> (d0 + 3)>(%11)
          %34 = vector.load %1[%33, %arg0] : memref<100000x100xf32>, vector<4xf32>
          %35 = vector.insert %34, %32 [3] : vector<4xf32> into vector<4x4xf32>
          %36 = arith.addf %24, %35 : vector<4x4xf32>
          %37 = vector.transpose %36, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
          %38 = vector.extract %37[0] : vector<4x4xf32>
          %39 = arith.addf %38, %arg1 : vector<4xf32>
          %40 = vector.extract %37[1] : vector<4x4xf32>
          %41 = arith.addf %40, %39 : vector<4xf32>
          %42 = vector.extract %37[2] : vector<4x4xf32>
          %43 = arith.addf %42, %41 : vector<4xf32>
          %44 = vector.extract %37[3] : vector<4x4xf32>
          %45 = arith.addf %44, %43 : vector<4xf32>
          scf.yield %45 : vector<4xf32>
          }
          vector.store %13, %2[%11] : memref<100000xf32>, vector<4xf32>
          } else {
          %8 = gpu.thread_id x
          %9 = arith.cmpi sle, %6, %c0 : index
          %10 = arith.subi %c0, %6 : index
          %11 = arith.subi %6, %c1 : index
          %12 = arith.select %9, %10, %11 : index
          %13 = arith.divsi %12, %c64 : index
          %14 = arith.subi %c0, %13 : index
          %15 = arith.addi %13, %c1 : index
          %16 = arith.select %9, %14, %15 : index
          %17 = arith.muli %8, %16 : index
          %18 = arith.muli %17, %c-1 : index
          %19 = arith.addi %18, %6 : index
          %20 = arith.cmpi slt, %19, %16 : index
          %21 = arith.select %20, %19, %16 : index
          %22 = arith.cmpi slt, %21, %c0 : index
          %23 = arith.select %22, %c0, %21 : index
          %24 = arith.muli %workgroup_id_x, %c256 : index
          %25 = arith.addi %17, %24 : index
          scf.for %arg0 = %c0 to %c100 step %c4 {
          scf.for %arg1 = %c0 to %23 step %c1 {
          scf.for %arg2 = %c0 to %c4 step %c1 {
          %26 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg1)[%25]
          %27 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg2)[%arg0]
          %28 = memref.load %0[%26, %27] : memref<100000x100xf32>
          %29 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg1)[%25]
          %30 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg2)[%arg0]
          %31 = memref.load %1[%29, %30] : memref<100000x100xf32>
          %32 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg1)[%25]
          %33 = memref.load %2[%32] : memref<100000xf32>
          %34 = arith.addf %28, %31 : f32
          %35 = arith.addf %34, %33 : f32
          %36 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg1)[%25]
          memref.store %35, %2[%36] : memref<100000xf32>
          }
          }
          }
          }
          return
          }
        • createConvertSCFToCFPass

          将structure的control flow转换成CFG的控制流。上面完整的source func代码会转换成如下代码,

          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          11
          12
          13
          14
          15
          16
          17
          18
          19
          20
          21
          22
          23
          24
          25
          26
          27
          28
          29
          30
          31
          32
          33
          34
          35
          36
          37
          38
          39
          40
          41
          42
          43
          44
          45
          46
          47
          48
          49
          50
          51
          52
          53
          54
          55
          56
          57
          58
          59
          60
          61
          62
          63
          64
          65
          66
          67
          68
          69
          70
          71
          72
          73
          74
          75
          76
          77
          78
          79
          80
          81
          82
          83
          84
          85
          86
          87
          88
          89
          90
          91
          92
          93
          94
          95
          96
          97
          98
          99
          100
          101
          102
          103
          104
          105
          106
          107
          108
          109
          110
          111
          112
          113
          114
          115
          116
          117
          118
          119
          120
          121
          func.func @test_dispatch_0_generic_100000x100() {
          %cst = arith.constant dense<0.000000e+00> : vector<4x4xf32>
          %c-1 = arith.constant -1 : index
          %c64 = arith.constant 64 : index
          %c100000 = arith.constant 100000 : index
          %c-256 = arith.constant -256 : index
          %c1 = arith.constant 1 : index
          %c100 = arith.constant 100 : index
          %c4 = arith.constant 4 : index
          %c256 = arith.constant 256 : index
          %c0 = arith.constant 0 : index
          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32>
          memref.assume_alignment %0, 64 : memref<100000x100xf32>
          %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<100000x100xf32>
          memref.assume_alignment %1, 64 : memref<100000x100xf32>
          %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<100000xf32>
          memref.assume_alignment %2, 64 : memref<100000xf32>
          %workgroup_id_x = hal.interface.workgroup.id[0] : index
          %3 = arith.muli %workgroup_id_x, %c-256 : index
          %4 = arith.addi %3, %c100000 : index
          %5 = arith.cmpi slt, %4, %c256 : index
          %6 = arith.select %5, %4, %c256 : index
          %7 = arith.cmpi eq, %6, %c256 : index
          cf.cond_br %7, ^bb1, ^bb5
          ^bb1: // pred: ^bb0
          %8 = gpu.thread_id x
          %9 = arith.muli %8, %c4 : index
          %10 = arith.muli %workgroup_id_x, %c256 : index
          %11 = arith.addi %9, %10 : index
          %12 = vector.load %2[%11] : memref<100000xf32>, vector<4xf32>
          cf.br ^bb2(%c0, %12 : index, vector<4xf32>)
          ^bb2(%13: index, %14: vector<4xf32>): // 2 preds: ^bb1, ^bb3
          %15 = arith.cmpi slt, %13, %c100 : index
          cf.cond_br %15, ^bb3, ^bb4
          ^bb3: // pred: ^bb2
          %16 = vector.load %0[%11, %13] : memref<100000x100xf32>, vector<4xf32>
          %17 = vector.insert %16, %cst [0] : vector<4xf32> into vector<4x4xf32>
          %c1_0 = arith.constant 1 : index
          %18 = arith.addi %11, %c1_0 : index
          %19 = vector.load %0[%18, %13] : memref<100000x100xf32>, vector<4xf32>
          %20 = vector.insert %19, %17 [1] : vector<4xf32> into vector<4x4xf32>
          %c2 = arith.constant 2 : index
          %21 = arith.addi %11, %c2 : index
          %22 = vector.load %0[%21, %13] : memref<100000x100xf32>, vector<4xf32>
          %23 = vector.insert %22, %20 [2] : vector<4xf32> into vector<4x4xf32>
          %c3 = arith.constant 3 : index
          %24 = arith.addi %11, %c3 : index
          %25 = vector.load %0[%24, %13] : memref<100000x100xf32>, vector<4xf32>
          %26 = vector.insert %25, %23 [3] : vector<4xf32> into vector<4x4xf32>
          %27 = vector.load %1[%11, %13] : memref<100000x100xf32>, vector<4xf32>
          %28 = vector.insert %27, %cst [0] : vector<4xf32> into vector<4x4xf32>
          %29 = vector.load %1[%18, %13] : memref<100000x100xf32>, vector<4xf32>
          %30 = vector.insert %29, %28 [1] : vector<4xf32> into vector<4x4xf32>
          %31 = vector.load %1[%21, %13] : memref<100000x100xf32>, vector<4xf32>
          %32 = vector.insert %31, %30 [2] : vector<4xf32> into vector<4x4xf32>
          %33 = vector.load %1[%24, %13] : memref<100000x100xf32>, vector<4xf32>
          %34 = vector.insert %33, %32 [3] : vector<4xf32> into vector<4x4xf32>
          %35 = arith.addf %26, %34 : vector<4x4xf32>
          %36 = vector.transpose %35, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
          %37 = vector.extract %36[0] : vector<4x4xf32>
          %38 = arith.addf %37, %14 : vector<4xf32>
          %39 = vector.extract %36[1] : vector<4x4xf32>
          %40 = arith.addf %39, %38 : vector<4xf32>
          %41 = vector.extract %36[2] : vector<4x4xf32>
          %42 = arith.addf %41, %40 : vector<4xf32>
          %43 = vector.extract %36[3] : vector<4x4xf32>
          %44 = arith.addf %43, %42 : vector<4xf32>
          %45 = arith.addi %13, %c4 : index
          cf.br ^bb2(%45, %44 : index, vector<4xf32>)
          ^bb4: // pred: ^bb2
          vector.store %14, %2[%11] : memref<100000xf32>, vector<4xf32>
          cf.br ^bb12
          ^bb5: // pred: ^bb0
          %46 = gpu.thread_id x
          %47 = arith.cmpi sle, %6, %c0 : index
          %48 = arith.subi %c0, %6 : index
          %49 = arith.subi %6, %c1 : index
          %50 = arith.select %47, %48, %49 : index
          %51 = arith.divsi %50, %c64 : index
          %52 = arith.subi %c0, %51 : index
          %53 = arith.addi %51, %c1 : index
          %54 = arith.select %47, %52, %53 : index
          %55 = arith.muli %46, %54 : index
          %56 = arith.muli %55, %c-1 : index
          %57 = arith.addi %56, %6 : index
          %58 = arith.cmpi slt, %57, %54 : index
          %59 = arith.select %58, %57, %54 : index
          %60 = arith.cmpi slt, %59, %c0 : index
          %61 = arith.select %60, %c0, %59 : index
          %62 = arith.muli %workgroup_id_x, %c256 : index
          %63 = arith.addi %55, %62 : index
          cf.br ^bb6(%c0 : index)
          ^bb6(%64: index): // 2 preds: ^bb5, ^bb11
          %65 = arith.cmpi slt, %64, %c100 : index
          cf.cond_br %65, ^bb7(%c0 : index), ^bb12
          ^bb7(%66: index): // 2 preds: ^bb6, ^bb10
          %67 = arith.cmpi slt, %66, %61 : index
          cf.cond_br %67, ^bb8(%c0 : index), ^bb11
          ^bb8(%68: index): // 2 preds: ^bb7, ^bb9
          %69 = arith.cmpi slt, %68, %c4 : index
          cf.cond_br %69, ^bb9, ^bb10
          ^bb9: // pred: ^bb8
          %70 = arith.addi %63, %66 : index
          %71 = arith.addi %64, %68 : index
          %72 = memref.load %0[%70, %71] : memref<100000x100xf32>
          %73 = memref.load %1[%70, %71] : memref<100000x100xf32>
          %74 = memref.load %2[%70] : memref<100000xf32>
          %75 = arith.addf %72, %73 : f32
          %76 = arith.addf %75, %74 : f32
          memref.store %76, %2[%70] : memref<100000xf32>
          %77 = arith.addi %68, %c1 : index
          cf.br ^bb8(%77 : index)
          ^bb10: // pred: ^bb8
          %78 = arith.addi %66, %c1 : index
          cf.br ^bb7(%78 : index)
          ^bb11: // pred: ^bb7
          %79 = arith.addi %64, %c4 : index
          cf.br ^bb6(%79 : index)
          ^bb12: // 2 preds: ^bb4, ^bb6
          return
          }
        • createPolynomialApproximationPass

        • arith::createArithExpandOpsPass

        • memref::createExpandOpsPass

        • memref::createExpandStridedMetadataPass

        • createLowerAffinePass

        • createStripDebugInfoPass

        • createConvertToROCDLPass或createConvertToNVVMPass

          转换到ROCDL IR或NVVM IR。上面完整的source func代码会转换成如下代码,

          1
          2
          3
          4
          5
          6
          7
          8
          9
          10
          11
          12
          13
          14
          15
          16
          17
          18
          19
          20
          21
          22
          23
          24
          25
          26
          27
          28
          29
          30
          31
          32
          33
          34
          35
          36
          37
          38
          39
          40
          41
          42
          43
          44
          45
          46
          47
          48
          49
          50
          51
          52
          53
          54
          55
          56
          57
          58
          59
          60
          61
          62
          63
          64
          65
          66
          67
          68
          69
          70
          71
          72
          73
          74
          75
          76
          77
          78
          79
          80
          81
          82
          83
          84
          85
          86
          87
          88
          89
          90
          91
          92
          93
          94
          95
          96
          97
          98
          99
          100
          101
          102
          103
          104
          105
          106
          107
          108
          109
          110
          111
          112
          113
          114
          115
          116
          117
          118
          119
          120
          121
          122
          123
          124
          125
          126
          127
          128
          129
          130
          131
          132
          133
          134
          135
          136
          137
          138
          139
          140
          141
          142
          143
          144
          145
          146
          147
          148
          149
          150
          151
          152
          153
          154
          155
          156
          157
          158
          159
          160
          161
          162
          163
          164
          165
          166
          167
          168
          169
          170
          171
          172
          173
          174
          175
          176
          177
          178
          179
          180
          181
          182
          183
          184
          185
          186
          187
          188
          189
          190
          191
          192
          193
          194
          195
          196
          197
          198
          199
          200
          201
          202
          203
          204
          205
          206
          207
          208
          209
          210
          211
          212
          213
          214
          215
          216
          217
          218
          219
          220
          221
          222
          223
          224
          225
          226
          227
          228
          229
          230
          231
          232
          233
          234
          235
          236
          237
          238
          239
          240
          241
          242
          243
          244
          245
          246
          247
          248
          249
          250
          251
          252
          253
          254
          255
          256
          257
          258
          259
          260
          261
          262
          263
          264
          265
          266
          267
          268
          269
          270
          271
          272
          273
          274
          275
          276
          277
          278
          279
          280
          281
          282
          283
          284
          285
          286
          287
          288
          289
          290
          291
          292
          293
          294
          295
          296
          297
          298
          299
          300
          301
          302
          303
          304
          305
          306
          307
          308
          309
          310
          311
          312
          313
          314
          315
          316
          317
          318
          319
          320
          321
          322
          323
          324
          325
          326
          327
          328
          329
          330
          331
          332
          333
          334
          335
          336
          337
          338
          339
          340
          341
          342
          343
          344
          345
          346
          347
          348
          349
          350
          351
          352
          353
          354
          355
          356
          357
          358
          359
          360
          361
          362
          363
          364
          365
          366
          367
          368
          369
          370
          371
          372
          373
          374
          375
          376
          377
          llvm.func @test_dispatch_0_generic_100000x100(%arg0: !llvm.ptr<f32> {llvm.align = 16 : i32}, %arg1: !llvm.ptr<f32> {llvm.align = 16 : i32}, %arg2: !llvm.ptr<f32> {llvm.align = 16 : i32}) {
          %0 = llvm.mlir.constant(3 : index) : i64
          %1 = llvm.mlir.constant(2 : index) : i64
          %2 = llvm.mlir.constant(dense<0.000000e+00> : vector<4x4xf32>) : !llvm.array<4 x vector<4xf32>>
          %3 = llvm.mlir.constant(-1 : index) : i64
          %4 = llvm.mlir.constant(64 : index) : i64
          %5 = llvm.mlir.constant(100000 : index) : i64
          %6 = llvm.mlir.constant(-256 : index) : i64
          %7 = llvm.mlir.constant(1 : index) : i64
          %8 = llvm.mlir.constant(100 : index) : i64
          %9 = llvm.mlir.constant(4 : index) : i64
          %10 = llvm.mlir.constant(256 : index) : i64
          %11 = llvm.mlir.constant(0 : index) : i64
          %12 = llvm.bitcast %arg0 : !llvm.ptr<f32> to !llvm.ptr<i8>
          %13 = llvm.getelementptr %12[%11] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
          %14 = llvm.bitcast %13 : !llvm.ptr<i8> to !llvm.ptr<f32>
          %15 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %16 = llvm.insertvalue %14, %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %17 = llvm.insertvalue %14, %16[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %18 = llvm.mlir.constant(0 : index) : i64
          %19 = llvm.insertvalue %18, %17[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %20 = llvm.mlir.constant(100000 : index) : i64
          %21 = llvm.insertvalue %20, %19[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %22 = llvm.mlir.constant(100 : index) : i64
          %23 = llvm.insertvalue %22, %21[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %24 = llvm.mlir.constant(100 : index) : i64
          %25 = llvm.insertvalue %24, %23[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %26 = llvm.mlir.constant(1 : index) : i64
          %27 = llvm.insertvalue %26, %25[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %28 = llvm.extractvalue %27[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %29 = llvm.mlir.constant(0 : index) : i64
          %30 = llvm.mlir.constant(63 : index) : i64
          %31 = llvm.ptrtoint %28 : !llvm.ptr<f32> to i64
          %32 = llvm.and %31, %30 : i64
          %33 = llvm.icmp "eq" %32, %29 : i64
          "llvm.intr.assume"(%33) : (i1) -> ()
          %34 = llvm.bitcast %arg1 : !llvm.ptr<f32> to !llvm.ptr<i8>
          %35 = llvm.getelementptr %34[%11] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
          %36 = llvm.bitcast %35 : !llvm.ptr<i8> to !llvm.ptr<f32>
          %37 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %38 = llvm.insertvalue %36, %37[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %39 = llvm.insertvalue %36, %38[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %40 = llvm.mlir.constant(0 : index) : i64
          %41 = llvm.insertvalue %40, %39[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %42 = llvm.mlir.constant(100000 : index) : i64
          %43 = llvm.insertvalue %42, %41[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %44 = llvm.mlir.constant(100 : index) : i64
          %45 = llvm.insertvalue %44, %43[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %46 = llvm.mlir.constant(100 : index) : i64
          %47 = llvm.insertvalue %46, %45[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %48 = llvm.mlir.constant(1 : index) : i64
          %49 = llvm.insertvalue %48, %47[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %50 = llvm.extractvalue %49[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %51 = llvm.mlir.constant(0 : index) : i64
          %52 = llvm.mlir.constant(63 : index) : i64
          %53 = llvm.ptrtoint %50 : !llvm.ptr<f32> to i64
          %54 = llvm.and %53, %52 : i64
          %55 = llvm.icmp "eq" %54, %51 : i64
          "llvm.intr.assume"(%55) : (i1) -> ()
          %56 = llvm.bitcast %arg2 : !llvm.ptr<f32> to !llvm.ptr<i8>
          %57 = llvm.getelementptr %56[%11] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
          %58 = llvm.bitcast %57 : !llvm.ptr<i8> to !llvm.ptr<f32>
          %59 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %60 = llvm.insertvalue %58, %59[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %61 = llvm.insertvalue %58, %60[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %62 = llvm.mlir.constant(0 : index) : i64
          %63 = llvm.insertvalue %62, %61[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %64 = llvm.mlir.constant(100000 : index) : i64
          %65 = llvm.insertvalue %64, %63[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %66 = llvm.mlir.constant(1 : index) : i64
          %67 = llvm.insertvalue %66, %65[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %68 = llvm.extractvalue %67[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %69 = llvm.mlir.constant(0 : index) : i64
          %70 = llvm.mlir.constant(63 : index) : i64
          %71 = llvm.ptrtoint %68 : !llvm.ptr<f32> to i64
          %72 = llvm.and %71, %70 : i64
          %73 = llvm.icmp "eq" %72, %69 : i64
          "llvm.intr.assume"(%73) : (i1) -> ()
          %74 = nvvm.read.ptx.sreg.ctaid.x : i32
          %75 = llvm.sext %74 : i32 to i64
          %76 = llvm.mul %75, %6 : i64
          %77 = llvm.add %76, %5 : i64
          %78 = llvm.icmp "slt" %77, %10 : i64
          %79 = llvm.select %78, %77, %10 : i1, i64
          %80 = llvm.icmp "eq" %79, %10 : i64
          llvm.cond_br %80, ^bb1, ^bb5
          ^bb1: // pred: ^bb0
          %81 = nvvm.read.ptx.sreg.tid.x : i32
          %82 = llvm.sext %81 : i32 to i64
          %83 = llvm.mul %82, %9 : i64
          %84 = llvm.mul %75, %10 : i64
          %85 = llvm.add %83, %84 : i64
          %86 = llvm.extractvalue %67[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %87 = llvm.getelementptr %86[%85] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %88 = llvm.bitcast %87 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          %89 = llvm.load %88 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          llvm.br ^bb2(%11, %89 : i64, vector<4xf32>)
          ^bb2(%90: i64, %91: vector<4xf32>): // 2 preds: ^bb1, ^bb3
          %92 = llvm.icmp "slt" %90, %8 : i64
          llvm.cond_br %92, ^bb3, ^bb4
          ^bb3: // pred: ^bb2
          %93 = llvm.extractvalue %27[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %94 = llvm.mlir.constant(100 : index) : i64
          %95 = llvm.mul %85, %94 : i64
          %96 = llvm.add %95, %90 : i64
          %97 = llvm.getelementptr %93[%96] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %98 = llvm.bitcast %97 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          %99 = llvm.load %98 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          %100 = llvm.insertvalue %99, %2[0] : !llvm.array<4 x vector<4xf32>>
          %101 = llvm.add %85, %7 : i64
          %102 = llvm.extractvalue %27[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %103 = llvm.mlir.constant(100 : index) : i64
          %104 = llvm.mul %101, %103 : i64
          %105 = llvm.add %104, %90 : i64
          %106 = llvm.getelementptr %102[%105] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %107 = llvm.bitcast %106 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          %108 = llvm.load %107 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          %109 = llvm.insertvalue %108, %100[1] : !llvm.array<4 x vector<4xf32>>
          %110 = llvm.add %85, %1 : i64
          %111 = llvm.extractvalue %27[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %112 = llvm.mlir.constant(100 : index) : i64
          %113 = llvm.mul %110, %112 : i64
          %114 = llvm.add %113, %90 : i64
          %115 = llvm.getelementptr %111[%114] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %116 = llvm.bitcast %115 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          %117 = llvm.load %116 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          %118 = llvm.insertvalue %117, %109[2] : !llvm.array<4 x vector<4xf32>>
          %119 = llvm.add %85, %0 : i64
          %120 = llvm.extractvalue %27[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %121 = llvm.mlir.constant(100 : index) : i64
          %122 = llvm.mul %119, %121 : i64
          %123 = llvm.add %122, %90 : i64
          %124 = llvm.getelementptr %120[%123] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %125 = llvm.bitcast %124 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          %126 = llvm.load %125 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          %127 = llvm.insertvalue %126, %118[3] : !llvm.array<4 x vector<4xf32>>
          %128 = llvm.extractvalue %49[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %129 = llvm.mlir.constant(100 : index) : i64
          %130 = llvm.mul %85, %129 : i64
          %131 = llvm.add %130, %90 : i64
          %132 = llvm.getelementptr %128[%131] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %133 = llvm.bitcast %132 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          %134 = llvm.load %133 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          %135 = llvm.insertvalue %134, %2[0] : !llvm.array<4 x vector<4xf32>>
          %136 = llvm.extractvalue %49[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %137 = llvm.mlir.constant(100 : index) : i64
          %138 = llvm.mul %101, %137 : i64
          %139 = llvm.add %138, %90 : i64
          %140 = llvm.getelementptr %136[%139] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %141 = llvm.bitcast %140 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          %142 = llvm.load %141 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          %143 = llvm.insertvalue %142, %135[1] : !llvm.array<4 x vector<4xf32>>
          %144 = llvm.extractvalue %49[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %145 = llvm.mlir.constant(100 : index) : i64
          %146 = llvm.mul %110, %145 : i64
          %147 = llvm.add %146, %90 : i64
          %148 = llvm.getelementptr %144[%147] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %149 = llvm.bitcast %148 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          %150 = llvm.load %149 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          %151 = llvm.insertvalue %150, %143[2] : !llvm.array<4 x vector<4xf32>>
          %152 = llvm.extractvalue %49[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %153 = llvm.mlir.constant(100 : index) : i64
          %154 = llvm.mul %119, %153 : i64
          %155 = llvm.add %154, %90 : i64
          %156 = llvm.getelementptr %152[%155] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %157 = llvm.bitcast %156 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          %158 = llvm.load %157 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          %159 = llvm.insertvalue %158, %151[3] : !llvm.array<4 x vector<4xf32>>
          %160 = llvm.mlir.undef : !llvm.array<4 x vector<4xf32>>
          %161 = llvm.extractvalue %127[0] : !llvm.array<4 x vector<4xf32>>
          %162 = llvm.extractvalue %159[0] : !llvm.array<4 x vector<4xf32>>
          %163 = llvm.fadd %161, %162 : vector<4xf32>
          %164 = llvm.insertvalue %163, %160[0] : !llvm.array<4 x vector<4xf32>>
          %165 = llvm.extractvalue %127[1] : !llvm.array<4 x vector<4xf32>>
          %166 = llvm.extractvalue %159[1] : !llvm.array<4 x vector<4xf32>>
          %167 = llvm.fadd %165, %166 : vector<4xf32>
          %168 = llvm.insertvalue %167, %164[1] : !llvm.array<4 x vector<4xf32>>
          %169 = llvm.extractvalue %127[2] : !llvm.array<4 x vector<4xf32>>
          %170 = llvm.extractvalue %159[2] : !llvm.array<4 x vector<4xf32>>
          %171 = llvm.fadd %169, %170 : vector<4xf32>
          %172 = llvm.insertvalue %171, %168[2] : !llvm.array<4 x vector<4xf32>>
          %173 = llvm.extractvalue %127[3] : !llvm.array<4 x vector<4xf32>>
          %174 = llvm.extractvalue %159[3] : !llvm.array<4 x vector<4xf32>>
          %175 = llvm.fadd %173, %174 : vector<4xf32>
          %176 = llvm.insertvalue %175, %172[3] : !llvm.array<4 x vector<4xf32>>
          %177 = llvm.extractvalue %176[0] : !llvm.array<4 x vector<4xf32>>
          %178 = llvm.mlir.constant(0 : i64) : i64
          %179 = llvm.extractelement %177[%178 : i64] : vector<4xf32>
          %180 = llvm.extractvalue %2[0] : !llvm.array<4 x vector<4xf32>>
          %181 = llvm.mlir.constant(0 : i64) : i64
          %182 = llvm.insertelement %179, %180[%181 : i64] : vector<4xf32>
          %183 = llvm.insertvalue %182, %2[0] : !llvm.array<4 x vector<4xf32>>
          %184 = llvm.extractvalue %176[0] : !llvm.array<4 x vector<4xf32>>
          %185 = llvm.mlir.constant(1 : i64) : i64
          %186 = llvm.extractelement %184[%185 : i64] : vector<4xf32>
          %187 = llvm.extractvalue %183[1] : !llvm.array<4 x vector<4xf32>>
          %188 = llvm.mlir.constant(0 : i64) : i64
          %189 = llvm.insertelement %186, %187[%188 : i64] : vector<4xf32>
          %190 = llvm.insertvalue %189, %183[1] : !llvm.array<4 x vector<4xf32>>
          %191 = llvm.extractvalue %176[0] : !llvm.array<4 x vector<4xf32>>
          %192 = llvm.mlir.constant(2 : i64) : i64
          %193 = llvm.extractelement %191[%192 : i64] : vector<4xf32>
          %194 = llvm.extractvalue %190[2] : !llvm.array<4 x vector<4xf32>>
          %195 = llvm.mlir.constant(0 : i64) : i64
          %196 = llvm.insertelement %193, %194[%195 : i64] : vector<4xf32>
          %197 = llvm.insertvalue %196, %190[2] : !llvm.array<4 x vector<4xf32>>
          %198 = llvm.extractvalue %176[0] : !llvm.array<4 x vector<4xf32>>
          %199 = llvm.mlir.constant(3 : i64) : i64
          %200 = llvm.extractelement %198[%199 : i64] : vector<4xf32>
          %201 = llvm.extractvalue %197[3] : !llvm.array<4 x vector<4xf32>>
          %202 = llvm.mlir.constant(0 : i64) : i64
          %203 = llvm.insertelement %200, %201[%202 : i64] : vector<4xf32>
          %204 = llvm.insertvalue %203, %197[3] : !llvm.array<4 x vector<4xf32>>
          %205 = llvm.extractvalue %176[1] : !llvm.array<4 x vector<4xf32>>
          %206 = llvm.mlir.constant(0 : i64) : i64
          %207 = llvm.extractelement %205[%206 : i64] : vector<4xf32>
          %208 = llvm.extractvalue %204[0] : !llvm.array<4 x vector<4xf32>>
          %209 = llvm.mlir.constant(1 : i64) : i64
          %210 = llvm.insertelement %207, %208[%209 : i64] : vector<4xf32>
          %211 = llvm.insertvalue %210, %204[0] : !llvm.array<4 x vector<4xf32>>
          %212 = llvm.extractvalue %176[1] : !llvm.array<4 x vector<4xf32>>
          %213 = llvm.mlir.constant(1 : i64) : i64
          %214 = llvm.extractelement %212[%213 : i64] : vector<4xf32>
          %215 = llvm.extractvalue %211[1] : !llvm.array<4 x vector<4xf32>>
          %216 = llvm.mlir.constant(1 : i64) : i64
          %217 = llvm.insertelement %214, %215[%216 : i64] : vector<4xf32>
          %218 = llvm.insertvalue %217, %211[1] : !llvm.array<4 x vector<4xf32>>
          %219 = llvm.extractvalue %176[1] : !llvm.array<4 x vector<4xf32>>
          %220 = llvm.mlir.constant(2 : i64) : i64
          %221 = llvm.extractelement %219[%220 : i64] : vector<4xf32>
          %222 = llvm.extractvalue %218[2] : !llvm.array<4 x vector<4xf32>>
          %223 = llvm.mlir.constant(1 : i64) : i64
          %224 = llvm.insertelement %221, %222[%223 : i64] : vector<4xf32>
          %225 = llvm.insertvalue %224, %218[2] : !llvm.array<4 x vector<4xf32>>
          %226 = llvm.extractvalue %176[1] : !llvm.array<4 x vector<4xf32>>
          %227 = llvm.mlir.constant(3 : i64) : i64
          %228 = llvm.extractelement %226[%227 : i64] : vector<4xf32>
          %229 = llvm.extractvalue %225[3] : !llvm.array<4 x vector<4xf32>>
          %230 = llvm.mlir.constant(1 : i64) : i64
          %231 = llvm.insertelement %228, %229[%230 : i64] : vector<4xf32>
          %232 = llvm.insertvalue %231, %225[3] : !llvm.array<4 x vector<4xf32>>
          %233 = llvm.extractvalue %176[2] : !llvm.array<4 x vector<4xf32>>
          %234 = llvm.mlir.constant(0 : i64) : i64
          %235 = llvm.extractelement %233[%234 : i64] : vector<4xf32>
          %236 = llvm.extractvalue %232[0] : !llvm.array<4 x vector<4xf32>>
          %237 = llvm.mlir.constant(2 : i64) : i64
          %238 = llvm.insertelement %235, %236[%237 : i64] : vector<4xf32>
          %239 = llvm.insertvalue %238, %232[0] : !llvm.array<4 x vector<4xf32>>
          %240 = llvm.extractvalue %176[2] : !llvm.array<4 x vector<4xf32>>
          %241 = llvm.mlir.constant(1 : i64) : i64
          %242 = llvm.extractelement %240[%241 : i64] : vector<4xf32>
          %243 = llvm.extractvalue %239[1] : !llvm.array<4 x vector<4xf32>>
          %244 = llvm.mlir.constant(2 : i64) : i64
          %245 = llvm.insertelement %242, %243[%244 : i64] : vector<4xf32>
          %246 = llvm.insertvalue %245, %239[1] : !llvm.array<4 x vector<4xf32>>
          %247 = llvm.extractvalue %176[2] : !llvm.array<4 x vector<4xf32>>
          %248 = llvm.mlir.constant(2 : i64) : i64
          %249 = llvm.extractelement %247[%248 : i64] : vector<4xf32>
          %250 = llvm.extractvalue %246[2] : !llvm.array<4 x vector<4xf32>>
          %251 = llvm.mlir.constant(2 : i64) : i64
          %252 = llvm.insertelement %249, %250[%251 : i64] : vector<4xf32>
          %253 = llvm.insertvalue %252, %246[2] : !llvm.array<4 x vector<4xf32>>
          %254 = llvm.extractvalue %176[2] : !llvm.array<4 x vector<4xf32>>
          %255 = llvm.mlir.constant(3 : i64) : i64
          %256 = llvm.extractelement %254[%255 : i64] : vector<4xf32>
          %257 = llvm.extractvalue %253[3] : !llvm.array<4 x vector<4xf32>>
          %258 = llvm.mlir.constant(2 : i64) : i64
          %259 = llvm.insertelement %256, %257[%258 : i64] : vector<4xf32>
          %260 = llvm.insertvalue %259, %253[3] : !llvm.array<4 x vector<4xf32>>
          %261 = llvm.extractvalue %176[3] : !llvm.array<4 x vector<4xf32>>
          %262 = llvm.mlir.constant(0 : i64) : i64
          %263 = llvm.extractelement %261[%262 : i64] : vector<4xf32>
          %264 = llvm.extractvalue %260[0] : !llvm.array<4 x vector<4xf32>>
          %265 = llvm.mlir.constant(3 : i64) : i64
          %266 = llvm.insertelement %263, %264[%265 : i64] : vector<4xf32>
          %267 = llvm.insertvalue %266, %260[0] : !llvm.array<4 x vector<4xf32>>
          %268 = llvm.extractvalue %176[3] : !llvm.array<4 x vector<4xf32>>
          %269 = llvm.mlir.constant(1 : i64) : i64
          %270 = llvm.extractelement %268[%269 : i64] : vector<4xf32>
          %271 = llvm.extractvalue %267[1] : !llvm.array<4 x vector<4xf32>>
          %272 = llvm.mlir.constant(3 : i64) : i64
          %273 = llvm.insertelement %270, %271[%272 : i64] : vector<4xf32>
          %274 = llvm.insertvalue %273, %267[1] : !llvm.array<4 x vector<4xf32>>
          %275 = llvm.extractvalue %176[3] : !llvm.array<4 x vector<4xf32>>
          %276 = llvm.mlir.constant(2 : i64) : i64
          %277 = llvm.extractelement %275[%276 : i64] : vector<4xf32>
          %278 = llvm.extractvalue %274[2] : !llvm.array<4 x vector<4xf32>>
          %279 = llvm.mlir.constant(3 : i64) : i64
          %280 = llvm.insertelement %277, %278[%279 : i64] : vector<4xf32>
          %281 = llvm.insertvalue %280, %274[2] : !llvm.array<4 x vector<4xf32>>
          %282 = llvm.extractvalue %176[3] : !llvm.array<4 x vector<4xf32>>
          %283 = llvm.mlir.constant(3 : i64) : i64
          %284 = llvm.extractelement %282[%283 : i64] : vector<4xf32>
          %285 = llvm.extractvalue %281[3] : !llvm.array<4 x vector<4xf32>>
          %286 = llvm.mlir.constant(3 : i64) : i64
          %287 = llvm.insertelement %284, %285[%286 : i64] : vector<4xf32>
          %288 = llvm.insertvalue %287, %281[3] : !llvm.array<4 x vector<4xf32>>
          %289 = llvm.extractvalue %288[0] : !llvm.array<4 x vector<4xf32>>
          %290 = llvm.fadd %289, %91 : vector<4xf32>
          %291 = llvm.extractvalue %288[1] : !llvm.array<4 x vector<4xf32>>
          %292 = llvm.fadd %291, %290 : vector<4xf32>
          %293 = llvm.extractvalue %288[2] : !llvm.array<4 x vector<4xf32>>
          %294 = llvm.fadd %293, %292 : vector<4xf32>
          %295 = llvm.extractvalue %288[3] : !llvm.array<4 x vector<4xf32>>
          %296 = llvm.fadd %295, %294 : vector<4xf32>
          %297 = llvm.add %90, %9 : i64
          llvm.br ^bb2(%297, %296 : i64, vector<4xf32>)
          ^bb4: // pred: ^bb2
          %298 = llvm.extractvalue %67[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %299 = llvm.getelementptr %298[%85] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %300 = llvm.bitcast %299 : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
          llvm.store %91, %300 {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
          llvm.br ^bb12
          ^bb5: // pred: ^bb0
          %301 = nvvm.read.ptx.sreg.tid.x : i32
          %302 = llvm.sext %301 : i32 to i64
          %303 = llvm.icmp "sle" %79, %11 : i64
          %304 = llvm.sub %11, %79 : i64
          %305 = llvm.sub %79, %7 : i64
          %306 = llvm.select %303, %304, %305 : i1, i64
          %307 = llvm.sdiv %306, %4 : i64
          %308 = llvm.sub %11, %307 : i64
          %309 = llvm.add %307, %7 : i64
          %310 = llvm.select %303, %308, %309 : i1, i64
          %311 = llvm.mul %302, %310 : i64
          %312 = llvm.mul %311, %3 : i64
          %313 = llvm.add %312, %79 : i64
          %314 = llvm.icmp "slt" %313, %310 : i64
          %315 = llvm.select %314, %313, %310 : i1, i64
          %316 = llvm.icmp "slt" %315, %11 : i64
          %317 = llvm.select %316, %11, %315 : i1, i64
          %318 = llvm.mul %75, %10 : i64
          %319 = llvm.add %311, %318 : i64
          llvm.br ^bb6(%11 : i64)
          ^bb6(%320: i64): // 2 preds: ^bb5, ^bb11
          %321 = llvm.icmp "slt" %320, %8 : i64
          llvm.cond_br %321, ^bb7(%11 : i64), ^bb12
          ^bb7(%322: i64): // 2 preds: ^bb6, ^bb10
          %323 = llvm.icmp "slt" %322, %317 : i64
          llvm.cond_br %323, ^bb8(%11 : i64), ^bb11
          ^bb8(%324: i64): // 2 preds: ^bb7, ^bb9
          %325 = llvm.icmp "slt" %324, %9 : i64
          llvm.cond_br %325, ^bb9, ^bb10
          ^bb9: // pred: ^bb8
          %326 = llvm.add %319, %322 : i64
          %327 = llvm.add %320, %324 : i64
          %328 = llvm.extractvalue %27[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %329 = llvm.mlir.constant(100 : index) : i64
          %330 = llvm.mul %326, %329 : i64
          %331 = llvm.add %330, %327 : i64
          %332 = llvm.getelementptr %328[%331] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %333 = llvm.load %332 : !llvm.ptr<f32>
          %334 = llvm.extractvalue %49[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
          %335 = llvm.mlir.constant(100 : index) : i64
          %336 = llvm.mul %326, %335 : i64
          %337 = llvm.add %336, %327 : i64
          %338 = llvm.getelementptr %334[%337] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %339 = llvm.load %338 : !llvm.ptr<f32>
          %340 = llvm.extractvalue %67[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %341 = llvm.getelementptr %340[%326] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          %342 = llvm.load %341 : !llvm.ptr<f32>
          %343 = llvm.fadd %333, %339 : f32
          %344 = llvm.fadd %343, %342 : f32
          %345 = llvm.extractvalue %67[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
          %346 = llvm.getelementptr %345[%326] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
          llvm.store %344, %346 : !llvm.ptr<f32>
          %347 = llvm.add %324, %7 : i64
          llvm.br ^bb8(%347 : i64)
          ^bb10: // pred: ^bb8
          %348 = llvm.add %322, %7 : i64
          llvm.br ^bb7(%348 : i64)
          ^bb11: // pred: ^bb7
          %349 = llvm.add %320, %9 : i64
          llvm.br ^bb6(%349 : i64)
          ^bb12: // 2 preds: ^bb4, ^bb6
          llvm.return
          }
  • createConvertToHALPass

  • createFixupLegacySyncPass

  • addCleanupPatterns

  • createLinkExecutablesPass

  • createResolveExportOrdinalsPass

  • createMaterializeResourceCachesPass

  • createInlineDeviceSwitchesPass

  • createMemoizeDeviceQueriesPass

  • addCleanupPatterns

  • createElideRedundantCommandsPass

  • mlir::createLowerAffinePass

  • mlir::createConvertSCFToCFPass

  • IREE::Util::createCombineInitializersPass

  • addCleanupPatterns

  • createSerializeExecutablesPass

  • mlir::createSymbolDCEPass

Contents