SlideShare uma empresa Scribd logo
1 de 45
Tensor Comprehensions にみる
Halide IRの汎用性
Fixstars Solutions, Inc.
Takuro Iizuka
Takuro Iizuka / @iitaku
北米子会社のFixstars Solutions, Inc. にて
HalideのFPGAバックエンドおよびツールチェイン”GENESIS”の開発やってます
もくじ
 TC: Tensor Comprehensions 概要
 TC言語
 Inside TC
 まとめ
TC: Tensor Comprehensionsとは?
 テンソル計算の記述言語および
最適化コンパイラフレームワーク
 2018.2.14にFacebook AI Researchからリリース
 TC言語でアルゴリズムを書くと
ライブラリがいい感じに最適化してくれる
 PyTorchとシームレスに統合できる
TC: Tensor Comprehensionsとは?
 テンソル計算の記述言語および
最適化コンパイラフレームワーク
 2018.2にFacebook AI Researchからリリース
 TC言語でアルゴリズムを書くと
ライブラリがいい感じに最適化してくれる
 PyTorchとシームレスに統合できる
コンパイラの中間表現としてHalide IRを採用
TCアーキテクチャ
https://research.fb.com/announcing-tensor-comprehensions/
TCベンチマーク結果
MLP: Multi-Layer Perceptron
TMM: Transposed Matrix Multiplication
TBMM: Transposed Batched Matrix Multiplication
GCOV: Grouped Convolutions
https://research.fb.com/announcing-tensor-comprehensions/
TC in PyTorch
で書いて
で動せる
TC in PyTorch
$ conda create –y –name pytorch python=3.6
$ conda activate pytorch
$ conda install -y -c pytorch -c tensorcomp tensor_comprehensions
$ python ./matmul.py
Variable containing:
-2.4028 2.8492 7.6141 3.3159 3.7171
1.3839 0.6650 -1.7253 0.7447 1.3988
0.1396 -0.0661 -1.0574 0.2163 0.1711
[torch.cuda.FloatTensor of size 3x5 (GPU 0)]
import tensor_comprehensions as tc
import torch
mm = """
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
"""
matmul = tc.define(mm, name="matmul")
A, B = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda()
C = matmul(A, B, options=tc.Options("naive"))
print(C)
1. TCをセットアップ
2. TC言語でカスタム
レイヤを書く
3. 実行する
TC言語
num ::= <number literal with C syntax>
id ::= [_a-zA-Z0-9]*[_a-zA-Z][_a-zA-Z0-9]*
exp ::= num
| ( '-' | '!' | ... ) exp
| exp ( [+-*/%] | '==' | '!=' | '<=' | ... ) exp
| exp '?' exp ':' exp
| id '.' num # range of num-th dimension of id
| id '(' exp_list ')' # builtin call or tensor access
reduction ::= <associative reduction operator>
| '+=' | '*=' | 'min=' | 'max='
| '+=!' | '*=!' | 'min=!' | 'max=!'
range_constraint ::= id 'in' exp ':' exp
stmt ::= id '(' id_list ')' [ '=' | reduction ] exp
[ 'where' range_constraint_list ]
| id_list = id '('id_list ')' # TC function call
arg ::= type id
return ::= id # inferred return type and range
scalar_type ::= 'double' | 'float' | 'half'
| 'int32' | 'byte' | 'uint32' | ...
type ::= scalar_type [ '(' id_list ')' ]
func ::= # TC function definition
'def' id '(' arg_list ')' '->' '(' return_list ')' '{'
stmt_list
'}'
id_list ::= <comma separated id list>
exp_list ::= <comma separated exp list>
arg_list ::= <comma separated arg list>
stmt_list ::= <whitespace separated stmt list>
return_list ::= <comma separated return list>
range_constraint_list ::= <non-empty comma separated
range_constraint list>
TC言語の特徴
 テンソル計算記述特化言語
 Halide言語よりさらに簡素な言語体型
 超ミニマルなプリミティブ型
 畳み込み演算用の特殊なオペレータ群
 where節でレンジ制約を記述
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
def 関数名 {} で関数定義
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
type(X,Y,…)で入力引数に制約
(要素型制約およびレンジ制約)を付与
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
レンジ制約に同一シンボルを使用することで、
異なる引数間の制約を表現できる
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
出力の制約はコンパイラによって自動推論される
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
初期化付き畳み込み用記述のオペレータ
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
左辺で定義した誘導変数を右辺で使用すると
ループが形成される
for m := 0, M
for n := 0, N
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
ループのレンジは入力制約から決定される
e.g. m := [0. M)
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
左辺で定義していない誘導変数を右辺で使用した場合、
既存ループの最内に新たにループが形成される
for m := 0, M
for n := 0, N
for r_k := 0, K
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
その場合も入力の制約をもとに制約チェックが行われる
Halide言語と比較してみる
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
TC言語
Func matmul;
ImageParam A{float, 2};
ImageParam B{float, 2};
Var m, n;
Rdom r_k{0, K};
matmul(m, n) = sum(A(m, r_k) * B(r_k, n));
matmul.realize({M, N});
Halide言語
Halide言語と比較してみる
def matmul(float(M,K) A, float(K,N) B) -> (output) {
output(m, n) +=! A(m, r_k) * B(r_k, n)
}
Func matmul;
ImageParam A{float, 2};
ImageParam B{float, 2};
Var m, n;
Rdom r_k{0, K};
matmul(m, n) = sum(A(m, r_k) * B(r_k, n));
matmul.realize({M, N});
レンジ制約は前向きに推論
レンジ制約は後ろ向きに推論
Inside TC
TC GPU Backendコンパイルフロー
TC言語
AST
Halide IR
パース/AST構築: tc/core/compiler.cc, parse(…)
Halide IR変換: tc/core/tc2halide.cc, translate(…)
Halide IR
レンジ推論/ループ形成/簡約化: tc/core/tc2halide.cc, translate(…)
Polyhedral IR
CUDA C
Polyhedral IR変換: tc/core/polyhedral/scop.cc, makeScop(…)
Polyhedral IR
ループ変形/スレッディング: tc/core/polyhedral/mapped_scop.cc,
makeWithOuterBlockInnerThreadStrategy(…)
コード生成: tc/core/polyhedral/mapped_scop.cc, codegen(…)
Halide IR
オリジナルのHalide実装では、
中間表現の変換過程においてHalide IRを2つに大別できる
Halide IR
Halide IR
Halide IR
Halide IR
Halide IR
Call/Provide系
Load/Store系
storage_flattening以前/以後
Call/Provide系
抽象度高い
Load/Store系
抽象度低い
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Provide(out, {x, y}) = Call(in, {x, y})
}
}
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Store(out, y*out.stride.1+x,
Load(in, y*in.stride.1.x))
}
}
TCが取り扱うHalide IRは
抽象度の高いCall/Provide系のみ
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Provide(out, {x, y}) = Call(in, {x, y})
}
}
for (y, 0, out.extent.1) {
for (x, 0, out.extent.0) {
Store(out, y*out.stride.1+x,
Load(in, y*in.stride.1.x))
}
}
ターゲットコードに近い中間表現は
Halide IRではなくPolyhedral IRを使う
Halide IR上でのLowering
 コンパイラインフラストラクチャとしてのHalide IR
