
【pythonでCNN#5】col2im関数
記事の目的
pythonでCNN(畳み込みニューラルネットワーク)を実装する上で必要になるcol2iml関数を実装していきます。ここにある全てのコードは、コピペで再現することが可能です。
目次
1 col2im関数の概要

2 im2col関数とcol2im関数

# In[1]
import numpy as np
# In[2]
def im2col(x, fil_size, y_size, stride, pad):
x_b, x_c, x_h, x_w = x.shape
fil_h, fil_w = fil_size, fil_size
y_h, y_w = y_size, y_size
index = -1
x_pad = np.pad(x, [(0, 0), (0, 0), (pad, pad), (pad, pad)], "constant")
x_col = np.zeros((fil_h*fil_w, x_b, x_c, y_h, y_w))
for h in range(fil_h):
h2 = h + y_h*stride
for w in range(fil_w):
index += 1
w2 = w + y_w*stride
x_col[index,:,:,:,:] = x_pad[:,:,h:h2:stride,w:w2:stride]
x_col = x_col.transpose(2,0,1,3,4).reshape(x_c*fil_h*fil_w, x_b*y_h*y_w)
return x_col
# In[3]
def col2im(dx_col, x_shape, fil_size, y_size, stride, pad):
x_b, x_c, x_h, x_w = x_shape
fil_h, fil_w = fil_size, fil_size
y_h, y_w = y_size, y_size
index = -1
dx_col = dx_col.reshape(x_c, fil_h*fil_w, x_b, y_h, y_w).transpose(1,2,0,3,4)
dx = np.zeros((x_b, x_c, x_h+2*pad+stride-1, x_w+2*pad+stride-1))
for h in range(fil_h):
h2 = h + y_h*stride
for w in range(fil_w):
index += 1
w2 = w + y_w*stride
#dx[:,:,h:h2:stride,w:w2:stride] = dx_col[index,:,:,:,:]
dx[:,:,h:h2:stride,w:w2:stride] += dx_col[index,:,:,:,:]
return dx[:,:,pad:x_h+pad, pad:x_w+pad]
3 im2col関数とcol2im関数の例

# In[4] x = np.array([[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]).reshape(1,1,4,4) x.shape, x # In[5] x_col = im2col(x, 2, 3, 1, 0) x_col # In[6] dx_col = np.ones(x_col.shape) dx_col # In[7] dx = col2im(dx_col, x.shape, 2, 3, 1, 0) dx