cBlog

Tips for you.

畳み込みニューラルネットワークの畳み込み層の実装を数式で理解する

スポンサーリンク

『ゼロから作るDeep Learning』で(御多分にもれず)導出やアルゴリズムに関する部分が省略されている、畳み込み層の演算を解読していきます。具体的には、Convolutionレイヤのim2colcol2im関数、偏微分の部分です。

ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装

ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装

 

本を持っていない方は、以下からcommon/layers.pycommon/util.pyをあわせて見てください。

github.com

 

順伝播

バッチ版入力データ\mathbf{X}_uに対する畳み込みは、以下の図で表されます。

畳み込み層の処理フロー(バッチ版)

添え字の簡単化のため、\mathbf{X}_uはすでにパディングされた状態とします。実際のコードでもまずim2colが呼び出され、そこであらかじめパディングしています。また、書中の表記から以下の置き換えを行っています。

  • N→U
  • H→M
  • W→N
  • FN→C'
  • FH→I
  • FW→J
  • OH→M'
  • OW→N'

ストライドが一般化された畳み込みを理解しやすくするために、S=2、入力サイズ5 x 5、フィルタサイズ3 x 3の場合を以下の図に表します。

畳み込みの例

くどいですが、本来の畳み込みはフィルタを各軸に対して反転(あるいは180°回転)させるのでCNN(や画像分野)の畳み込みは正確には相関です。この辺りの主張は以下の過去記事をご参照ください。

yaritakunai.hatenablog.com

余談おわり。図の関係をもとに考えると、出力データ\mathbf{Y}_uの成分y_{u, c', m', n'}は以下のように定式化されます。

\begin{align}
y_{u, c', m', n'} = \sum_{c = 0}^{C - 1} \sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} x_{u, c, i + Sm', j + Sn'}w_{c', c, i, j} + b_{c'}
\end{align}

バイアスb_{c'}を除いた部分は線型結合ですから行列の積として書けそうです。総和はcijに対して行っていますから、それらすべてを行として並べることで、\mathbf{X}_uは2次元配列として以下の図のように格納できます。

im2colの出力行列

xの添え字は6個ですから、まず6次元配列を用意します。

col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

2重のfor文は、各フィルタ座標(i, j)について全フィルタポジションにわたって先に格納しています。wの畳み込みに関する軸はijのみだからです。実装した人、頭いいですね。

y_max = y + stride*out_h
x_max = x + stride*out_w

col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

に、i + Sm'j + Sn'が表れていますね。畳み込みですからxは異なるフィルタ位置で重複しますが、それらは別の行に格納することで重複して計算しています。

最後に、軸を入れ替え、形状をUM'N' \times CIJの行列にします。

col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)

順伝播の計算は、以下の図で表せます。

畳み込み層の行列演算

col_WCIJ \times C'にしたのち、転置します。

col_W = self.W.reshape(FN, -1).T

あとは、ブロードキャストを利用してバイアスを含めて計算後、整形するだけです。

out = np.dot(col, col_W) + self.b
out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)

計算結果はUM'N' \times C'ですから、軸を以上のような順番に入れ替えています。

 

逆伝播

まず、立式します。上のy_{u, c', m', n'}の式を利用して、損失関数Lw_{c', c, i, j}b_{c'}について偏微分すると、それぞれ以下のようになります。

\begin{align}
\frac{\partial L}{\partial w_{c', c, i, j}} &= \sum_{u = 0}^{U - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}\frac{\partial y_{u, c', m', n'}}{w_{c', c, i, j}} \\
&= \sum_{u = 0}^{U - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}x_{u, c, i + Sm', j + Sn'}
\end{align}

\begin{align}
\frac{\partial L}{\partial b_{c'}} &= \sum_{u = 0}^{U - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}\frac{\partial y_{u, c', m', n'}}{b_{c'}} \\
&= \sum_{u = 0}^{U - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}
\end{align}

ここは、多変数の合成関数の偏微分を思い出してください。つまり、L(y_1, y_2, \ldots)y_i(w_1, w_2, \ldots)に対し、\partial L/\partial w_jは以下となるからです。

\begin{align}
\frac{\partial L}{\partial w_j} = \sum_i \frac{\partial L}{\partial y_i}\frac{\partial y_i}{\partial w_j}
\end{align}

(テクニック的には、分子の添え字のうち、分母にもある添え字は軸が固定されるので総和をとりません。)

ここで、{\partial L}/{\partial y_{u, c', m', n'}}の配列(=dout)は明らかにoutと同サイズですから、以下のコードでUM'N' \times C'の2次元配列にします。

dout = dout.transpose(0,2,3,1).reshape(-1, FN)

あとは、順伝播のときと同様に、行列として計算です。

self.db = np.sum(dout, axis=0)
self.dW = np.dot(self.col.T, dout)
self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

3行目は、元の形状に戻しているだけです。

ここまで書いて気づきました… 順伝播時に畳み込み層を行列の計算として書けたのですから、Affineレイヤと同じ計算で済むんですね。畳み込み層の偏微分を必ずしも理解する必要はないようです。

それでも、col2imの理解には次式が必要です。

\begin{align}
\frac{\partial L}{\partial x_{u, c, m, n}} = \sum_{c' = 0}^{C' - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}\frac{\partial y_{u, c', m', n'}}{\partial x_{u, c, m, n}}
\end{align}

ここで、{\partial y_{u, c', m', n'}}/{\partial x_{u, c, m, n}}について考えると、以下のようになります。

\begin{align}
\frac{\partial y_{u, c', m', n'}}{\partial x_{u, c, m, n}} =
\begin{cases}
\displaystyle \sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} w_{c', c, i, j}, & m = i + Sm' \text{ and } n = j + Sn' \\
0, & \text{otherwise}
\end{cases}
\end{align}

したがって、以下のように変形できます。

\begin{align}
\frac{\partial L}{\partial x_{u, c, m, n}} &= \sum_{c' = 0}^{C' - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}\left(\sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} w_{c', c, i, j}\right) \\
&= \sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \left(\sum_{c' = 0}^{C' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}w_{c', c, i, j}\right) \\
&= \sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} dout_{u, c, i, j, m', n'}, \\
&\text{where } m = i + Sm' \text{ and } n = j + Sn'
\end{align}

これは、コードで書くと次のようになります。

for y in range(filter_h):
    for x in range(filter_w):
        for mp in range(out_h):
            for np in range(out_w):
                img[:, :, y + stride*mp, x + stride*np] += col[:, :, y, x, mp, np]

つまりは、以下の実際のコードと同じことです。

for y in range(filter_h):
    y_max = y + stride*out_h
    for x in range(filter_w):
        x_max = x + stride*out_w
        img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]

単純にim2colの逆じゃなくて+=なのが不思議に思っていたのですが、こういうことだったんですね。

ところで、試してないので勘ですが、その直前のこの部分は+ stride - 1がいらないんじゃないかと思います。

img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))

stride - 1は常に非負なので問題なく動いている気がします。