– Halide中にはHalide IR (Halide::Stmt/Expr)を操作する
関数が多数実装済み
• 簡約器
• ソルバー
• 範囲演算
• CSE
• などなど
– IRMutatorやIRVisitorクラスを使用すれば
Halide IRに対する変換や解析を独自に実装できる
 TCのHalide IR Loweringは以下を行っている
– レンジ推論
– ループ形成
– 簡約化
レンジ推論で使用されるHalide API
 solve_for_inner_interval(c, v)
– 条件式cを必ず満たす変数vの最大範囲を計算する
 and_condition_over_domain(c, varying)
– 変数範囲varyingの仮定のもとで条件式cを簡約化する
 simplify(e)
– 式eを簡約化する
これらを組み合わせてレンジ推論を行い、後段で行われる
Polyhedral最適化に必要な条件を満たすかをテストしておく
ループ形成で使用されるHalide API
 realization_order(…)
– Provide/Call間の依存グラフを
トポロジカルソートによって順序付けをする
 schedule_functions(…)
– 出力が依存するすべての計算ループを形成し、
Halide IR (Halide::Stmt) を返す
後の解析や変換のベースとなるループ構築を行う
ループ形成結果
後の解析や変換のベースとなるループ構築を行う
produce output {
let output.s0.n.loop_max = output.s0.n.max
let output.s0.n.loop_min = output.s0.n.min
let output.s0.n.loop_extent = ((output.s0.n.max + 1) - output.s0.n.min)
let output.s0.m.loop_max = output.s0.m.max
let output.s0.m.loop_min = output.s0.m.min
let output.s0.m.loop_extent = ((output.s0.m.max + 1) - output.s0.m.min)
for (output.s0.m, output.s0.m.loop_min, output.s0.m.loop_extent) {
for (output.s0.n, output.s0.n.loop_min, output.s0.n.loop_extent) {
output(output.s0.m, output.s0.n) = 0.000000f
}
}
let output.s1.r_k.loop_extent = ((output.s1.r_k.max - output.s1.r_k.min) + 1)
let output.s1.r_k.loop_max = output.s1.r_k.max
let output.s1.r_k.loop_min = output.s1.r_k.min
let output.s1.n.loop_max = output.s1.n.max
let output.s1.n.loop_min = output.s1.n.min
let output.s1.n.loop_extent = ((output.s1.n.max + 1) - output.s1.n.min)
let output.s1.m.loop_max = output.s1.m.max
let output.s1.m.loop_min = output.s1.m.min
let output.s1.m.loop_extent = ((output.s1.m.max + 1) - output.s1.m.min)
for (output.s1.m, output.s1.m.loop_min, output.s1.m.loop_extent) {
for (output.s1.n, output.s1.n.loop_min, output.s1.n.loop_extent) {
for (output.s1.r_k, output.s1.r_k.loop_min, output.s1.r_k.loop_extent) {
output(output.s1.m, output.s1.n) =
ReductionUpdate((output(output.s1.m, output.s1.n) +
(A(output.s1.m, output.s1.r_k)*
B(output.s1.r_k, output.s1.n))))
}
}
}
}
初期化
計算
簡約化で使用されるHalide API
 LetStmt::make(n, e, s)
– 文s中で式eをシンボルnに束縛する
 simplify(s)
– 文sを簡約化する
レンジ制約を適用しループ範囲の簡約化を行う
簡約化結果
後段のPolyhedral Transformationで
解析・変形可能なループ構造に簡約できた
for (output.s0.m, 0, M) {
for (output.s0.n, 0, K) {
output(output.s0.m, output.s0.n) = 0.000000f
}
}
for (output.s1.m, 0, M) {
for (output.s1.n, 0, K) {
for (output.s1.r_k, 0, K) {
output(output.s1.m, output.s1.n) =
ReductionUpdate((output(output.s1.m, output.s1.n) +
(A(output.s1.m, output.s1.r_k)*
B(output.s1.r_k, output.s1.n))))
}
}
}
ループ範囲がTC言語のレンジ制約と対応している
 tc::polyhedral::Scop
