deconvolution(transposed convolution)はsegmentationやGANなどで多用される、convolutionの逆操作のようなものです。しかし、具体的な操作はやや難解です。実は、transposed convolutionはconvolutionの逆伝播と共通しており、convolutionの逆伝播は別のconvolutionで表現できます。つまり、transposed convolutionはある意味convolutionなんだというのが今回の主題です。
前提
この記事を理解するには以下の知識があると良いです。
- backpropagation(全結合層のみで大丈夫)
- convolution (数学的な定義ではなく、グラフィカルなイメージだけで大丈夫)
途中でTensorflow実装を見る節がありますが、Tensorflow特有の話ではなく、有名なライブラリではtransposed convolutionをどのように捉えているのかを見るだけなので、Tensorflowの知識は不要です。
名称の整理
次はすべて同じものです。
- transposed convolution
- deconvolution
- up-convolution
これ以降は、transposed convolutionで統一します。他に比べると長いですが、より正確だと考えるからです。また、convolutionを略してconv.と表記している部分もあります。
なぜtransposed convolutionを使うのか
既にtransposed convolutionを使っている人は読み飛ばして大丈夫です。
なぜconvolutionの逆操作が必要なのでしょうか。それは、convolutionにより小さくなったfeature mapの大きさを元に戻すためです(入出力が画像のモデルでは畳み込んだのをもとの大きさに戻す必要がありますよね)。画像を大きくするには、色々な補間方法がありますが、そうではなくnetworkの中で、parameterを持ったtrainableな方法で大きさをもとに戻したいのです。
したがって、transposed convolutionが、値の意味でconvolutionの逆変換になって欲しいというわけではなく、shapeの意味での逆操作になってほしいのです。
transposed convolutionは意外と難しい
convolutionの逆操作と聞くと、特に難しいことはないように思えますが、(私にとっては)そこまで簡単ではありませんでした。
例えば、transposed convolutionのアニメーションを見てみてください。いくつか疑問が浮かぶと思います。まず第一に、transposed convolutionをしているはずなのに、convolutionとそっくりです。また、paddingに注目してみてください。一番最初の例ではNo paddingと書いてあるのに、アニメーションではpaddingがあります。
しかし、これらは全く正しいアニメーションです。この記事を最後まで読めば、これらの疑問は解けると思います。
Tensorflowでの実装
さて、Tensorflowでの実装を見てみましょう。Tensorflowの知識がなくても大丈夫です。
transpose convolution層を定義するのは Conv2DTranspose
クラスです。そのメソッドである call()
はforward propagationのときに呼ばれます。call()
内では次のように処理が受け渡されていきます。
backend.conv2d_transpose()
nn.conv2d_transpose
nn.conv2d_transpose_v2
gen_nn_ops.conv2d_backprop_input
本質的な処理は最後の gen_nn_ops.conv2d_backprop_input
で行われており、途中の関数ではテンソルの順番調整などの細かい処理が行われます。 gen_nn_ops.conv2d_backprop_input
はC++コードのwrapperで、convolutionに対するinputの勾配を求める関数です。C++のコードまでは降りませんが、この関数がどのような入力を受け取り、何を計算して出力するのかを見ておきます。
name | 渡す引数 | 説明 |
---|---|---|
input_sizes |
output_shape |
An integer vector representing the shape of input |
filter |
filters |
A Tensor . Must be one of the following types:half , bfloat16 , float32 , float64 . |
out_backprop |
input |
A Tensor . Must have the same type as filter . Gradients w.r.t. the output of the convolution. |
strides |
strides |
A list of ints . The stride of the sliding window for each dimension of the input of the convolution. Must be in the same order as the dimension specified with format. |
padding |
padding |
Either the string "SAME" or "VALID" |
引数の名前と説明に加え、 nn.conv2d_transpose_v2
から呼び出す際に渡している引数の列も加えました。というのも、よく見ると、一部の引数は明らかに本来渡すべき値とは異なる値を渡していることが分かるからです。
input_sizes
に対して output_shape
を、 out_backprop
に対して、 input
を渡していますね。これを頭の片隅に置きつつ、まずは、 conv2d_backprop_input
の本来の役割である、通常のconvolution層における逆伝播について見ていきましょう。
通常のconvolution層における逆伝播
端的に言うと、通常のconv層は逆伝播においてもconvolutionを行います。しかしそれは、順伝播(forward)のときとは異なるfilterとinputを持ちます。
この点についてより丁寧に理解したい方は、https://medium.com/@pavisj/convolutions-and-backpropagations-46026a8f5d2c を参照してください。(この記事では入力の勾配のみを扱っています)
今、convolution層を含むneural networkを考えます。ネットワーク全体のlossは、convolutionへの入力は、出力を、フィルタサイズをとします。また、入力を、フィルタをとし、strideを2、paddingはなしとします。よって、出力はとなります。
conv2d_backprop_input
の役割は、 (out_backprop
)と、 (filter
)を受け取り、を求めることです。
busyな図ですみません。しかし、この図にconv.のbackpropを理解するのに必要な知識を詰め込めたと思います。
最終的には、Xの黄色に塗られたピクセルの勾配を求めたいのですが、まずはforwardの計算から見ていきましょう。
forward
左上から見てください。通常のconv.のforwardです。出力の各ピクセルを4色で塗ってみました。それぞれの対応するfilter位置はに枠線で示しました。
次に、ここがちょっと難しいかもしれませんが、フィルタがそれぞれ枠線の位置にあるとき、黄色のピクセルがフィルタ上のどこに位置するかが上に示されています。枠線の色と対応しています。
この塗り分けにどのような意味があるのでしょうか。例えば出力の緑を計算するときは、フィルタが緑枠の位置にあります。このとき、黄色ピクセルの値はの緑ピクセルの値と掛けられます。すなわち、を表しています。
黄色ピクセルが、緑色ピクセルの出力を通してlossに与える影響は、です。
求めるべき値は、すべての出力についてそれぞれを経由するときの勾配を足したものとなりますから、図中に示したで始まる式になります。
これで1つのピクセルについて、求めるべき勾配は求まったのですが、今度は今行った計算をconvolutionで表せないかと考えてみます。話は中央の灰色線を右へ超えてbackwardと書かれた領域に移ります。
backward
まずは概略を示します。との畳み込みにより、が得られます。しかし、単純に畳み込んだだけでは辻褄が合わないので、ちょっとした変形を加える必要があります。
まず、についてですが、forwardのときにstridesを使っており、出力テンソルが小さく圧縮されているため、backwardではpaddingを入れて辻褄を合わせる必要があります。
次に、についてですが、180度回転させてから畳み込みます。これは上で同じ位置にあるピクセルについて、フィルタを右にずらすとき、そのピクセルは、フィルタ上では左に移動するといった関係があるからです。
最後に、full paddingをします。図には示されていませんがzero paddingを行い、全ピクセルのフィルタがかかる回数を等しくします。今回の場合は4辺から2つずつpaddingしてあげれば良いです。
他のピクセルについても同じ
今回はについて考えましたが、他のピクセルについて考えるときは、茶色のエリアの色のみが変わります。つまり、フィルタの色付けのみが変わります。フィルタ自体が変わるわけではありません。出力との対応関係が変わるだけです。(茶色のNOTEの内容)
例えば、(1,1)について考えてみましょう。一番左上です。まず、forwardについてですが、はにしか影響しないので、の色付けは左上だけが緑になります。が水色であれ緑色であれ値は変わりません。変わるのは対応関係のみです。
続いて、backwardについて考えます。180 rotationにより、の色付けは右下だけとなります。の色付けは変わりません。今、(1,1)について考えていたので、について考えます。フィルタ位置はpaddingの領域まではみ出しますが、ちょうど緑が重なることになります。
transposed convolutionは一体何なのか
convolutionの逆伝播について理解できれば、transposed convolutionで何が行われているかを理解するのは簡単です。
この図は、convolutionの逆伝播で使用した図から不要な部分を除き書き換えたものです。色については、直接的な意味はないものの、ベースにconv.の逆伝播があったことを意識するために残してあります。(下の灰色NOTEの内容)
Tensorflowの実装で、本来conv.の逆伝播を行うはずの関数に、どのようなパラメータを渡していたのかを思い出してください。
まず、out_backprop
を渡すべきところに input
を渡していました。out_backprop
はのことです。つまり、図の左下がInputとなります。
filterはそのままですね。
そして、input_sizes
を渡すべきところに output_shape
を渡していました。これはさすがに逆では? と思うかもしれません。しかし、この関数が想定している input_sizes
というのはforwardのときのinput、つまりのサイズであり、のサイズです。transposed convolutionではがOutputになりますから、input_sizes
に output_shape
を渡すのは妥当です。
そして、これらを用いてfull padding convolutionを行った出力が、transposed convolutionの出力なのです。
なぜno paddingなのにpaddingがあったのか
画像の一番上にある青色のNOTEの内容です。“no padding”のtransposed conv.なのに、実際にはpaddingを行っています。これは誤りでしょうか。いいえ。
input_sizes
の理屈と同じです。この関数が想定しているpaddingというのは、conv.のforwardにおけるpaddingです。前節ではforward時にpaddingはありませんでしたよね。ですから、no paddingで正しいです。transposed conv.の過程で(あるいはbackprop計算の過程で)paddingを行うのは、forwardのときにstridesがあるためです。strides=1
であれば、このpaddingは必要ありません。
結論
transposed convolutionの順伝播計算とconvolutionの逆伝播計算は共通部分が多いです。Tensorflowの実装ではconvolutionの逆伝播計算を用いて、transposed convolutionの順伝播計算を行っています。
参考文献
- https://medium.com/@pavisj/convolutions-and-backpropagations-46026a8f5d2c
- https://github.com/vdumoulin/conv_arithmetic
- https://medium.com/activating-robotic-minds/up-sampling-with-transposed-convolution-9ae4f2df52d0
- http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf
あとがき
こういった理論寄りの話は、他の人と被るのであまり書いていなかったのですが、今回のテーマはあまりカバーされていないようだったので、書いてみました。結構時間がかかってしまいましたが、自分の勉強にもなったと思います。今後はこういうのを増やしていきたいという気持ちになりました。
内容に誤りを見つけた場合、TwitterやKeybaseなどで報告いただけると助かります。