『ゼロから作るDeep Learning』で(御多分にもれず)導出やアルゴリズムに関する部分が省略されている、畳み込み層の演算を解読していきます。具体的には、Convolutionレイヤのim2col
、col2im
関数、偏微分の部分です。
ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装
- 作者: 斎藤康毅
- 出版社/メーカー: オライリージャパン
- 発売日: 2016/09/24
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (18件) を見る
本を持っていない方は、以下からcommon/layers.py
、common/util.py
をあわせて見てください。
順伝播
バッチ版入力データに対する畳み込みは、以下の図で表されます。
添え字の簡単化のため、はすでにパディングされた状態とします。実際のコードでもまずim2col
が呼び出され、そこであらかじめパディングしています。また、書中の表記から以下の置き換えを行っています。
- N→
- H→
- W→
- FN→
- FH→
- FW→
- OH→
- OW→
ストライドが一般化された畳み込みを理解しやすくするために、、入力サイズ5 x 5、フィルタサイズ3 x 3の場合を以下の図に表します。
くどいですが、本来の畳み込みはフィルタを各軸に対して反転(あるいは180°回転)させるのでCNN(や画像分野)の畳み込みは正確には相関です。この辺りの主張は以下の過去記事をご参照ください。
余談おわり。図の関係をもとに考えると、出力データの成分は以下のように定式化されます。
\begin{align}
y_{u, c', m', n'} = \sum_{c = 0}^{C - 1} \sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} x_{u, c, i + Sm', j + Sn'}w_{c', c, i, j} + b_{c'}
\end{align}
バイアスを除いた部分は線型結合ですから行列の積として書けそうです。総和は、、に対して行っていますから、それらすべてを行として並べることで、は2次元配列として以下の図のように格納できます。
の添え字は6個ですから、まず6次元配列を用意します。
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
2重のfor文は、各フィルタ座標について全フィルタポジションにわたって先に格納しています。の畳み込みに関する軸は、のみだからです。実装した人、頭いいですね。
y_max = y + stride*out_h
x_max = x + stride*out_w
や
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
に、、が表れていますね。畳み込みですからは異なるフィルタ位置で重複しますが、それらは別の行に格納することで重複して計算しています。
最後に、軸を入れ替え、形状をの行列にします。
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
順伝播の計算は、以下の図で表せます。
col_W
はにしたのち、転置します。
col_W = self.W.reshape(FN, -1).T
あとは、ブロードキャストを利用してバイアスを含めて計算後、整形するだけです。
out = np.dot(col, col_W) + self.b
out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
計算結果はですから、軸を以上のような順番に入れ替えています。
逆伝播
まず、立式します。上のの式を利用して、損失関数を、について偏微分すると、それぞれ以下のようになります。
\begin{align}
\frac{\partial L}{\partial w_{c', c, i, j}} &= \sum_{u = 0}^{U - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}\frac{\partial y_{u, c', m', n'}}{w_{c', c, i, j}} \\
&= \sum_{u = 0}^{U - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}x_{u, c, i + Sm', j + Sn'}
\end{align}
\begin{align}
\frac{\partial L}{\partial b_{c'}} &= \sum_{u = 0}^{U - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}\frac{\partial y_{u, c', m', n'}}{b_{c'}} \\
&= \sum_{u = 0}^{U - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}
\end{align}
ここは、多変数の合成関数の偏微分を思い出してください。つまり、、に対し、は以下となるからです。
\begin{align}
\frac{\partial L}{\partial w_j} = \sum_i \frac{\partial L}{\partial y_i}\frac{\partial y_i}{\partial w_j}
\end{align}
(テクニック的には、分子の添え字のうち、分母にもある添え字は軸が固定されるので総和をとりません。)
ここで、の配列(=dout
)は明らかにout
と同サイズですから、以下のコードでの2次元配列にします。
dout = dout.transpose(0,2,3,1).reshape(-1, FN)
あとは、順伝播のときと同様に、行列として計算です。
self.db = np.sum(dout, axis=0)
self.dW = np.dot(self.col.T, dout)
self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
3行目は、元の形状に戻しているだけです。
ここまで書いて気づきました… 順伝播時に畳み込み層を行列の計算として書けたのですから、Affineレイヤと同じ計算で済むんですね。畳み込み層の偏微分を必ずしも理解する必要はないようです。
それでも、col2im
の理解には次式が必要です。
\begin{align}
\frac{\partial L}{\partial x_{u, c, m, n}} = \sum_{c' = 0}^{C' - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}\frac{\partial y_{u, c', m', n'}}{\partial x_{u, c, m, n}}
\end{align}
ここで、について考えると、以下のようになります。
\begin{align}
\frac{\partial y_{u, c', m', n'}}{\partial x_{u, c, m, n}} =
\begin{cases}
\displaystyle \sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} w_{c', c, i, j}, & m = i + Sm' \text{ and } n = j + Sn' \\
0, & \text{otherwise}
\end{cases}
\end{align}
したがって、以下のように変形できます。
\begin{align}
\frac{\partial L}{\partial x_{u, c, m, n}} &= \sum_{c' = 0}^{C' - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}\left(\sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} w_{c', c, i, j}\right) \\
&= \sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} \left(\sum_{c' = 0}^{C' - 1} \frac{\partial L}{\partial y_{u, c', m', n'}}w_{c', c, i, j}\right) \\
&= \sum_{i = 0}^{I - 1} \sum_{j = 0}^{J - 1} \sum_{m' = 0}^{M' - 1} \sum_{n' = 0}^{N' - 1} dout_{u, c, i, j, m', n'}, \\
&\text{where } m = i + Sm' \text{ and } n = j + Sn'
\end{align}
これは、コードで書くと次のようになります。
for y in range(filter_h):
for x in range(filter_w):
for mp in range(out_h):
for np in range(out_w):
img[:, :, y + stride*mp, x + stride*np] += col[:, :, y, x, mp, np]
つまりは、以下の実際のコードと同じことです。
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]
単純にim2col
の逆じゃなくて+=
なのが不思議に思っていたのですが、こういうことだったんですね。
ところで、試してないので勘ですが、その直前のこの部分は+ stride - 1
がいらないんじゃないかと思います。
img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))
stride - 1
は常に非負なので問題なく動いている気がします。