GPUのメモリからCPUのメモリに逃がすテク「TFLMS」の論文を読んだ

シェアする

TFLMSという技術に関する論文を読んだので、軽くまとめます。(きっかけは第1回メディカルAI学会でIBMの方が発表されていたからです。)

今後もこんな感じのを書いていきたいと思うのですが、あくまで自分の理解に基づいて書いているので、間違っている点などがあるかもしれません。実際に論文を利用する場合は必ず原著を参照してください。(また、コメントやTwitterなどで指摘いただけるとありがたいです)

スポンサーリンク

原著

Tung D. Le, Haruki Imai, Yasushi Negishi, Kiyokuni Kawachiya. 2018. TFLMS: Large Model Support in TensorFlow by Graph Rewriting (arxiv)

概要

大きい画像や3Dのタスクにdeep learningを適用する際に、GPUメモリが制約となることがある。CPUのメモリに適切なタイミングで逃してやれば、その制約をなくせると考え、計算グラフに自動的にそのような処理を追加するtensorflow(とkeras)のライブラリを開発した、らしい。すごい。

こちらでコードが公開されている。

https://github.com/IBM/tensorflow-large-model-support

用語

swap-out: GPUのメモリからCPUのメモリにデータを逃がす

swap-in: CPUのメモリからGPUのメモリへデータを取り戻す

だから必ず、swap-outの後にswap-inが来る。

課題

swap-out, swap-inの操作はやはりロスが生じるので回数を減らしたい。また、GPUのメモリを節約するのが目的なので、出来るだけ早くswap outさせたいし、できるだけ遅くswap inさせたい

fusing

同じtensorを複数回使うなら、swap outを複数回する必要はなくて、まとめてやれば良い。(swap-out operationをfuseする)

あるtensorが大きく、それを使うタイミングが近い場合、swap-inをfuseさせることも考えられる。デメリットは、swap-in operationが早くなる(=GPUメモリを専有する)こと。

タイミング

ここでは、swap-inのタイミングを考える。swap-inはできるだけ遅くしたいが、遅すぎると待機時間が生じる(overheadになる)。そこで、consuming operation(tensorを使うvertex)とswap-in vertexのcontrol edge(swap-inを起動する辺みたいな)のtopological order(vertexに付与される実行の順番みたいな)の差kが一定範囲内に入るようにすることを考える。(この範囲はハイパーパラメータとなる)

どのようにしてこのようなcontrol edgeを決めるかは2つの方法が提案されている。

direct-order strategy

kの下限から探索を始め、topological orderがその値となる全てのverticesを取ってきて、consuming operationにreachableだったら、そこからcontrol edgeを生やす。

chain-rule strategy

こちらは、”chain-rule”から感じられるように具体的にneural network(NN)などを意識している。単にグラフとして考えているのではなく、もっと具体的に、forwardとbackwardの存在するNNのcomputation graphとして考え、その性質に着目して良いcontrol edgeを生成する。

NNのgraphを考えると、forwardのvertexから出るedgeは次のlayerに相当するforwardのvertexへ行くものと、対応するbackwardのvertexに行くものがあるはず。ここで、forwardに行くものについては、すぐ計算に使われていらなくなる。しかし、backwardに使われるものについては、これ以降のlayerのforwardとbackwardの計算の後にようやく使われることになる。その間、GPUメモリを専有することとなり勿体無い。

だから、forwardのvertexからbreadth-first searchをしていき、対応するbackwardのvertexについて、kが範囲内にあれば、そこからcontrol edgeを生やす。

どうやらTFLMSではchain-rule strategyを採用している模様。

まとめ

実用性については、experimentsを見る限りは上手く行っているようです。やっぱり、patchに分けず、そのままモデルに入れられれば、精度は上がりそうなので自然です。今度自分で試してみたいと思います。(太字にすることで自分にプレッシャーを掛けていく)

GitHubのドキュメントを見るとkerasでも簡単に使えるということなので、嬉しいですね。

スポンサーリンク

シェアする

フォローする