Tensor Comprehensions にみる
Halide IRの汎用性
Fixstars Solutions, Inc.
Takuro Iizuka
Takuro Iizuka / @iitaku
北米子会社のFixstars Solutions, Inc. にて
HalideのFPGAバックエンドおよびツールチェイン”GENESIS”の開発やってます
もくじ
 TC: Tensor Comprehensions 概要
 TC言語
 Inside TC
 まとめ
TC: Tensor Comprehensionsとは?
 テンソル計算の記述言語および
最適化コンパイラフレームワーク
 2018.2.14にFacebook AI Researchからリリース
 TC言語でアルゴリズムを書くと
ライブラリがいい感じに最適化してくれる
 PyTorchとシームレスに統合できる
TC: Tensor Comprehensionsとは?
 テンソル計算の記述言語および
最適化コンパイラフレームワーク
 2018.2にFacebook AI Researchからリリース
 TC言語でアルゴリズムを書くと
ライブラリがいい感じに最適化してくれる
 PyTorchとシームレスに統合できる
コンパイラの中間表現としてHalide IRを採用
TCアーキテクチャ
https://research.fb.com/announcing-tensor-comprehensions/
TCベンチマーク結果
MLP: Multi-Layer Perceptron
TMM: Transposed Matrix Multiplication
TBMM: Transposed Batched Matrix Multiplication
GCOV: Grouped Convolutions
https://research.fb.com/announcing-tensor-comprehensions/
TC in PyTorch
で書いて
で動せる
TC in PyTorch
$ conda create –y –name pytorch python=3.6
$ conda activate pytorch
$ conda install -y -c pytorch -c tensorcomp tensor_comprehensions
$ python ./matmul.py
Variable containing:
-2.4028 2.8492 7.6141 3.3159 3.7171
1.3839 0.6650 -1.7253 0.7447 1.3988
0.1396 -0.0661 -1.0574 0.2163 0.1711
[torch.cuda.FloatTensor of size 3x5 (GPU 0)]
import tensor_comprehensions as tc
import torch
mm = """
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
"""
matmul = tc.define(mm, name="matmul")
A, B = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda()
C = matmul(A, B, options=tc.Options("naive"))
print(C)
1. TCをセットアップ
2. TC言語でカスタム
レイヤを書く
3. 実行する
TC言語
num ::= <number literal with C syntax>
id ::= [_a-zA-Z0-9]*[_a-zA-Z][_a-zA-Z0-9]*
exp ::= num
| ( '-' | '!' | ... ) exp
| exp ( [+-*/%] | '==' | '!=' | '<=' | ... ) exp
| exp '?' exp ':' exp
| id '.' num # range of num-th dimension of id
| id '(' exp_list ')' # builtin call or tensor access
reduction ::= <associative reduction operator>
| '+=' | '*=' | 'min=' | 'max='
| '+=!' | '*=!' | 'min=!' | 'max=!'
range_constraint ::= id 'in' exp ':' exp
stmt ::= id '(' id_list ')' [ '=' | reduction ] exp
[ 'where' range_constraint_list ]
| id_list = id '('id_list ')' # TC function call
arg ::= type id
return ::= id # inferred return type and range
scalar_type ::= 'double' | 'float' | 'half'
| 'int32' | 'byte' | 'uint32' | ...
type ::= scalar_type [ '(' id_list ')' ]
func ::= # TC function definition
'def' id '(' arg_list ')' '->' '(' return_list ')' '{'
stmt_list
'}'
id_list ::= <comma separated id list>
exp_list ::= <comma separated exp list>
arg_list ::= <comma separated arg list>
stmt_list ::= <whitespace separated stmt list>
return_list ::= <comma separated return list>
range_constraint_list ::= <non-empty comma separated
range_constraint list>
TC言語の特徴
 テンソル計算記述特化言語
 Halide言語よりさらに簡素な言語体型
 超ミニマルなプリミティブ型
 畳み込み演算用の特殊なオペレータ群
 where節でレンジ制約を記述
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
def 関数名 {} で関数定義
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
type(X,Y,…)で入力引数に制約
(要素型制約およびレンジ制約)を付与
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
レンジ制約に同一シンボルを使用することで、
異なる引数間の制約を表現できる
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
出力の制約はコンパイラによって自動推論される
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
初期化付き畳み込み用記述のオペレータ
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
左辺で定義した誘導変数を右辺で使用すると
ループが形成される
for m := 0, M
for n := 0, N
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
ループのレンジは入力制約から決定される
e.g. m := [0. M)
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
左辺で定義していない誘導変数を右辺で使用した場合、
既存ループの最内に新たにループが形成される
for m := 0, M
for n := 0, N
for r_k := 0, K
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
その場合も入力の制約をもとに制約チェックが行われる
Halide言語と比較してみる
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
TC言語
Func matmul;
ImageParam A{float, 2};
ImageParam B{float, 2};
Var m, n;
Rdom r_k{0, K};
matmul(m, n) = sum(A(m, r_k) * B(r_k, n));
matmul.realize({M, N});
Halide言語
Halide言語と比較してみる
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
Func matmul;
ImageParam A{float, 2};
ImageParam B{float, 2};
Var m, n;
Rdom r_k{0, K};
matmul(m, n) = sum(A(m, r_k) * B(r_k, n));
matmul.realize({M, N});
レンジ制約は前向きに推論
レンジ制約は後ろ向きに推論
Inside TC
TC GPU Backendコンパイルフロー
TC言語
AST
Halide IR
パース/AST構築: tc/core/compiler.cc, parse(…)
Halide IR変換: tc/core/tc2halide.cc, translate(…)
Halide IR
レンジ推論/ループ形成/簡約化: tc/core/tc2halide.cc, translate(…)
Polyhedral IR
CUDA C
Polyhedral IR変換: tc/core/polyhedral/scop.cc, makeScop(…)
Polyhedral IR
ループ変形/スレッディング: tc/core/polyhedral/mapped_scop.cc,
makeWithOuterBlockInnerThreadStrategy(…)
コード生成: tc/core/polyhedral/mapped_scop.cc, codegen(…)
Halide IR
オリジナルのHalide実装では、
中間表現の変換過程においてHalide IRを2つに大別できる
Halide IR
Halide IR
Halide IR
Halide IR
Halide IR
Call/Provide系
Load/Store系
storage_flattening以前/以後
Call/Provide系
抽象度高い
Load/Store系
抽象度低い
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Provide(out, {x, y}) = Call(in, {x, y})
}
}
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Store(out, y*out.stride.1+x,
Load(in, y*in.stride.1.x))
}
}
TCが取り扱うHalide IRは
抽象度の高いCall/Provide系のみ
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Provide(out, {x, y}) = Call(in, {x, y})
}
}
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Store(out, y*out.stride.1+x,
Load(in, y*in.stride.1.x))
}
}
ターゲットコードに近い中間表現は
Halide IRではなくPolyhedral IRを使う
Halide IR上でのLowering
 コンパイラインフラストラクチャとしてのHalide IR