– ISL (Integer Set Library) /Polyhedral Compilation
Libraryを用いて計算されたスケジューリング
– RAW, WAR, WAW 依存関係
– メモリ配置
– Halide IR関連
• パラメータ
• 入出力
• Stmt
Polyhedral IR
Polyhedral IR
Halide IR
パラメータ/
入出力
Polyhedral IR変換
Stmt of Provide A Stmt of Provide B
パラメータ/
入出力
スケジューリング 依存関係
メモリ配置
Polyhedral IR 上でのLowering
1. ループ融合
2. タイリング
– パラメータ: タイリング戦略
3. スレッドマッピング
– パラメータ: スレッドサイズ
4. ブロックマッピング
– パラメータ: ブロックサイズ
5. メモリマッピング
– パラメータ: 有効/無効、マッピング先、マッピング量
Polyhedral Transformationとは、
ループ構造をPolytope=多面体と見立ててアフィン変換を施すことで
Legalなループ変形を行う最適化手法
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <=
output_s1_n < N and 0 <= output_s1_r_k < K })
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
-----------------------------------------------------------------------
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
初期状態のスケジューリングツリー
雑な解説: band=ループ、 filter=ステートメント
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <=
output_s1_n < N and 0 <= output_s1_r_k < K })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
ループ融合後
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <=
output_s1_n < N and 0 <= output_s1_r_k < K })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
band(n(2) permutable(1) coincident(1, 0) unroll(0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
タイリング後
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and
0 <= output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 })
context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
mapping_filter(ids(t0, )
[K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 }
[K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 })
band(n(1) permutable(1) coincident(1) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
thread_specific()
band(n(1) permutable(1) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
スレッドマッピング後
domain(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 }
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and 0 <=
output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 })
context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 })
mapping_filter(ids(b0, )
[K, M, N, b0] -> { S_0[output_s0_m, output_s0_n] : (-b0 + output_s0_m) mod 128 = 0 and 0 <= b0 <= 127 }
[K, M, N, b0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-b0 + output_s1_m) mod 128 = 0 and 0 <= b0 <= 127 })
band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] }
-----------------------------------------------------------------------
mapping_filter(ids(t0, )
[K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 }
[K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 })
band(n(1) permutable(1) coincident(1) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] }
-----------------------------------------------------------------------
thread_specific()
band(n(1) permutable(1) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] }
| [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] }
-----------------------------------------------------------------------
sequence()
filter(
[K, M, N] -> { S_0[output_s0_m, output_s0_n] })
filter(
[K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] })
ブロックマッピング後
template<typename T> inline __device__ T floord(T n, T d) {
return n < 0 ? - (-n + d - 1)/d : n / d;
}
#define if_then_else(cond,a,b) ((cond) ? (a) : (b))
// Can't include system dependencies with NVRTC
// Can't include cuda_fp16.h with NVRTC due to transitive system dependencies
// #include <cuda_fp16.h>
// Halide type handling
typedef char int8;
typedef short int16;
typedef int int32;
typedef long int64;
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long uint64;
// typedef half float16;
typedef float float32;
typedef double float64;
#define inff __int_as_float(0x7f800000)
#define inf __longlong_as_double(0x7ff0000000000000LL)
// Before CUDA 9, syncwarp is a noop since warps are always synchronized.
#if __CUDACC_VER_MAJOR__ < 9
__device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {}
#endif
extern "C" {
__global__ void matmul_4_3_5(int32 K, int32 M, int32 N, float32* poutput, const float32* pA, const float32* pB) {
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
float32 (*output)[5] = reinterpret_cast<float32 (*)[5]>(poutput);
const float32 (*A)[4] = reinterpret_cast<const float32 (*)[4]>(pA);
const float32 (*B)[5] = reinterpret_cast<const float32 (*)[5]>(pB);
output[b0][t0] = 0.000000f;
for (int c4 = 0; c4 <= 3; c4 += 1) {
output[b0][t0] = (output[b0][t0] + (A[b0][c4]*B[c4][t0]));
}
}
}
コード生成
パラメータのオートチューニング
 遺伝的アルゴリズムを用いて
より良いパラメータを探索する
https://research.fb.com/announcing-tensor-comprehensions/
まとめ
Halide IR良くできてる
– 制約は正義
• Polyhedral Transformation等応用的な最適化手法を適用可能
• 解析時の計算量爆発がおきにくい
– IRを操作する関数の実装が揃ってる
• 自前の解析や変形を行う場合でも多くの機能を転用可能
– IRの変換が書きやすい
• IRMutator/IRVisitorのクラスが割とシンプルで書きやすい
– TC/TVM/Tiramisu等、
Halide IRを再利用する取り組みも出てきた
High Level Compiler IRとしてのHalideに今後も要注目!

Mais conteúdo relacionado

Mais procurados

新しい並列for構文のご提案
新しい並列for構文のご提案新しい並列for構文のご提案
新しい並列for構文のご提案yohhoy
 
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能準同型暗号の実装とMontgomery, Karatsuba, FFT の性能
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能MITSUNARI Shigeo
 
C++ AMPを使ってみよう
C++ AMPを使ってみようC++ AMPを使ってみよう
C++ AMPを使ってみようOsamu Masutani
 
Cython ことはじめ
Cython ことはじめCython ことはじめ
Cython ことはじめgion_XY
 
条件分岐とcmovとmaxps
条件分岐とcmovとmaxps条件分岐とcmovとmaxps
条件分岐とcmovとmaxpsMITSUNARI Shigeo
 
Haswellサーベイと有限体クラスの紹介
Haswellサーベイと有限体クラスの紹介Haswellサーベイと有限体クラスの紹介
Haswellサーベイと有限体クラスの紹介MITSUNARI Shigeo
 
Wrapping a C++ library with Cython
Wrapping a C++ library with CythonWrapping a C++ library with Cython
Wrapping a C++ library with Cythonfuzzysphere
 
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013GPUが100倍速いという神話をぶち殺せたらいいな ver.2013
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013Ryo Sakamoto
 
C++による数値解析の並列化手法
C++による数値解析の並列化手法C++による数値解析の並列化手法
C++による数値解析の並列化手法dc1394
 
SSE4.2の文字列処理命令の紹介
SSE4.2の文字列処理命令の紹介SSE4.2の文字列処理命令の紹介
SSE4.2の文字列処理命令の紹介MITSUNARI Shigeo
 
高速な倍精度指数関数expの実装
高速な倍精度指数関数expの実装高速な倍精度指数関数expの実装
高速な倍精度指数関数expの実装MITSUNARI Shigeo
 
GPGPU deいろんな問題解いてみた
GPGPU deいろんな問題解いてみたGPGPU deいろんな問題解いてみた
GPGPU deいろんな問題解いてみたRyo Sakamoto
 
NumPyが物足りない人へのCython入門
NumPyが物足りない人へのCython入門NumPyが物足りない人へのCython入門
NumPyが物足りない人へのCython入門Shiqiao Du
 
Brief introduction of Boost.ICL
Brief introduction of Boost.ICLBrief introduction of Boost.ICL
Brief introduction of Boost.ICLyak1ex
 
