【pytorchでニューラルネットワーク#5】DatasetとDataLoader

記事の目的

pytorchでニューラルネットワークを実装する上で必要になるDatasetとDataLoaderについて実装していきます。ここにある全てのコードは、コピペで再現することが可能です。

 

目次

  1. ライブラリとデータのダウンロード
  2. データの参照
  3. 自作Datasetクラス

 

1 ライブラリとデータのダウンロード

# In[1]
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
torch.manual_seed(1)

# In[2]
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0),(1)),lambda x: x.view(-1)])

root = './data'
mnist_train = datasets.MNIST(root=root,download=True,train=True,transform=transform)
mnist_test = datasets.MNIST(root=root,download=True,train=False,transform=transform)

train_dataloader = DataLoader(mnist_train,batch_size=100,shuffle=True)
test_dataloader = DataLoader(mnist_test,batch_size=100,shuffle=False)

# In[3]
mnist_train

# In[4]
train_dataloader

 

2 データの参照

# In[5]
x, t = next(iter(train_dataloader))
x.shape, t.shape

# In[6]
for x,t in train_dataloader:
    print(x.shape, t.shape)

 

3 自作Datasetクラス

# In[7]
class Dataset:
    def __init__(self):
        self.data = [1,2,3,4,5]
        self.label = [0,0,0,1,1]
    def __getitem__(self, index):
        return self.data[index], self.label[index]
    def __len__(self):
        return len(self.data)

# In[8]
dataset = Dataset()

# In[9]
dataset

# In[10]
print(dataset[0], dataset[1], dataset[2], dataset[3], dataset[4])
print(len(dataset))

# In[11]
class Dataset2:
    def __init__(self, transform_data=None, transform_label=None):
        self.transform_data = transform_data
        self.transform_label = transform_label
        self.data = [1,2,3,4,5,6]
        self.label = [0,0,0,1,1,1]
    def __getitem__(self, index):
        x = self.data[index]
        t = self.label[index]
        if self.transform_data:
            x = self.transform_data(self.data[index])
        if self.transform_label:
            t = self.transform_label(self.label[index])
        return x,t
    def __len__(self):
        return len(self.data)

# In[12]
dataset2 = Dataset2()

# In[13]
print(dataset2[0], dataset2[1], dataset2[2], dataset2[3], dataset2[4], dataset2[5])
print(len(dataset2))

# In[14]
transform = lambda x: x+10
dataset3 = Dataset2(transform_data=transform)

# In[15]
print(dataset3[0], dataset3[1], dataset3[2], dataset3[3], dataset3[4], dataset3[5])
print(len(dataset3))

# In[16]
dataloader = DataLoader(dataset3,batch_size=2,shuffle=True)

# In[17]
dataloader

# In[18]
for x,t in dataloader:
    print(x,t)