– Halide中にはHalide IR (Halide::Stmt/Expr)を操作する
関数が多数実装済み
• 簡約器
• ソルバー
• 範囲演算
• CSE
• などなど
– IRMutatorやIRVisitorクラスを使用すれば
Halide IRに対する変換や解析を独自に実装できる
 TCのHalide IR Loweringは以下を行っている
– レンジ推論
– ループ形成
– 簡約化
レンジ推論で使用されるHalide API
 solve_for_inner_interval(c, v)
– 条件式cを必ず満たす変数vの最大範囲を計算する
 and_condition_over_domain(c, varying)
– 変数範囲varyingの仮定のもとで条件式cを簡約化する
 simplify(e)
– 式eを簡約化する
これらを組み合わせてレンジ推論を行い、後段で行われる
Polyhedral最適化に必要な条件を満たすかをテストしておく
ループ形成で使用されるHalide API
 realization_order(…)
– Provide/Call間の依存グラフを
トポロジカルソートによって順序付けをする
 schedule_functions(…)
– 出力が依存するすべての計算ループを形成し、
Halide IR (Halide::Stmt) を返す
後の解析や変換のベースとなるループ構築を行う
ループ形成結果
後の解析や変換のベースとなるループ構築を行う
produce output {
let output.s0.n.loop_max = output.s0.n.max
let output.s0.n.loop_min = output.s0.n.min
let output.s0.n.loop_extent = ((output.s0.n.max + 1) - output.s0.n.min)
let output.s0.m.loop_max = output.s0.m.max
let output.s0.m.loop_min = output.s0.m.min
let output.s0.m.loop_extent = ((output.s0.m.max + 1) - output.s0.m.min)
for (output.s0.m, output.s0.m.loop_min, output.s0.m.loop_extent) {
for (output.s0.n, output.s0.n.loop_min, output.s0.n.loop_extent) {
output(output.s0.m, output.s0.n) = 0.000000f
}
}
let output.s1.r_k.loop_extent = ((output.s1.r_k.max - output.s1.r_k.min) + 1)
let output.s1.r_k.loop_max = output.s1.r_k.max
let output.s1.r_k.loop_min = output.s1.r_k.min
let output.s1.n.loop_max = output.s1.n.max
let output.s1.n.loop_min = output.s1.n.min
let output.s1.n.loop_extent = ((output.s1.n.max + 1) - output.s1.n.min)
let output.s1.m.loop_max = output.s1.m.max
let output.s1.m.loop_min = output.s1.m.min
let output.s1.m.loop_extent = ((output.s1.m.max + 1) - output.s1.m.min)
for (output.s1.m, output.s1.m.loop_min, output.s1.m.loop_extent) {
for (output.s1.n, output.s1.n.loop_min, output.s1.n.loop_extent) {
for (output.s1.r_k, output.s1.r_k.loop_min, output.s1.r_k.loop_extent) {
output(output.s1.m, output.s1.n) =
ReductionUpdate((output(output.s1.m, output.s1.n) +
(A(output.s1.m, output.s1.r_k)*
B(output.s1.r_k, output.s1.n))))
}
}
}
}
初期化
計算
簡約化で使用されるHalide API
 LetStmt::make(n, e, s)
