【pytorchでニューラルネットワーク#5】DatasetとDataLoader
記事の目的
pytorchでニューラルネットワークを実装する上で必要になるDatasetとDataLoaderについて実装していきます。ここにある全てのコードは、コピペで再現することが可能です。
目次
1 ライブラリとデータのダウンロード
# In[1] import torch from torchvision import datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader torch.manual_seed(1) # In[2] transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0),(1)),lambda x: x.view(-1)]) root = './data' mnist_train = datasets.MNIST(root=root,download=True,train=True,transform=transform) mnist_test = datasets.MNIST(root=root,download=True,train=False,transform=transform) train_dataloader = DataLoader(mnist_train,batch_size=100,shuffle=True) test_dataloader = DataLoader(mnist_test,batch_size=100,shuffle=False) # In[3] mnist_train # In[4] train_dataloader
2 データの参照
# In[5] x, t = next(iter(train_dataloader)) x.shape, t.shape # In[6] for x,t in train_dataloader: print(x.shape, t.shape)
3 自作Datasetクラス
# In[7] class Dataset: def __init__(self): self.data = [1,2,3,4,5] self.label = [0,0,0,1,1] def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return len(self.data) # In[8] dataset = Dataset() # In[9] dataset # In[10] print(dataset[0], dataset[1], dataset[2], dataset[3], dataset[4]) print(len(dataset)) # In[11] class Dataset2: def __init__(self, transform_data=None, transform_label=None): self.transform_data = transform_data self.transform_label = transform_label self.data = [1,2,3,4,5,6] self.label = [0,0,0,1,1,1] def __getitem__(self, index): x = self.data[index] t = self.label[index] if self.transform_data: x = self.transform_data(self.data[index]) if self.transform_label: t = self.transform_label(self.label[index]) return x,t def __len__(self): return len(self.data) # In[12] dataset2 = Dataset2() # In[13] print(dataset2[0], dataset2[1], dataset2[2], dataset2[3], dataset2[4], dataset2[5]) print(len(dataset2)) # In[14] transform = lambda x: x+10 dataset3 = Dataset2(transform_data=transform) # In[15] print(dataset3[0], dataset3[1], dataset3[2], dataset3[3], dataset3[4], dataset3[5]) print(len(dataset3)) # In[16] dataloader = DataLoader(dataset3,batch_size=2,shuffle=True) # In[17] dataloader # In[18] for x,t in dataloader: print(x,t)