組み込み関数(intrinsic)によるSIMD入門
組み込み関数(intrinsic)によるSIMD入門組み込み関数(intrinsic)によるSIMD入門
組み込み関数(intrinsic)によるSIMD入門Norishige Fukushima
 

Mais procurados (20)

LLVM最適化のこつ
LLVM最適化のこつLLVM最適化のこつ
LLVM最適化のこつ
 
Prosym2012
Prosym2012Prosym2012
Prosym2012
 
Boost.SIMD
Boost.SIMDBoost.SIMD
Boost.SIMD
 
新しい並列for構文のご提案
新しい並列for構文のご提案新しい並列for構文のご提案
新しい並列for構文のご提案
 
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能準同型暗号の実装とMontgomery, Karatsuba, FFT の性能
準同型暗号の実装とMontgomery, Karatsuba, FFT の性能
 
C++ AMPを使ってみよう
C++ AMPを使ってみようC++ AMPを使ってみよう
C++ AMPを使ってみよう
 
Cython ことはじめ
Cython ことはじめCython ことはじめ
Cython ことはじめ
 
条件分岐とcmovとmaxps
条件分岐とcmovとmaxps条件分岐とcmovとmaxps
条件分岐とcmovとmaxps
 
Haswellサーベイと有限体クラスの紹介
Haswellサーベイと有限体クラスの紹介Haswellサーベイと有限体クラスの紹介
Haswellサーベイと有限体クラスの紹介
 
Wrapping a C++ library with Cython
Wrapping a C++ library with CythonWrapping a C++ library with Cython
Wrapping a C++ library with Cython
 
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013GPUが100倍速いという神話をぶち殺せたらいいな ver.2013
GPUが100倍速いという神話をぶち殺せたらいいな ver.2013
 
C++による数値解析の並列化手法
C++による数値解析の並列化手法C++による数値解析の並列化手法
C++による数値解析の並列化手法
 
SSE4.2の文字列処理命令の紹介
SSE4.2の文字列処理命令の紹介SSE4.2の文字列処理命令の紹介
SSE4.2の文字列処理命令の紹介
 
高速な倍精度指数関数expの実装
高速な倍精度指数関数expの実装高速な倍精度指数関数expの実装
高速な倍精度指数関数expの実装
 
GPGPU deいろんな問題解いてみた
GPGPU deいろんな問題解いてみたGPGPU deいろんな問題解いてみた
GPGPU deいろんな問題解いてみた
 
NumPyが物足りない人へのCython入門
NumPyが物足りない人へのCython入門NumPyが物足りない人へのCython入門
NumPyが物足りない人へのCython入門
 
boost tour 1.48.0 all
boost tour 1.48.0 allboost tour 1.48.0 all
boost tour 1.48.0 all
 
Boost Tour 1.50.0 All
Boost Tour 1.50.0 AllBoost Tour 1.50.0 All
Boost Tour 1.50.0 All
 
Brief introduction of Boost.ICL
Brief introduction of Boost.ICLBrief introduction of Boost.ICL
Brief introduction of Boost.ICL
 
組み込み関数(intrinsic)によるSIMD入門
組み込み関数(intrinsic)によるSIMD入門組み込み関数(intrinsic)によるSIMD入門
組み込み関数(intrinsic)によるSIMD入門
 

Semelhante a 20180728 halide-study

文字列カーネルによる辞書なしツイート分類 〜文字列カーネル入門〜
文字列カーネルによる辞書なしツイート分類 〜文字列カーネル入門〜文字列カーネルによる辞書なしツイート分類 〜文字列カーネル入門〜
文字列カーネルによる辞書なしツイート分類 〜文字列カーネル入門〜Takeshi Arabiki
 
Ekmett勉強会発表資料
Ekmett勉強会発表資料Ekmett勉強会発表資料
Ekmett勉強会発表資料時響 逢坂
 
2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)
2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)
2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)Shin-ya Koga
 
Python standard 2022 Spring
Python standard 2022 SpringPython standard 2022 Spring
Python standard 2022 Springanyakichi
 
Node.jsでつくるNode.js ミニインタープリター&コンパイラー
Node.jsでつくるNode.js ミニインタープリター&コンパイラーNode.jsでつくるNode.js ミニインタープリター&コンパイラー
Node.jsでつくるNode.js ミニインタープリター&コンパイラーmganeko
 
多値で簡単パーサーコンビネーター
多値で簡単パーサーコンビネーター多値で簡単パーサーコンビネーター
多値で簡単パーサーコンビネーターKeiichiro Shikano
 
C++ tips 3 カンマ演算子編
C++ tips 3 カンマ演算子編C++ tips 3 カンマ演算子編
C++ tips 3 カンマ演算子編道化師 堂華
 
Ekmett勉強会発表資料
Ekmett勉強会発表資料Ekmett勉強会発表資料
Ekmett勉強会発表資料時響 逢坂
 
Chainerの使い方と自然言語処理への応用
Chainerの使い方と自然言語処理への応用Chainerの使い方と自然言語処理への応用
Chainerの使い方と自然言語処理への応用Seiya Tokui
 
Boost jp9 program_options
Boost jp9 program_optionsBoost jp9 program_options
Boost jp9 program_optionsnyaocat
 
最新C++事情 C++14-C++20 (2018年10月)
最新C++事情 C++14-C++20 (2018年10月)最新C++事情 C++14-C++20 (2018年10月)
最新C++事情 C++14-C++20 (2018年10月)Akihiko Matuura
 
Lambda in template_final
Lambda in template_finalLambda in template_final
Lambda in template_finalCryolite
 
プログラミング言語のパラダイムシフトーScalaから見る関数型と並列性時代の幕開けー
プログラミング言語のパラダイムシフトーScalaから見る関数型と並列性時代の幕開けープログラミング言語のパラダイムシフトーScalaから見る関数型と並列性時代の幕開けー
プログラミング言語のパラダイムシフトーScalaから見る関数型と並列性時代の幕開けーTanUkkii
 
これから Haskell を書くにあたって
これから Haskell を書くにあたってこれから Haskell を書くにあたって
これから Haskell を書くにあたってTsuyoshi Matsudate
 