– 文s中で式eをシンボルnに束縛する
 simplify(s)
– 文sを簡約化する
レンジ制約を適用しループ範囲の簡約化を行う
簡約化結果
後段のPolyhedral Transformationで
解析・変形可能なループ構造に簡約できた
for (output.s0.m, 0, M) {
for (output.s0.n, 0, K) {
output(output.s0.m, output.s0.n) = 0.000000f
}
}
for (output.s1.m, 0, M) {
for (output.s1.n, 0, K) {
for (output.s1.r_k, 0, K) {
output(output.s1.m, output.s1.n) =
ReductionUpdate((output(output.s1.m, output.s1.n) +
(A(output.s1.m, output.s1.r_k)*
B(output.s1.r_k, output.s1.n))))
}
}
}
ループ範囲がTC言語のレンジ制約と対応している
 tc::polyhedral::Scop
– ISL (Integer Set Library) /Polyhedral Compilation
Libraryを用いて計算されたスケジューリング
– RAW, WAR, WAW 依存関係
– メモリ配置
– Halide IR関連
• パラメータ
• 入出力
• Stmt
Polyhedral IR
Polyhedral IR
Halide IR
パラメータ/
入出力
Polyhedral IR変換
Stmt of Provide A Stmt of Provide B
パラメータ/
入出力
スケジューリング 依存関係
メモリ配置
Polyhedral IR 上でのLowering
1. ループ融合
2. タイリング
– パラメータ: タイリング戦略
3. スレッドマッピング
– パラメータ: スレッドサイズ
4. ブロックマッピング
– パラメータ: ブロックサイズ
5. メモリマッピング
– パラメータ: 有効/無効、マッピング先、マッピング量
Polyhedral Transformationとは、
ループ構造をPolytope=多面体と見立ててアフィン変換を施すことで
Legalなループ変形を行う最適化手法
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <=
output_s1_n < N and 0 <= output_s1_r_k < K })
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
-----------------------------------------------------------------------
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
初期状態のスケジューリングツリー
雑な解説: band=ループ、 filter=ステートメント
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <=
output_s1_n < N and 0 <= output_s1_r_k < K })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
ループ融合後
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <=
output_s1_n < N and 0 <= output_s1_r_k < K })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
band(n(2) permutable(1) coincident(1, 0) unroll(0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
タイリング後
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and
0 <= output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 })
context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
mapping_filter(ids(t0, )
[K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 }
[K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 })
band(n(1) permutable(1) coincident(1) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
thread_specific()
band(n(1) permutable(1) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
スレッドマッピング後
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and 0 <=
output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 })
context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 })
mapping_filter(ids(b0, )
[K, M, N, b0] -> { S_0[output_s0_m, output_s0_n] : (-b0 + output_s0_m) mod 128 = 0 and 0 <= b0 <= 127 }
[K, M, N, b0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-b0 + output_s1_m) mod 128 = 0 and 0 <= b0 <= 127 })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
mapping_filter(ids(t0, )
[K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 }
[K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 })
band(n(1) permutable(1) coincident(1) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
thread_specific()
band(n(1) permutable(1) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
ブロックマッピング後
template<typename T> inline __device__ T floord(T n, T d) {
return n < 0 ? - (-n + d - 1)/d : n / d;
}
#define if_then_else(cond,a,b) ((cond) ? (a) : (b))
// Can't include system dependencies with NVRTC
// Can't include cuda_fp16.h with NVRTC due to transitive system dependencies
// #include <cuda_fp16.h>
// Halide type handling
typedef char int8;
typedef short int16;
typedef int int32;
typedef long int64;
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long uint64;
// typedef half float16;
typedef float float32;
typedef double float64;
#define inff __int_as_float(0x7f800000)
#define inf __longlong_as_double(0x7ff0000000000000LL)
// Before CUDA 9, syncwarp is a noop since warps are always synchronized.
#if __CUDACC_VER_MAJOR__ < 9
__device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {}
#endif
extern "C" {
__global__ void matmul_4_3_5(int32 K, int32 M, int32 N, float32* poutput, const float32* pA, const float32* pB) {
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
float32 (*output)[5] = reinterpret_cast<float32 (*)[5]>(poutput);
const float32 (*A)[4] = reinterpret_cast<const float32 (*)[4]>(pA);
const float32 (*B)[5] = reinterpret_cast<const float32 (*)[5]>(pB);
output[b0][t0] = 0.000000f;
for (int c4 = 0; c4 <= 3; c4 += 1) {
output[b0][t0] = (output[b0][t0] + (A[b0][c4]*B[c4][t0]));
}
}
}
コード生成
パラメータのオートチューニング
 遺伝的アルゴリズムを用いて
より良いパラメータを探索する
https://research.fb.com/announcing-tensor-comprehensions/
まとめ
Halide IR良くできてる
– 制約は正義
• Polyhedral Transformation等応用的な最適化手法を適用可能
• 解析時の計算量爆発がおきにくい
– IRを操作する関数の実装が揃ってる
• 自前の解析や変形を行う場合でも多くの機能を転用可能
– IRの変換が書きやすい
• IRMutator/IRVisitorのクラスが割とシンプルで書きやすい
– TC/TVM/Tiramisu等、
Halide IRを再利用する取り組みも出てきた
High Level Compiler IRとしてのHalideに今後も要注目!

More Related Content

KEY
関東GPGPU勉強会 LLVM meets GPU
PDF
TVM の紹介
PDF
Xeon PhiとN体計算コーディング x86/x64最適化勉強会6(@k_nitadoriさんの代理アップ)
PDF
llvm入門
PDF
フラグを愛でる
PDF
きつねさんでもわかるLlvm読書会 第2回
PDF
LLVMで遊ぶ(整数圧縮とか、x86向けの自動ベクトル化とか)
PDF
Emcjp item33,34
関東GPGPU勉強会 LLVM meets GPU
TVM の紹介
Xeon PhiとN体計算コーディング x86/x64最適化勉強会6(@k_nitadoriさんの代理アップ)
llvm入門
フラグを愛でる
きつねさんでもわかるLlvm読書会 第2回
LLVMで遊ぶ(整数圧縮とか、x86向けの自動ベクトル化とか)
Emcjp item33,34

What's hot (20)

PDF
LLVM最適化のこつ
PPTX
Prosym2012
PDF
Boost.SIMD
PDF
新しい並列for構文のご提案
PDF
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能
PPTX
C++ AMPを使ってみよう
PDF
Cython ことはじめ
PDF
条件分岐とcmovとmaxps
PDF
Haswellサーベイと有限体クラスの紹介
PDF
Wrapping a C++ library with Cython
PDF
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013
PDF
C++による数値解析の並列化手法
PDF
SSE4.2の文字列処理命令の紹介
PDF
高速な倍精度指数関数expの実装
KEY
GPGPU deいろんな問題解いてみた
PDF
NumPyが物足りない人へのCython入門
PDF
boost tour 1.48.0 all
PDF
Boost Tour 1.50.0 All
PPTX
Brief introduction of Boost.ICL
PDF
組み込み関数(intrinsic)によるSIMD入門
LLVM最適化のこつ
Prosym2012
Boost.SIMD
新しい並列for構文のご提案
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能
C++ AMPを使ってみよう
Cython ことはじめ
条件分岐とcmovとmaxps
Haswellサーベイと有限体クラスの紹介
Wrapping a C++ library with Cython
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013
C++による数値解析の並列化手法
SSE4.2の文字列処理命令の紹介
高速な倍精度指数関数expの実装
GPGPU deいろんな問題解いてみた
NumPyが物足りない人へのCython入門
boost tour 1.48.0 all
Boost Tour 1.50.0 All
Brief introduction of Boost.ICL
組み込み関数(intrinsic)によるSIMD入門
Ad

Similar to 20180728 halide-study (20)

PPTX
Prelude to Halide
PDF
HalideでつくるDomain Specific Architectureの世界
PDF
多次元配列の効率的利用法の検討
PDF
暗号技術の実装と数学
PPT
Or seminar2011final
PPT
融合変換による最適化の理論的基盤と正当性 (2006-06-27)
PDF
[第2版]Python機械学習プログラミング 第14章
PDF
Adding simpl GVN path into GHC
PPTX
画像処理の高性能計算
PDF
関数モデル 【クラウドアプリケーションのためのオブジェクト指向分析設計講座 第8回】
PDF
数式をnumpyに落としこむコツ
PDF
Halide による画像処理プログラミング入門
PPTX
AtCoder Beginner Contest 012 解説
PDF
Rの高速化
PPT
Clean
PPT
Clean
KEY
PyOpenCLによるGPGPU入門 Tokyo.SciPy#4 編
PDF
Shunsuke Horii
PDF
PFI Christmas seminar 2009
PPTX
2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)
Prelude to Halide
HalideでつくるDomain Specific Architectureの世界
多次元配列の効率的利用法の検討
暗号技術の実装と数学
Or seminar2011final
融合変換による最適化の理論的基盤と正当性 (2006-06-27)
[第2版]Python機械学習プログラミング 第14章
Adding simpl GVN path into GHC
画像処理の高性能計算
関数モデル 【クラウドアプリケーションのためのオブジェクト指向分析設計講座 第8回】
数式をnumpyに落としこむコツ
Halide による画像処理プログラミング入門
AtCoder Beginner Contest 012 解説
Rの高速化
Clean
Clean
PyOpenCLによるGPGPU入門 Tokyo.SciPy#4 編
Shunsuke Horii
PFI Christmas seminar 2009
2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)
Ad

