【pythonでCNN#1】im2col関数

記事の目的

pythonでCNN(畳み込みニューラルネットワーク)を実装する上で必要になるim2col関数を実装していきます。ここにある全てのコードは、コピペで再現することが可能です。

 

目次

  1. 畳み込み演算とim2col
  2. im2col関数

 

1 畳み込み演算とim2col

 

2 im2col関数

2.1 列方向の繰り返し

# In[1]
import numpy as np

# In[2]
x = np.array([[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]])
x

# In[3]
x_col = np.zeros((2*2,3*3))
x_col

# In[4]
x_col[:,0] = x[0:2,0:2].reshape(-1)
x_col[:,1] = x[0:2,1:3].reshape(-1)
x_col

# In[5]
def im2col(x, fil_size, y_size):
    fil_h, fil_w = fil_size, fil_size
    y_h, y_w = y_size, y_size
    index = -1
    
    x_col = np.zeros((fil_h*fil_w, y_h*y_w))
    for h in range(y_h):
        h2 = h + fil_h
        for w in range(y_w):
            index += 1
            w2 = w + fil_w
            x_col[:,index] = x[h:h2,w:w2].reshape(-1)
            
    return x_col

# In[6]
im2col(x,2,3)

 

2.2 行方向の繰り返し

# In[7]
x = np.array([[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]])
x

# In[8]
x_col = np.zeros((2*2,3*3))
x_col

# In[9]
x_col[0,:] = x[0:3,0:3].reshape(-1)
x_col[1,:] = x[0:3,1:4].reshape(-1)
x_col

# In[10]
def im2col(x, fil_size, y_size):
    fil_h, fil_w = fil_size, fil_size
    y_h, y_w = y_size, y_size
    index = -1
    
    x_col = np.zeros((fil_h*fil_w, y_h*y_w))
    for h in range(fil_h):
        h2 = h + y_h
        for w in range(fil_w):
            index += 1 
            w2 = w + y_w
            x_col[index,:] = x[h:h2,w:w2].reshape(-1)
    return x_col

# In[11]
im2col(x,2,3)