Recommended 関東GPGPU勉強会 LLVM meets GPU
Xeon PhiとN体計算コーディング x86/x64最適化勉強会6(@k_nitadoriさんの代理アップ)
LLVMで遊ぶ(整数圧縮とか、x86向けの自動ベクトル化とか)
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能
Wrapping a C++ library with Cython
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013
Brief introduction of Boost.ICL
組み込み関数(intrinsic)によるSIMD入門
HalideでつくるDomain Specific Architectureの世界
More Related Content 関東GPGPU勉強会 LLVM meets GPU
Xeon PhiとN体計算コーディング x86/x64最適化勉強会6(@k_nitadoriさんの代理アップ)
LLVMで遊ぶ(整数圧縮とか、x86向けの自動ベクトル化とか)
What's hot (20)
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能
Wrapping a C++ library with Cython
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013
Brief introduction of Boost.ICL
組み込み関数(intrinsic)によるSIMD入門
Similar to 20180728 halide-study (20)
HalideでつくるDomain Specific Architectureの世界
融合変換による最適化の理論的基盤と正当性 (2006-06-27)
[第2版]Python機械学習プログラミング 第14章
Adding simpl GVN path into GHC
関数モデル 【クラウドアプリケーションのためのオブジェクト指向分析設計講座 第8回】
AtCoder Beginner Contest 012 解説
PyOpenCLによるGPGPU入門 Tokyo.SciPy#4 編
PFI Christmas seminar 2009
2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)
More from Fixstars Corporation (20) 製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx
CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編
製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~
製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~
CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編
Fpga online seminar by fixstars (1st)
Jetson活用セミナー ROS2自律走行実現に向けて
量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~
金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化
いまさら聞けないarmを使ったNEONの基礎と活用事例
ARM CPUにおけるSIMDを用いた高速計算入門
ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方
AIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術について
第8回社内プログラミングコンテスト 第1位 taiyo
20180728 halide-study2. Takuro Iizuka / @iitaku
北米子会社のFixstars Solutions, Inc. にて
HalideのFPGAバックエンドおよびツールチェイン”GENESIS”の開発やってます
4. TC: Tensor Comprehensionsとは?
テンソル計算の記述言語および
最適化コンパイラフレームワーク
2018.2.14にFacebook AI Researchからリリース
TC言語でアルゴリズムを書くと
ライブラリがいい感じに最適化してくれる
PyTorchとシームレスに統合できる
5. TC: Tensor Comprehensionsとは?
テンソル計算の記述言語および
最適化コンパイラフレームワーク
2018.2にFacebook AI Researchからリリース
TC言語でアルゴリズムを書くと
ライブラリがいい感じに最適化してくれる
PyTorchとシームレスに統合できる
コンパイラの中間表現としてHalide IRを採用
7. TCベンチマーク結果
MLP: Multi-Layer Perceptron
TMM: Transposed Matrix Multiplication
TBMM: Transposed Batched Matrix Multiplication
GCOV: Grouped Convolutions
https://research.fb.com/announcing-tensor-comprehensions/
9. TC in PyTorch
$ conda create –y –name pytorch python=3.6
$ conda activate pytorch
$ conda install -y -c pytorch -c tensorcomp tensor_comprehensions
$ python ./matmul.py
Variable containing:
-2.4028 2.8492 7.6141 3.3159 3.7171
1.3839 0.6650 -1.7253 0.7447 1.3988
0.1396 -0.0661 -1.0574 0.2163 0.1711
[torch.cuda.FloatTensor of size 3x5 (GPU 0)]
import tensor_comprehensions as tc
import torch
mm = """
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
"""
matmul = tc.define(mm, name="matmul")
A, B = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda()
C = matmul(A, B, options=tc.Options("naive"))
print(C)
1. TCをセットアップ
2. TC言語でカスタム
レイヤを書く
3. 実行する
11. num ::= <number literal with C syntax>
id ::= [_a-zA-Z0-9]*[_a-zA-Z][_a-zA-Z0-9]*
exp ::= num
| ( '-' | '!' | ... ) exp
| exp ( [+-*/%] | '==' | '!=' | '<=' | ... ) exp
| exp '?' exp ':' exp
| id '.' num # range of num-th dimension of id
| id '(' exp_list ')' # builtin call or tensor access
reduction ::= <associative reduction operator>
| '+=' | '*=' | 'min=' | 'max='
| '+=!' | '*=!' | 'min=!' | 'max=!'
range_constraint ::= id 'in' exp ':' exp
stmt ::= id '(' id_list ')' [ '=' | reduction ] exp
[ 'where' range_constraint_list ]
| id_list = id '('id_list ')' # TC function call
arg ::= type id
return ::= id # inferred return type and range
scalar_type ::= 'double' | 'float' | 'half'
| 'int32' | 'byte' | 'uint32' | ...
type ::= scalar_type [ '(' id_list ')' ]
func ::= # TC function definition
'def' id '(' arg_list ')' '->' '(' return_list ')' '{'
stmt_list
'}'
id_list ::= <comma separated id list>
exp_list ::= <comma separated exp list>
arg_list ::= <comma separated arg list>
stmt_list ::= <whitespace separated stmt list>
return_list ::= <comma separated return list>
range_constraint_list ::= <non-empty comma separated
range_constraint list>
14. def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
type(X,Y,…)で入力引数に制約
(要素型制約およびレンジ制約)を付与
15. def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
レンジ制約に同一シンボルを使用することで、
異なる引数間の制約を表現できる
16. def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
出力の制約はコンパイラによって自動推論される
18. def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
左辺で定義した誘導変数を右辺で使用すると
ループが形成される
for m := 0, M
for n := 0, N
19. def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
ループのレンジは入力制約から決定される
e.g. m := [0. M)
20. def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
左辺で定義していない誘導変数を右辺で使用した場合、
既存ループの最内に新たにループが形成される
for m := 0, M
for n := 0, N
for r_k := 0, K
21. def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
その場合も入力の制約をもとに制約チェックが行われる
22. Halide言語と比較してみる
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
TC言語
Func matmul;
ImageParam A{float, 2};
ImageParam B{float, 2};
Var m, n;
Rdom r_k{0, K};
matmul(m, n) = sum(A(m, r_k) * B(r_k, n));
matmul.realize({M, N});
Halide言語
23. Halide言語と比較してみる
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
Func matmul;
ImageParam A{float, 2};
ImageParam B{float, 2};
Var m, n;
Rdom r_k{0, K};
matmul(m, n) = sum(A(m, r_k) * B(r_k, n));
matmul.realize({M, N});
レンジ制約は前向きに推論
レンジ制約は後ろ向きに推論
25. TC GPU Backendコンパイルフロー
TC言語
AST
Halide IR
パース/AST構築: tc/core/compiler.cc, parse(…)
Halide IR変換: tc/core/tc2halide.cc, translate(…)
Halide IR
レンジ推論/ループ形成/簡約化: tc/core/tc2halide.cc, translate(…)
Polyhedral IR
CUDA C
Polyhedral IR変換: tc/core/polyhedral/scop.cc, makeScop(…)
Polyhedral IR
ループ変形/スレッディング: tc/core/polyhedral/mapped_scop.cc,
makeWithOuterBlockInnerThreadStrategy(…)
コード生成: tc/core/polyhedral/mapped_scop.cc, codegen(…)
27. Call/Provide系
抽象度高い
Load/Store系
抽象度低い
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Provide(out, {x, y}) = Call(in, {x, y})
}
}
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Store(out, y*out.stride.1+x,
Load(in, y*in.stride.1.x))
}
}
28. TCが取り扱うHalide IRは
抽象度の高いCall/Provide系のみ
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Provide(out, {x, y}) = Call(in, {x, y})
}
}
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Store(out, y*out.stride.1+x,
Load(in, y*in.stride.1.x))
}
}
ターゲットコードに近い中間表現は
Halide IRではなくPolyhedral IRを使う
29. Halide IR上でのLowering
コンパイラインフラストラクチャとしてのHalide IR
– Halide中にはHalide IR (Halide::Stmt/Expr)を操作する
関数が多数実装済み
• 簡約器
• ソルバー
• 範囲演算
• CSE
• などなど
– IRMutatorやIRVisitorクラスを使用すれば
Halide IRに対する変換や解析を独自に実装できる
TCのHalide IR Loweringは以下を行っている
– レンジ推論
– ループ形成
– 簡約化
32. ループ形成結果
後の解析や変換のベースとなるループ構築を行う
produce output {
let output.s0.n.loop_max = output.s0.n.max
let output.s0.n.loop_min = output.s0.n.min
let output.s0.n.loop_extent = ((output.s0.n.max + 1) - output.s0.n.min)
let output.s0.m.loop_max = output.s0.m.max
let output.s0.m.loop_min = output.s0.m.min
let output.s0.m.loop_extent = ((output.s0.m.max + 1) - output.s0.m.min)
for (output.s0.m, output.s0.m.loop_min, output.s0.m.loop_extent) {
for (output.s0.n, output.s0.n.loop_min, output.s0.n.loop_extent) {
output(output.s0.m, output.s0.n) = 0.000000f
}
}
let output.s1.r_k.loop_extent = ((output.s1.r_k.max - output.s1.r_k.min) + 1)
let output.s1.r_k.loop_max = output.s1.r_k.max
let output.s1.r_k.loop_min = output.s1.r_k.min
let output.s1.n.loop_max = output.s1.n.max
let output.s1.n.loop_min = output.s1.n.min
let output.s1.n.loop_extent = ((output.s1.n.max + 1) - output.s1.n.min)
let output.s1.m.loop_max = output.s1.m.max
let output.s1.m.loop_min = output.s1.m.min
let output.s1.m.loop_extent = ((output.s1.m.max + 1) - output.s1.m.min)
for (output.s1.m, output.s1.m.loop_min, output.s1.m.loop_extent) {
for (output.s1.n, output.s1.n.loop_min, output.s1.n.loop_extent) {
for (output.s1.r_k, output.s1.r_k.loop_min, output.s1.r_k.loop_extent) {
output(output.s1.m, output.s1.n) =
ReductionUpdate((output(output.s1.m, output.s1.n) +
(A(output.s1.m, output.s1.r_k)*
B(output.s1.r_k, output.s1.n))))
}
}
}
}
初期化
計算
34. 簡約化結果
後段のPolyhedral Transformationで
解析・変形可能なループ構造に簡約できた
for (output.s0.m, 0, M) {
for (output.s0.n, 0, K) {
output(output.s0.m, output.s0.n) = 0.000000f
}
}
for (output.s1.m, 0, M) {
for (output.s1.n, 0, K) {
for (output.s1.r_k, 0, K) {
output(output.s1.m, output.s1.n) =
ReductionUpdate((output(output.s1.m, output.s1.n) +
(A(output.s1.m, output.s1.r_k)*
B(output.s1.r_k, output.s1.n))))
}
}
}
ループ範囲がTC言語のレンジ制約と対応している
35. tc::polyhedral::Scop
– ISL (Integer Set Library) /Polyhedral Compilation
Libraryを用いて計算されたスケジューリング
– RAW, WAR, WAW 依存関係
– メモリ配置
– Halide IR関連
• パラメータ
• 入出力
• Stmt
Polyhedral IR
Polyhedral IR
Halide IR
パラメータ/
入出力
Polyhedral IR変換
Stmt of Provide A Stmt of Provide B
パラメータ/
入出力
スケジューリング 依存関係
メモリ配置
36. Polyhedral IR 上でのLowering
1. ループ融合
2. タイリング
– パラメータ: タイリング戦略
3. スレッドマッピング
– パラメータ: スレッドサイズ
4. ブロックマッピング
– パラメータ: ブロックサイズ
5. メモリマッピング
– パラメータ: 有効/無効、マッピング先、マッピング量
Polyhedral Transformationとは、
ループ構造をPolytope=多面体と見立ててアフィン変換を施すことで
Legalなループ変形を行う最適化手法
37. domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <=
output_s1_n < N and 0 <= output_s1_r_k < K })
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
-----------------------------------------------------------------------
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
初期状態のスケジューリングツリー
雑な解説: band=ループ、 filter=ステートメント
38. domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <=
output_s1_n < N and 0 <= output_s1_r_k < K })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
ループ融合後
39. domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <=
output_s1_n < N and 0 <= output_s1_r_k < K })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
band(n(2) permutable(1) coincident(1, 0) unroll(0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
タイリング後
40. domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and
0 <= output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 })
context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
mapping_filter(ids(t0, )
[K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 }
[K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 })
band(n(1) permutable(1) coincident(1) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
thread_specific()
band(n(1) permutable(1) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
スレッドマッピング後
41. domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and 0 <=
output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 })
context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 })
mapping_filter(ids(b0, )
[K, M, N, b0] -> { S_0[output_s0_m, output_s0_n] : (-b0 + output_s0_m) mod 128 = 0 and 0 <= b0 <= 127 }
[K, M, N, b0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-b0 + output_s1_m) mod 128 = 0 and 0 <= b0 <= 127 })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
mapping_filter(ids(t0, )
[K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 }
[K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 })
band(n(1) permutable(1) coincident(1) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
thread_specific()
band(n(1) permutable(1) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
ブロックマッピング後
42. template<typename T> inline __device__ T floord(T n, T d) {
return n < 0 ? - (-n + d - 1)/d : n / d;
}
#define if_then_else(cond,a,b) ((cond) ? (a) : (b))
// Can't include system dependencies with NVRTC
// Can't include cuda_fp16.h with NVRTC due to transitive system dependencies
// #include <cuda_fp16.h>
// Halide type handling
typedef char int8;
typedef short int16;
typedef int int32;
typedef long int64;
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long uint64;
// typedef half float16;
typedef float float32;
typedef double float64;
#define inff __int_as_float(0x7f800000)
#define inf __longlong_as_double(0x7ff0000000000000LL)
// Before CUDA 9, syncwarp is a noop since warps are always synchronized.
#if __CUDACC_VER_MAJOR__ < 9
__device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {}
#endif
extern "C" {
__global__ void matmul_4_3_5(int32 K, int32 M, int32 N, float32* poutput, const float32* pA, const float32* pB) {
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
float32 (*output)[5] = reinterpret_cast<float32 (*)[5]>(poutput);
const float32 (*A)[4] = reinterpret_cast<const float32 (*)[4]>(pA);
const float32 (*B)[5] = reinterpret_cast<const float32 (*)[5]>(pB);
output[b0][t0] = 0.000000f;
for (int c4 = 0; c4 <= 3; c4 += 1) {
output[b0][t0] = (output[b0][t0] + (A[b0][c4]*B[c4][t0]));
}
}
}
コード生成
45. Halide IR良くできてる
– 制約は正義
• Polyhedral Transformation等応用的な最適化手法を適用可能
• 解析時の計算量爆発がおきにくい
– IRを操作する関数の実装が揃ってる
• 自前の解析や変形を行う場合でも多くの機能を転用可能
– IRの変換が書きやすい
• IRMutator/IRVisitorのクラスが割とシンプルで書きやすい
– TC/TVM/Tiramisu等、
Halide IRを再利用する取り組みも出てきた
High Level Compiler IRとしてのHalideに今後も要注目!