More from Fixstars Corporation (20)

PPTX
製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx
PDF
CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編
PDF
製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~
PPTX
製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~
PDF
株式会社フィックスターズの会社説明資料(抜粋)
PPTX
CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編
PDF
Fpga online seminar by fixstars (1st)
PDF
Jetson活用セミナー ROS2自律走行実現に向けて
PDF
いまさら聞けない!CUDA高速化入門
PDF
量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~
PDF
金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化
PDF
いまさら聞けないarmを使ったNEONの基礎と活用事例
PDF
ARM CPUにおけるSIMDを用いた高速計算入門
PDF
株式会社フィックスターズ 会社説明資料(抜粋)
PDF
株式会社フィックスターズ 会社説明資料(抜粋)
PDF
ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方
PDF
AIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術について
PDF
株式会社フィックスターズ 会社説明資料(抜粋)
PPTX
第8回 社内プログラミングコンテスト 結果発表会
PPTX
第8回 社内プログラミングコンテスト 第1位 taiyo
製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx
CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編
製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~
製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~
株式会社フィックスターズの会社説明資料(抜粋)
CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編
Fpga online seminar by fixstars (1st)
Jetson活用セミナー ROS2自律走行実現に向けて
いまさら聞けない!CUDA高速化入門
量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~
金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化
いまさら聞けないarmを使ったNEONの基礎と活用事例
ARM CPUにおけるSIMDを用いた高速計算入門
株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)
ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方
AIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術について
株式会社フィックスターズ 会社説明資料(抜粋)
第8回 社内プログラミングコンテスト 結果発表会
第8回 社内プログラミングコンテスト 第1位 taiyo