有限オートマトンとスティッカー系に関するCoqによる形式証明について
有限オートマトンとスティッカー系に関するCoqによる形式証明について有限オートマトンとスティッカー系に関するCoqによる形式証明について
有限オートマトンとスティッカー系に関するCoqによる形式証明についてYoshihiro Mizoguchi
 

Semelhante a 20180728 halide-study (20)

Sml#探検隊
Sml#探検隊Sml#探検隊
Sml#探検隊
 
文字列カーネルによる辞書なしツイート分類 〜文字列カーネル入門〜
文字列カーネルによる辞書なしツイート分類 〜文字列カーネル入門〜文字列カーネルによる辞書なしツイート分類 〜文字列カーネル入門〜
文字列カーネルによる辞書なしツイート分類 〜文字列カーネル入門〜
 
Ekmett勉強会発表資料
Ekmett勉強会発表資料Ekmett勉強会発表資料
Ekmett勉強会発表資料
 
たのしい関数型
たのしい関数型たのしい関数型
たのしい関数型
 
2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)
2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)
2014年の社内新人教育テキスト #2(関数型言語からオブジェクト指向言語へ)
 
Python standard 2022 Spring
Python standard 2022 SpringPython standard 2022 Spring
Python standard 2022 Spring
 
Swiftおさらい
SwiftおさらいSwiftおさらい
Swiftおさらい
 
Node.jsでつくるNode.js ミニインタープリター&コンパイラー
Node.jsでつくるNode.js ミニインタープリター&コンパイラーNode.jsでつくるNode.js ミニインタープリター&コンパイラー
Node.jsでつくるNode.js ミニインタープリター&コンパイラー
 
多値で簡単パーサーコンビネーター
多値で簡単パーサーコンビネーター多値で簡単パーサーコンビネーター
多値で簡単パーサーコンビネーター
 
C++ tips 3 カンマ演算子編
C++ tips 3 カンマ演算子編C++ tips 3 カンマ演算子編
C++ tips 3 カンマ演算子編
 
Ekmett勉強会発表資料
Ekmett勉強会発表資料Ekmett勉強会発表資料
Ekmett勉強会発表資料
 
Chainerの使い方と自然言語処理への応用
Chainerの使い方と自然言語処理への応用Chainerの使い方と自然言語処理への応用
Chainerの使い方と自然言語処理への応用
 
Boost jp9 program_options
Boost jp9 program_optionsBoost jp9 program_options
Boost jp9 program_options
 
最新C++事情 C++14-C++20 (2018年10月)
最新C++事情 C++14-C++20 (2018年10月)最新C++事情 C++14-C++20 (2018年10月)
最新C++事情 C++14-C++20 (2018年10月)
 
Lambda in template_final
Lambda in template_finalLambda in template_final
Lambda in template_final
 
会津合宿2015Day3:D問題
会津合宿2015Day3:D問題会津合宿2015Day3:D問題
会津合宿2015Day3:D問題
 
プログラミング言語のパラダイムシフトーScalaから見る関数型と並列性時代の幕開けー
プログラミング言語のパラダイムシフトーScalaから見る関数型と並列性時代の幕開けープログラミング言語のパラダイムシフトーScalaから見る関数型と並列性時代の幕開けー
プログラミング言語のパラダイムシフトーScalaから見る関数型と並列性時代の幕開けー
 
これから Haskell を書くにあたって
これから Haskell を書くにあたってこれから Haskell を書くにあたって
これから Haskell を書くにあたって
 
Rcppのすすめ
RcppのすすめRcppのすすめ
Rcppのすすめ
 
有限オートマトンとスティッカー系に関するCoqによる形式証明について
有限オートマトンとスティッカー系に関するCoqによる形式証明について有限オートマトンとスティッカー系に関するCoqによる形式証明について
有限オートマトンとスティッカー系に関するCoqによる形式証明について
 

Mais de Fixstars Corporation

製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx
製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx
製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptxFixstars Corporation
 
CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編
CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編
CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編Fixstars Corporation
 
製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~
製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~
製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~Fixstars Corporation
 
製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~
製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~
製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~Fixstars Corporation
 
株式会社フィックスターズの会社説明資料(抜粋)
株式会社フィックスターズの会社説明資料(抜粋)株式会社フィックスターズの会社説明資料(抜粋)
株式会社フィックスターズの会社説明資料(抜粋)Fixstars Corporation
 
CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編
CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編
CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編Fixstars Corporation
 
Fpga online seminar by fixstars (1st)
Fpga online seminar by fixstars (1st)Fpga online seminar by fixstars (1st)
Fpga online seminar by fixstars (1st)Fixstars Corporation
 
Jetson活用セミナー ROS2自律走行実現に向けて
Jetson活用セミナー ROS2自律走行実現に向けてJetson活用セミナー ROS2自律走行実現に向けて
Jetson活用セミナー ROS2自律走行実現に向けてFixstars Corporation
 
いまさら聞けない!CUDA高速化入門
いまさら聞けない!CUDA高速化入門いまさら聞けない!CUDA高速化入門
いまさら聞けない!CUDA高速化入門Fixstars Corporation
 
量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~
量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~
量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~Fixstars Corporation
 
金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化
金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化
金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化Fixstars Corporation
 
いまさら聞けないarmを使ったNEONの基礎と活用事例
いまさら聞けないarmを使ったNEONの基礎と活用事例いまさら聞けないarmを使ったNEONの基礎と活用事例
いまさら聞けないarmを使ったNEONの基礎と活用事例Fixstars Corporation
 
ARM CPUにおけるSIMDを用いた高速計算入門
ARM CPUにおけるSIMDを用いた高速計算入門ARM CPUにおけるSIMDを用いた高速計算入門
ARM CPUにおけるSIMDを用いた高速計算入門Fixstars Corporation
 
株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)Fixstars Corporation
 
株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)Fixstars Corporation
 
ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方
ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方
ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方Fixstars Corporation
 
AIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術について
AIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術についてAIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術について
AIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術についてFixstars Corporation
 
株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)Fixstars Corporation
 
第8回 社内プログラミングコンテスト 結果発表会
第8回社内プログラミングコンテスト 結果発表会第8回社内プログラミングコンテスト 結果発表会
第8回 社内プログラミングコンテスト 結果発表会Fixstars Corporation
 
