Mais conteúdo relacionado Semelhante a Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model (MuZero) (12) Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model (MuZero)1. Mastering Atari, Go, Chess
and Shogi by Planning with a
Learned Model
調和系工学研究室 B4 織田智矢
Julian Schrittwieser,1 Ioannis Antonoglou,1;2 Thomas Hubert,1
Karen Simonyan,1 Laurent Sifre,1 Simon Schmitt,1 Arthur Guez,1
Edward Lockhart,1 Demis Hassabis,1 Thore Graepel,1;2 Timothy Lillicrap,1
David Silver1;2
1DeepMind, 6 Pancras Square, London N1C 4AG.
2University College London, Gower Street, London WC1E 6BT.
4. 強化学習おさらい
強化学習の枠組み
• s : 状態 (state)
• a : 行動 (action)
• r : 報酬 (reward)
• エージェントがある行動atを行って,
環境から次の状態st+1と報酬rt+1を受け取る
という枠組みが基本である
出典 : https://qiita.com/Hironsan/items/56f6c0b2f4cfd28dd906
5. 強化学習の大まかな分類
Model-Free RL
• 環境に対する経験から直接価値関数や方策の学習
を行う
• 一般的なQ学習や前ページのスライドはModel-
Free
Model-Based RL
• まず環境に対する経験から環境のモデルを学習
• そのモデルからnステップのサンプルを取得して価
値関数や方策を学習する
(これをプランニングという)
出典 : https://shirakonotempura.hatenablog.com/entry/2019/02/08/162541, Reinforcement
Learning: An Introduction, R Sutton & A Barto, p.162
8. (A)モデルを使用したプランニング詳しく
• MCTS(Monte Carlo Tree Search)を一般化した手法を使用
MCTSのノードは を保持する
• N:訪れた数, Q:評価値, P:方策(次行動の確率分布), R:報酬, S:状態
MCTSは次の3ステップ
1. Selection
• ルートノード(𝑠0)から葉ノードまで評価値が高いノードを選択(葉ノード到達をlとおく)
2. Expansion
• ダイナミクス関数から 𝑟 𝑙
, 𝑠 𝑙を受け取り, 𝑅 𝑠 𝑙−1
, 𝑎 𝑙
= 𝑟 𝑙
, S 𝑠 𝑙−1
, 𝑎 𝑙
= 𝑠 𝑙
を現在のノードへ
• {𝑁 = 0, 𝑄 = 0, 𝑃 = 𝒑𝒍}で新しい葉ノード作成
3. Backup
• 親ノードのQを計算,N+=1,
𝑣 𝑡, 𝜋 𝑡 = 𝑀𝐶𝑇𝑆 𝑠𝑡
𝑎 𝑡~𝜋 𝑡
探索方策πはルートノードからの各
行動のノードNに比例
10. モデルの概要
(C) モデルの訓練
• 履歴はReplay Bufferからサンプリング
• 表現関数ℎにサンプリングした履歴から過去
の観測𝑜1, ⋯ , 𝑜𝑡を入力
• モデルはKステップにアンロールされる
• 各ステップkでダイナミクス関数𝑔に𝑠 𝑘−1と
実際にとった行動𝑎 𝑡+𝑘
を入力
これらのℎ, 𝑔, 𝑓の関数(表現,ダイナミク
ス,予測)は3つの量を予測するため
BPTTで同時にend-to-endで学習
3つの量: 𝒑 𝑘
≈ 𝜋 𝑡+𝑘, 𝑣 𝑘
≈ 𝑧𝑡+𝑘, 𝑟𝑡+𝑘 ≈ 𝑢 𝑡+𝑘
𝑧𝑡+𝑘とは最終報酬(board game) or n-step
return(Atari)
隠れ状態 : 𝑠
行動 : 𝑎
報酬 : 𝑟
価値関数 : 𝑣
次行動の確率分布 : 𝑝
12. MuZeroアルゴリズム
1. MCTSによって履歴を貯める (自己対戦)
2. 履歴からサンプルしモデルをK-stepアンロールして学習
• 予測方策 と探索方策(実際にとる行動)の誤差を最小化 𝒑 𝑘 ≈ 𝜋 𝑡+𝑘
• 予測価値と目標価値の誤差を最小化 𝑣 𝑘
≈ 𝑧 𝑡+𝑘
• 予測報酬と環境からの報酬の誤差を最小化 𝑟𝑡+𝑘 ≈ 𝑢 𝑡+𝑘
具体的には以下の誤差関数𝑙 𝑡(𝜃)を最小化(最後の項はL2正則化)