
【pytorchでニューラルネットワーク#5】DatasetとDataLoader
記事の目的
pytorchでニューラルネットワークを実装する上で必要になるDatasetとDataLoaderについて実装していきます。ここにある全てのコードは、コピペで再現することが可能です。
目次
1 ライブラリとデータのダウンロード

# In[1] import torch from torchvision import datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader torch.manual_seed(1) # In[2] transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0),(1)),lambda x: x.view(-1)]) root = './data' mnist_train = datasets.MNIST(root=root,download=True,train=True,transform=transform) mnist_test = datasets.MNIST(root=root,download=True,train=False,transform=transform) train_dataloader = DataLoader(mnist_train,batch_size=100,shuffle=True) test_dataloader = DataLoader(mnist_test,batch_size=100,shuffle=False) # In[3] mnist_train # In[4] train_dataloader
2 データの参照

# In[5]
x, t = next(iter(train_dataloader))
x.shape, t.shape
# In[6]
for x,t in train_dataloader:
print(x.shape, t.shape)
3 自作Datasetクラス

# In[7]
class Dataset:
def __init__(self):
self.data = [1,2,3,4,5]
self.label = [0,0,0,1,1]
def __getitem__(self, index):
return self.data[index], self.label[index]
def __len__(self):
return len(self.data)
# In[8]
dataset = Dataset()
# In[9]
dataset
# In[10]
print(dataset[0], dataset[1], dataset[2], dataset[3], dataset[4])
print(len(dataset))
# In[11]
class Dataset2:
def __init__(self, transform_data=None, transform_label=None):
self.transform_data = transform_data
self.transform_label = transform_label
self.data = [1,2,3,4,5,6]
self.label = [0,0,0,1,1,1]
def __getitem__(self, index):
x = self.data[index]
t = self.label[index]
if self.transform_data:
x = self.transform_data(self.data[index])
if self.transform_label:
t = self.transform_label(self.label[index])
return x,t
def __len__(self):
return len(self.data)
# In[12]
dataset2 = Dataset2()
# In[13]
print(dataset2[0], dataset2[1], dataset2[2], dataset2[3], dataset2[4], dataset2[5])
print(len(dataset2))
# In[14]
transform = lambda x: x+10
dataset3 = Dataset2(transform_data=transform)
# In[15]
print(dataset3[0], dataset3[1], dataset3[2], dataset3[3], dataset3[4], dataset3[5])
print(len(dataset3))
# In[16]
dataloader = DataLoader(dataset3,batch_size=2,shuffle=True)
# In[17]
dataloader
# In[18]
for x,t in dataloader:
print(x,t)