第8回 社内プログラミングコンテスト 第1位 taiyo
第8回社内プログラミングコンテスト 第1位 taiyo第8回社内プログラミングコンテスト 第1位 taiyo
第8回 社内プログラミングコンテスト 第1位 taiyoFixstars Corporation
 

Mais de Fixstars Corporation (20)

製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx
製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx
製造業向け量子コンピュータ時代のDXセミナー_生産計画最適化_20220323.pptx
 
CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編
CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編
CPU / GPU高速化セミナー!性能モデルの理論と実践:実践編
 
製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~
製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~
製造業向け量子コンピュータ時代のDXセミナー~ 最適化の中身を覗いてみよう~
 
製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~
製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~
製造業向け量子コンピュータ時代のDXセミナー ~見える化、分析、予測、その先の最適化へ~
 
株式会社フィックスターズの会社説明資料(抜粋)
株式会社フィックスターズの会社説明資料(抜粋)株式会社フィックスターズの会社説明資料(抜粋)
株式会社フィックスターズの会社説明資料(抜粋)
 
CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編
CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編
CPU / GPU高速化セミナー!性能モデルの理論と実践:理論編
 
Fpga online seminar by fixstars (1st)
Fpga online seminar by fixstars (1st)Fpga online seminar by fixstars (1st)
Fpga online seminar by fixstars (1st)
 
Jetson活用セミナー ROS2自律走行実現に向けて
Jetson活用セミナー ROS2自律走行実現に向けてJetson活用セミナー ROS2自律走行実現に向けて
Jetson活用セミナー ROS2自律走行実現に向けて
 
いまさら聞けない!CUDA高速化入門
いまさら聞けない!CUDA高速化入門いまさら聞けない!CUDA高速化入門
いまさら聞けない!CUDA高速化入門
 
量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~
量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~
量子コンピュータ時代の製造業におけるDXセミナー~生産工程効率化に向けた新たなご提案~
 
金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化
金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化
金融業界向けセミナー 量子コンピュータ時代を見据えた組合せ最適化
 
いまさら聞けないarmを使ったNEONの基礎と活用事例
いまさら聞けないarmを使ったNEONの基礎と活用事例いまさら聞けないarmを使ったNEONの基礎と活用事例
いまさら聞けないarmを使ったNEONの基礎と活用事例
 
ARM CPUにおけるSIMDを用いた高速計算入門
ARM CPUにおけるSIMDを用いた高速計算入門ARM CPUにおけるSIMDを用いた高速計算入門
ARM CPUにおけるSIMDを用いた高速計算入門
 
株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)
 
株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)
 
ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方
ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方
ソフト高速化の専門家が教える!AI・IoTエッジデバイスの選び方
 
AIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術について
AIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術についてAIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術について
AIチップ戦国時代における深層学習モデルの推論の最適化と実用的な運用を可能にするソフトウェア技術について
 
株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)株式会社フィックスターズ 会社説明資料(抜粋)
株式会社フィックスターズ 会社説明資料(抜粋)
 
第8回 社内プログラミングコンテスト 結果発表会
第8回社内プログラミングコンテスト 結果発表会第8回社内プログラミングコンテスト 結果発表会
第8回 社内プログラミングコンテスト 結果発表会
 
第8回 社内プログラミングコンテスト 第1位 taiyo
第8回社内プログラミングコンテスト 第1位 taiyo第8回社内プログラミングコンテスト 第1位 taiyo
第8回 社内プログラミングコンテスト 第1位 taiyo
 