20180728 halide-study

  • 1. Tensor Comprehensions にみる Halide IRの汎用性 Fixstars Solutions, Inc. Takuro Iizuka
  • 2. Takuro Iizuka / @iitaku 北米子会社のFixstars Solutions, Inc. にて HalideのFPGAバックエンドおよびツールチェイン”GENESIS”の開発やってます
  • 3. もくじ  TC: Tensor Comprehensions 概要  TC言語  Inside TC  まとめ
  • 4. TC: Tensor Comprehensionsとは?  テンソル計算の記述言語および 最適化コンパイラフレームワーク  2018.2.14にFacebook AI Researchからリリース  TC言語でアルゴリズムを書くと ライブラリがいい感じに最適化してくれる  PyTorchとシームレスに統合できる
  • 5. TC: Tensor Comprehensionsとは?  テンソル計算の記述言語および 最適化コンパイラフレームワーク  2018.2にFacebook AI Researchからリリース  TC言語でアルゴリズムを書くと ライブラリがいい感じに最適化してくれる  PyTorchとシームレスに統合できる コンパイラの中間表現としてHalide IRを採用
  • 7. TCベンチマーク結果 MLP: Multi-Layer Perceptron TMM: Transposed Matrix Multiplication TBMM: Transposed Batched Matrix Multiplication GCOV: Grouped Convolutions https://research.fb.com/announcing-tensor-comprehensions/
  • 9. TC in PyTorch $ conda create –y –name pytorch python=3.6 $ conda activate pytorch $ conda install -y -c pytorch -c tensorcomp tensor_comprehensions $ python ./matmul.py Variable containing: -2.4028 2.8492 7.6141 3.3159 3.7171 1.3839 0.6650 -1.7253 0.7447 1.3988 0.1396 -0.0661 -1.0574 0.2163 0.1711 [torch.cuda.FloatTensor of size 3x5 (GPU 0)] import tensor_comprehensions as tc import torch mm = """ def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } """ matmul = tc.define(mm, name="matmul") A, B = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() C = matmul(A, B, options=tc.Options("naive")) print(C) 1. TCをセットアップ 2. TC言語でカスタム レイヤを書く 3. 実行する
  • 11. num ::= <number literal with C syntax> id ::= [_a-zA-Z0-9]*[_a-zA-Z][_a-zA-Z0-9]* exp ::= num | ( '-' | '!' | ... ) exp | exp ( [+-*/%] | '==' | '!=' | '<=' | ... ) exp | exp '?' exp ':' exp | id '.' num # range of num-th dimension of id | id '(' exp_list ')' # builtin call or tensor access reduction ::= <associative reduction operator> | '+=' | '*=' | 'min=' | 'max=' | '+=!' | '*=!' | 'min=!' | 'max=!' range_constraint ::= id 'in' exp ':' exp stmt ::= id '(' id_list ')' [ '=' | reduction ] exp [ 'where' range_constraint_list ] | id_list = id '('id_list ')' # TC function call arg ::= type id return ::= id # inferred return type and range scalar_type ::= 'double' | 'float' | 'half' | 'int32' | 'byte' | 'uint32' | ... type ::= scalar_type [ '(' id_list ')' ] func ::= # TC function definition 'def' id '(' arg_list ')' '->' '(' return_list ')' '{' stmt_list '}' id_list ::= <comma separated id list> exp_list ::= <comma separated exp list> arg_list ::= <comma separated arg list> stmt_list ::= <whitespace separated stmt list> return_list ::= <comma separated return list> range_constraint_list ::= <non-empty comma separated range_constraint list>
  • 12. TC言語の特徴  テンソル計算記述特化言語  Halide言語よりさらに簡素な言語体型  超ミニマルなプリミティブ型  畳み込み演算用の特殊なオペレータ群  where節でレンジ制約を記述
  • 13. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } def 関数名 {} で関数定義
  • 14. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } type(X,Y,…)で入力引数に制約 (要素型制約およびレンジ制約)を付与
  • 15. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } レンジ制約に同一シンボルを使用することで、 異なる引数間の制約を表現できる
  • 16. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 出力の制約はコンパイラによって自動推論される
  • 17. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 初期化付き畳み込み用記述のオペレータ
  • 18. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 左辺で定義した誘導変数を右辺で使用すると ループが形成される for m := 0, M for n := 0, N
  • 19. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } ループのレンジは入力制約から決定される e.g. m := [0. M)
  • 20. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 左辺で定義していない誘導変数を右辺で使用した場合、 既存ループの最内に新たにループが形成される for m := 0, M for n := 0, N for r_k := 0, K
  • 21. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } その場合も入力の制約をもとに制約チェックが行われる
  • 22. Halide言語と比較してみる def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } TC言語 Func matmul; ImageParam A{float, 2}; ImageParam B{float, 2}; Var m, n; Rdom r_k{0, K}; matmul(m, n) = sum(A(m, r_k) * B(r_k, n)); matmul.realize({M, N}); Halide言語
  • 23. Halide言語と比較してみる def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } Func matmul; ImageParam A{float, 2}; ImageParam B{float, 2}; Var m, n; Rdom r_k{0, K}; matmul(m, n) = sum(A(m, r_k) * B(r_k, n)); matmul.realize({M, N}); レンジ制約は前向きに推論 レンジ制約は後ろ向きに推論
  • 25. TC GPU Backendコンパイルフロー TC言語 AST Halide IR パース/AST構築: tc/core/compiler.cc, parse(…) Halide IR変換: tc/core/tc2halide.cc, translate(…) Halide IR レンジ推論/ループ形成/簡約化: tc/core/tc2halide.cc, translate(…) Polyhedral IR CUDA C Polyhedral IR変換: tc/core/polyhedral/scop.cc, makeScop(…) Polyhedral IR ループ変形/スレッディング: tc/core/polyhedral/mapped_scop.cc, makeWithOuterBlockInnerThreadStrategy(…) コード生成: tc/core/polyhedral/mapped_scop.cc, codegen(…)
  • 26. Halide IR オリジナルのHalide実装では、 中間表現の変換過程においてHalide IRを2つに大別できる Halide IR Halide IR Halide IR Halide IR Halide IR Call/Provide系 Load/Store系 storage_flattening以前/以後
  • 27. Call/Provide系 抽象度高い Load/Store系 抽象度低い for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Provide(out, {x, y}) = Call(in, {x, y}) } } for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Store(out, y*out.stride.1+x, Load(in, y*in.stride.1.x)) } }
  • 28. TCが取り扱うHalide IRは 抽象度の高いCall/Provide系のみ for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Provide(out, {x, y}) = Call(in, {x, y}) } } for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Store(out, y*out.stride.1+x, Load(in, y*in.stride.1.x)) } } ターゲットコードに近い中間表現は Halide IRではなくPolyhedral IRを使う
  • 29. Halide IR上でのLowering  コンパイラインフラストラクチャとしてのHalide IR – Halide中にはHalide IR (Halide::Stmt/Expr)を操作する 関数が多数実装済み • 簡約器 • ソルバー • 範囲演算 • CSE • などなど – IRMutatorやIRVisitorクラスを使用すれば Halide IRに対する変換や解析を独自に実装できる  TCのHalide IR Loweringは以下を行っている – レンジ推論 – ループ形成 – 簡約化
  • 30. レンジ推論で使用されるHalide API  solve_for_inner_interval(c, v) – 条件式cを必ず満たす変数vの最大範囲を計算する  and_condition_over_domain(c, varying) – 変数範囲varyingの仮定のもとで条件式cを簡約化する  simplify(e) – 式eを簡約化する これらを組み合わせてレンジ推論を行い、後段で行われる Polyhedral最適化に必要な条件を満たすかをテストしておく
  • 31. ループ形成で使用されるHalide API  realization_order(…) – Provide/Call間の依存グラフを トポロジカルソートによって順序付けをする  schedule_functions(…) – 出力が依存するすべての計算ループを形成し、 Halide IR (Halide::Stmt) を返す 後の解析や変換のベースとなるループ構築を行う
  • 32. ループ形成結果 後の解析や変換のベースとなるループ構築を行う produce output { let output.s0.n.loop_max = output.s0.n.max let output.s0.n.loop_min = output.s0.n.min let output.s0.n.loop_extent = ((output.s0.n.max + 1) - output.s0.n.min) let output.s0.m.loop_max = output.s0.m.max let output.s0.m.loop_min = output.s0.m.min let output.s0.m.loop_extent = ((output.s0.m.max + 1) - output.s0.m.min) for (output.s0.m, output.s0.m.loop_min, output.s0.m.loop_extent) { for (output.s0.n, output.s0.n.loop_min, output.s0.n.loop_extent) { output(output.s0.m, output.s0.n) = 0.000000f } } let output.s1.r_k.loop_extent = ((output.s1.r_k.max - output.s1.r_k.min) + 1) let output.s1.r_k.loop_max = output.s1.r_k.max let output.s1.r_k.loop_min = output.s1.r_k.min let output.s1.n.loop_max = output.s1.n.max let output.s1.n.loop_min = output.s1.n.min let output.s1.n.loop_extent = ((output.s1.n.max + 1) - output.s1.n.min) let output.s1.m.loop_max = output.s1.m.max let output.s1.m.loop_min = output.s1.m.min let output.s1.m.loop_extent = ((output.s1.m.max + 1) - output.s1.m.min) for (output.s1.m, output.s1.m.loop_min, output.s1.m.loop_extent) { for (output.s1.n, output.s1.n.loop_min, output.s1.n.loop_extent) { for (output.s1.r_k, output.s1.r_k.loop_min, output.s1.r_k.loop_extent) { output(output.s1.m, output.s1.n) = ReductionUpdate((output(output.s1.m, output.s1.n) + (A(output.s1.m, output.s1.r_k)* B(output.s1.r_k, output.s1.n)))) } } } } 初期化 計算
  • 33. 簡約化で使用されるHalide API  LetStmt::make(n, e, s) – 文s中で式eをシンボルnに束縛する  simplify(s) – 文sを簡約化する レンジ制約を適用しループ範囲の簡約化を行う
  • 34. 簡約化結果 後段のPolyhedral Transformationで 解析・変形可能なループ構造に簡約できた for (output.s0.m, 0, M) { for (output.s0.n, 0, K) { output(output.s0.m, output.s0.n) = 0.000000f } } for (output.s1.m, 0, M) { for (output.s1.n, 0, K) { for (output.s1.r_k, 0, K) { output(output.s1.m, output.s1.n) = ReductionUpdate((output(output.s1.m, output.s1.n) + (A(output.s1.m, output.s1.r_k)* B(output.s1.r_k, output.s1.n)))) } } } ループ範囲がTC言語のレンジ制約と対応している
  • 35.  tc::polyhedral::Scop – ISL (Integer Set Library) /Polyhedral Compilation Libraryを用いて計算されたスケジューリング – RAW, WAR, WAW 依存関係 – メモリ配置 – Halide IR関連 • パラメータ • 入出力 • Stmt Polyhedral IR Polyhedral IR Halide IR パラメータ/ 入出力 Polyhedral IR変換 Stmt of Provide A Stmt of Provide B パラメータ/ 入出力 スケジューリング 依存関係 メモリ配置
  • 36. Polyhedral IR 上でのLowering 1. ループ融合 2. タイリング – パラメータ: タイリング戦略 3. スレッドマッピング – パラメータ: スレッドサイズ 4. ブロックマッピング – パラメータ: ブロックサイズ 5. メモリマッピング – パラメータ: 有効/無効、マッピング先、マッピング量 Polyhedral Transformationとは、 ループ構造をPolytope=多面体と見立ててアフィン変換を施すことで Legalなループ変形を行う最適化手法
  • 37. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <= output_s1_n < N and 0 <= output_s1_r_k < K }) sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } ----------------------------------------------------------------------- band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } ----------------------------------------------------------------------- filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- 初期状態のスケジューリングツリー 雑な解説: band=ループ、 filter=ステートメント
  • 38. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <= output_s1_n < N and 0 <= output_s1_r_k < K }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) ループ融合後
  • 39. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <= output_s1_n < N and 0 <= output_s1_r_k < K }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- band(n(2) permutable(1) coincident(1, 0) unroll(0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) タイリング後
  • 40. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and 0 <= output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 }) context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- mapping_filter(ids(t0, ) [K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 } [K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 }) band(n(1) permutable(1) coincident(1) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- thread_specific() band(n(1) permutable(1) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) スレッドマッピング後
  • 41. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and 0 <= output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 }) context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 }) mapping_filter(ids(b0, ) [K, M, N, b0] -> { S_0[output_s0_m, output_s0_n] : (-b0 + output_s0_m) mod 128 = 0 and 0 <= b0 <= 127 } [K, M, N, b0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-b0 + output_s1_m) mod 128 = 0 and 0 <= b0 <= 127 }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- mapping_filter(ids(t0, ) [K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 } [K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 }) band(n(1) permutable(1) coincident(1) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- thread_specific() band(n(1) permutable(1) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) ブロックマッピング後
  • 42. template<typename T> inline __device__ T floord(T n, T d) { return n < 0 ? - (-n + d - 1)/d : n / d; } #define if_then_else(cond,a,b) ((cond) ? (a) : (b)) // Can't include system dependencies with NVRTC // Can't include cuda_fp16.h with NVRTC due to transitive system dependencies // #include <cuda_fp16.h> // Halide type handling typedef char int8; typedef short int16; typedef int int32; typedef long int64; typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint32; typedef unsigned long uint64; // typedef half float16; typedef float float32; typedef double float64; #define inff __int_as_float(0x7f800000) #define inf __longlong_as_double(0x7ff0000000000000LL) // Before CUDA 9, syncwarp is a noop since warps are always synchronized. #if __CUDACC_VER_MAJOR__ < 9 __device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {} #endif extern "C" { __global__ void matmul_4_3_5(int32 K, int32 M, int32 N, float32* poutput, const float32* pA, const float32* pB) { int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; float32 (*output)[5] = reinterpret_cast<float32 (*)[5]>(poutput); const float32 (*A)[4] = reinterpret_cast<const float32 (*)[4]>(pA); const float32 (*B)[5] = reinterpret_cast<const float32 (*)[5]>(pB); output[b0][t0] = 0.000000f; for (int c4 = 0; c4 <= 3; c4 += 1) { output[b0][t0] = (output[b0][t0] + (A[b0][c4]*B[c4][t0])); } } } コード生成
  • 45. Halide IR良くできてる – 制約は正義 • Polyhedral Transformation等応用的な最適化手法を適用可能 • 解析時の計算量爆発がおきにくい – IRを操作する関数の実装が揃ってる • 自前の解析や変形を行う場合でも多くの機能を転用可能 – IRの変換が書きやすい • IRMutator/IRVisitorのクラスが割とシンプルで書きやすい – TC/TVM/Tiramisu等、 Halide IRを再利用する取り組みも出てきた High Level Compiler IRとしてのHalideに今後も要注目!