【pythonでCNN#3】im2col関数(チャンネル数とバッチサイズ)

記事の目的

pythonでCNN(畳み込みニューラルネットワーク)を実装する上で必要になるim2col関数を、チャンネル数とバッチサイズを考慮して実装していきます。ここにある全てのコードは、コピペで再現することが可能です。

 

目次

  1. チャンネル数3
  2. バッチサイズ3
  3. im2col関数

 

1 チャンネル数3

 

2 バッチサイズ3

 

3 im2col関数

# In[1]
import numpy as np

# In[2]
def im2col(x, fil_size, y_size, stride, pad):
    x_b, x_c, x_h, x_w = x.shape
    fil_h, fil_w = fil_size, fil_size
    y_h, y_w = y_size, y_size
    index = -1
    
    x_pad = np.pad(x, [(0, 0), (0, 0), (pad, pad), (pad, pad)], "constant")
    x_col = np.zeros((fil_h*fil_w, x_b, x_c, y_h, y_w))
    
    for h in range(fil_h):
        h2 = h + y_h*stride
        for w in range(fil_w):
            index += 1
            w2 = w + y_w*stride
            x_col[index,:,:,:,:] = x_pad[:,:,h:h2:stride,w:w2:stride]
    x_col = x_col.transpose(2,0,1,3,4).reshape(x_c*fil_h*fil_w, x_b*y_h*y_w)
    
    return x_col

# In[3]
x = np.arange(48).reshape(1,3,4,4)
print(x)

# In[4]
x1 = im2col(x,3,2,1,0)
print(x1)

# In[5]
x = np.arange(144).reshape(3,3,4,4)
print(x)

# In[6]
x2 = im2col(x,3,2,1,0)
print(x2)