20180728 halide-study

  • 1. Tensor Comprehensions にみる Halide IRの汎用性 Fixstars Solutions, Inc. Takuro Iizuka
  • 2. Takuro Iizuka / @iitaku 北米子会社のFixstars Solutions, Inc. にて HalideのFPGAバックエンドおよびツールチェイン”GENESIS”の開発やってます
  • 3. もくじ  TC: Tensor Comprehensions 概要  TC言語  Inside TC  まとめ
  • 4. TC: Tensor Comprehensionsとは?  テンソル計算の記述言語および 最適化コンパイラフレームワーク  2018.2.14にFacebook AI Researchからリリース  TC言語でアルゴリズムを書くと ライブラリがいい感じに最適化してくれる  PyTorchとシームレスに統合できる
  • 5. TC: Tensor Comprehensionsとは?  テンソル計算の記述言語および 最適化コンパイラフレームワーク  2018.2にFacebook AI Researchからリリース  TC言語でアルゴリズムを書くと ライブラリがいい感じに最適化してくれる  PyTorchとシームレスに統合できる コンパイラの中間表現としてHalide IRを採用
  • 7. TCベンチマーク結果 MLP: Multi-Layer Perceptron TMM: Transposed Matrix Multiplication TBMM: Transposed Batched Matrix Multiplication GCOV: Grouped Convolutions https://research.fb.com/announcing-tensor-comprehensions/
  • 9. TC in PyTorch $ conda create –y –name pytorch python=3.6 $ conda activate pytorch $ conda install -y -c pytorch -c tensorcomp tensor_comprehensions $ python ./matmul.py Variable containing: -2.4028 2.8492 7.6141 3.3159 3.7171 1.3839 0.6650 -1.7253 0.7447 1.3988 0.1396 -0.0661 -1.0574 0.2163 0.1711 [torch.cuda.FloatTensor of size 3x5 (GPU 0)] import tensor_comprehensions as tc import torch mm = """ def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } """ matmul = tc.define(mm, name="matmul") A, B = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() C = matmul(A, B, options=tc.Options("naive")) print(C) 1. TCをセットアップ 2. TC言語でカスタム レイヤを書く 3. 実行する
  • 11. num ::= <number literal with C syntax> id ::= [_a-zA-Z0-9]*[_a-zA-Z][_a-zA-Z0-9]* exp ::= num | ( '-' | '!' | ... ) exp | exp ( [+-*/%] | '==' | '!=' | '<=' | ... ) exp | exp '?' exp ':' exp | id '.' num # range of num-th dimension of id | id '(' exp_list ')' # builtin call or tensor access reduction ::= <associative reduction operator> | '+=' | '*=' | 'min=' | 'max=' | '+=!' | '*=!' | 'min=!' | 'max=!' range_constraint ::= id 'in' exp ':' exp stmt ::= id '(' id_list ')' [ '=' | reduction ] exp [ 'where' range_constraint_list ] | id_list = id '('id_list ')' # TC function call arg ::= type id return ::= id # inferred return type and range scalar_type ::= 'double' | 'float' | 'half' | 'int32' | 'byte' | 'uint32' | ... type ::= scalar_type [ '(' id_list ')' ] func ::= # TC function definition 'def' id '(' arg_list ')' '->' '(' return_list ')' '{' stmt_list '}' id_list ::= <comma separated id list> exp_list ::= <comma separated exp list> arg_list ::= <comma separated arg list> stmt_list ::= <whitespace separated stmt list> return_list ::= <comma separated return list> range_constraint_list ::= <non-empty comma separated range_constraint list>
  • 12. TC言語の特徴  テンソル計算記述特化言語  Halide言語よりさらに簡素な言語体型  超ミニマルなプリミティブ型  畳み込み演算用の特殊なオペレータ群  where節でレンジ制約を記述
  • 13. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } def 関数名 {} で関数定義
  • 14. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } type(X,Y,…)で入力引数に制約 (要素型制約およびレンジ制約)を付与
  • 15. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } レンジ制約に同一シンボルを使用することで、 異なる引数間の制約を表現できる
  • 16. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 出力の制約はコンパイラによって自動推論される
  • 17. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 初期化付き畳み込み用記述のオペレータ
  • 18. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 左辺で定義した誘導変数を右辺で使用すると ループが形成される for m := 0, M for n := 0, N
  • 19. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } ループのレンジは入力制約から決定される e.g. m := [0. M)
  • 20. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 左辺で定義していない誘導変数を右辺で使用した場合、 既存ループの最内に新たにループが形成される for m := 0, M for n := 0, N for r_k := 0, K
  • 21. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } その場合も入力の制約をもとに制約チェックが行われる
  • 22. Halide言語と比較してみる def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } TC言語 Func matmul; ImageParam A{float, 2}; ImageParam B{float, 2}; Var m, n; Rdom r_k{0, K}; matmul(m, n) = sum(A(m, r_k) * B(r_k, n)); matmul.realize({M, N}); Halide言語
  • 23. Halide言語と比較してみる def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } Func matmul; ImageParam A{float, 2}; ImageParam B{float, 2}; Var m, n; Rdom r_k{0, K}; matmul(m, n) = sum(A(m, r_k) * B(r_k, n)); matmul.realize({M, N}); レンジ制約は前向きに推論 レンジ制約は後ろ向きに推論
  • 25. TC GPU Backendコンパイルフロー TC言語 AST Halide IR パース/AST構築: tc/core/compiler.cc, parse(…) Halide IR変換: tc/core/tc2halide.cc, translate(…) Halide IR レンジ推論/ループ形成/簡約化: tc/core/tc2halide.cc, translate(…) Polyhedral IR CUDA C Polyhedral IR変換: tc/core/polyhedral/scop.cc, makeScop(…) Polyhedral IR ループ変形/スレッディング: tc/core/polyhedral/mapped_scop.cc, makeWithOuterBlockInnerThreadStrategy(…) コード生成: tc/core/polyhedral/mapped_scop.cc, codegen(…)
  • 26. Halide IR オリジナルのHalide実装では、 中間表現の変換過程においてHalide IRを2つに大別できる Halide IR Halide IR Halide IR Halide IR Halide IR Call/Provide系 Load/Store系 storage_flattening以前/以後
  • 27. Call/Provide系 抽象度高い Load/Store系 抽象度低い for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Provide(out, {x, y}) = Call(in, {x, y}) } } for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Store(out, y*out.stride.1+x, Load(in, y*in.stride.1.x)) } }
  • 28. TCが取り扱うHalide IRは 抽象度の高いCall/Provide系のみ for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Provide(out, {x, y}) = Call(in, {x, y}) } } for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Store(out, y*out.stride.1+x, Load(in, y*in.stride.1.x)) } } ターゲットコードに近い中間表現は Halide IRではなくPolyhedral IRを使う
  • 29. Halide IR上でのLowering  コンパイラインフラストラクチャとしてのHalide IR – Halide中にはHalide IR (Halide::Stmt/Expr)を操作する 関数が多数実装済み • 簡約器 • ソルバー • 範囲演算 • CSE • などなど – IRMutatorやIRVisitorクラスを使用すれば Halide IRに対する変換や解析を独自に実装できる  TCのHalide IR Loweringは以下を行っている – レンジ推論 – ループ形成 – 簡約化
  • 30. レンジ推論で使用されるHalide API  solve_for_inner_interval(c, v) – 条件式cを必ず満たす変数vの最大範囲を計算する  and_condition_over_domain(c, varying) – 変数範囲varyingの仮定のもとで条件式cを簡約化する  simplify(e) – 式eを簡約化する これらを組み合わせてレンジ推論を行い、後段で行われる Polyhedral最適化に必要な条件を満たすかをテストしておく
  • 31. ループ形成で使用されるHalide API  realization_order(…) – Provide/Call間の依存グラフを トポロジカルソートによって順序付けをする  schedule_functions(…) – 出力が依存するすべての計算ループを形成し、 Halide IR (Halide::Stmt) を返す 後の解析や変換のベースとなるループ構築を行う
  • 32. ループ形成結果 後の解析や変換のベースとなるループ構築を行う produce output { let output.s0.n.loop_max = output.s0.n.max let output.s0.n.loop_min = output.s0.n.min let output.s0.n.loop_extent = ((output.s0.n.max + 1) - output.s0.n.min) let output.s0.m.loop_max = output.s0.m.max let output.s0.m.loop_min = output.s0.m.min let output.s0.m.loop_extent = ((output.s0.m.max + 1) - output.s0.m.min) for (output.s0.m, output.s0.m.loop_min, output.s0.m.loop_extent) { for (output.s0.n, output.s0.n.loop_min, output.s0.n.loop_extent) { output(output.s0.m, output.s0.n) = 0.000000f } } let output.s1.r_k.loop_extent = ((output.s1.r_k.max - output.s1.r_k.min) + 1) let output.s1.r_k.loop_max = output.s1.r_k.max let output.s1.r_k.loop_min = output.s1.r_k.min let output.s1.n.loop_max = output.s1.n.max let output.s1.n.loop_min = output.s1.n.min let output.s1.n.loop_extent = ((output.s1.n.max + 1) - output.s1.n.min) let output.s1.m.loop_max = output.s1.m.max let output.s1.m.loop_min = output.s1.m.min let output.s1.m.loop_extent = ((output.s1.m.max + 1) - output.s1.m.min) for (output.s1.m, output.s1.m.loop_min, output.s1.m.loop_extent) { for (output.s1.n, output.s1.n.loop_min, output.s1.n.loop_extent) { for (output.s1.r_k, output.s1.r_k.loop_min, output.s1.r_k.loop_extent) { output(output.s1.m, output.s1.n) = ReductionUpdate((output(output.s1.m, output.s1.n) + (A(output.s1.m, output.s1.r_k)* B(output.s1.r_k, output.s1.n)))) } } } } 初期化 計算
  • 33. 簡約化で使用されるHalide API  LetStmt::make(n, e, s) – 文s中で式eをシンボルnに束縛する  simplify(s) – 文sを簡約化する レンジ制約を適用しループ範囲の簡約化を行う
  • 34. 簡約化結果 後段のPolyhedral Transformationで 解析・変形可能なループ構造に簡約できた for (output.s0.m, 0, M) { for (output.s0.n, 0, K) { output(output.s0.m, output.s0.n) = 0.000000f } } for (output.s1.m, 0, M) { for (output.s1.n, 0, K) { for (output.s1.r_k, 0, K) { output(output.s1.m, output.s1.n) = ReductionUpdate((output(output.s1.m, output.s1.n) + (A(output.s1.m, output.s1.r_k)* B(output.s1.r_k, output.s1.n)))) } } } ループ範囲がTC言語のレンジ制約と対応している
  • 35.  tc::polyhedral::Scop – ISL (Integer Set Library) /Polyhedral Compilation Libraryを用いて計算されたスケジューリング – RAW, WAR, WAW 依存関係 – メモリ配置 – Halide IR関連 • パラメータ • 入出力 • Stmt Polyhedral IR Polyhedral IR Halide IR パラメータ/ 入出力 Polyhedral IR変換 Stmt of Provide A Stmt of Provide B パラメータ/ 入出力 スケジューリング 依存関係 メモリ配置
  • 36. Polyhedral IR 上でのLowering 1. ループ融合 2. タイリング – パラメータ: タイリング戦略 3. スレッドマッピング – パラメータ: スレッドサイズ 4. ブロックマッピング – パラメータ: ブロックサイズ 5. メモリマッピング – パラメータ: 有効/無効、マッピング先、マッピング量 Polyhedral Transformationとは、 ループ構造をPolytope=多面体と見立ててアフィン変換を施すことで Legalなループ変形を行う最適化手法
  • 37. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <= output_s1_n < N and 0 <= output_s1_r_k < K }) sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } ----------------------------------------------------------------------- band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } ----------------------------------------------------------------------- filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- 初期状態のスケジューリングツリー 雑な解説: band=ループ、 filter=ステートメント
  • 38. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <= output_s1_n < N and 0 <= output_s1_r_k < K }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) ループ融合後
  • 39. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <= output_s1_n < N and 0 <= output_s1_r_k < K }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- band(n(2) permutable(1) coincident(1, 0) unroll(0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) タイリング後
  • 40. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and 0 <= output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 }) context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- mapping_filter(ids(t0, ) [K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 } [K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 }) band(n(1) permutable(1) coincident(1) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- thread_specific() band(n(1) permutable(1) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) スレッドマッピング後
  • 41. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and 0 <= output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 }) context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 }) mapping_filter(ids(b0, ) [K, M, N, b0] -> { S_0[output_s0_m, output_s0_n] : (-b0 + output_s0_m) mod 128 = 0 and 0 <= b0 <= 127 } [K, M, N, b0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-b0 + output_s1_m) mod 128 = 0 and 0 <= b0 <= 127 }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- mapping_filter(ids(t0, ) [K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 } [K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 }) band(n(1) permutable(1) coincident(1) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- thread_specific() band(n(1) permutable(1) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) ブロックマッピング後
  • 42. template<typename T> inline __device__ T floord(T n, T d) { return n < 0 ? - (-n + d - 1)/d : n / d; } #define if_then_else(cond,a,b) ((cond) ? (a) : (b)) // Can't include system dependencies with NVRTC // Can't include cuda_fp16.h with NVRTC due to transitive system dependencies // #include <cuda_fp16.h> // Halide type handling typedef char int8; typedef short int16; typedef int int32; typedef long int64; typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint32; typedef unsigned long uint64; // typedef half float16; typedef float float32; typedef double float64; #define inff __int_as_float(0x7f800000) #define inf __longlong_as_double(0x7ff0000000000000LL) // Before CUDA 9, syncwarp is a noop since warps are always synchronized. #if __CUDACC_VER_MAJOR__ < 9 __device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {} #endif extern "C" { __global__ void matmul_4_3_5(int32 K, int32 M, int32 N, float32* poutput, const float32* pA, const float32* pB) { int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; float32 (*output)[5] = reinterpret_cast<float32 (*)[5]>(poutput); const float32 (*A)[4] = reinterpret_cast<const float32 (*)[4]>(pA); const float32 (*B)[5] = reinterpret_cast<const float32 (*)[5]>(pB); output[b0][t0] = 0.000000f; for (int c4 = 0; c4 <= 3; c4 += 1) { output[b0][t0] = (output[b0][t0] + (A[b0][c4]*B[c4][t0])); } } } コード生成
  • 45. Halide IR良くできてる – 制約は正義 • Polyhedral Transformation等応用的な最適化手法を適用可能 • 解析時の計算量爆発がおきにくい – IRを操作する関数の実装が揃ってる • 自前の解析や変形を行う場合でも多くの機能を転用可能 – IRの変換が書きやすい • IRMutator/IRVisitorのクラスが割とシンプルで書きやすい – TC/TVM/Tiramisu等、 Halide IRを再利用する取り組みも出てきた High Level Compiler IRとしてのHalideに今後も要注目!