【pytorchで深層生成モデル#12】画像データセットの作成
記事の目的
画像データセットの作成方法について解説します。これまでの深層生成モデルのデータセットはMNISTを使用してきました。今回は、自分が用意した画像データを使用できるようにするため、”jpg”の画像データがたくさん入っているファイルのパスからデータセットを作成する方法を解説します。
目次
1 ライブラリ
# [1] # MNISTを読み込むライブラリ from torchvision import datasets # データセット作成ライブラリ import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset # パスに必要なライブラリ from pathlib import Path # 画像読み込みに必要なライブラリ from PIL import Image
2 jpgデータの読み込み
# [2] !mkdir './data' dataset = datasets.MNIST(root='.', download=True) for idx, (img, _) in enumerate(dataset): img.save('./data/{:05d}.jpg'.format(idx)) # [3] !find /content/data/* -type f | wc -l
3 データセットの作成
# [4] class Image_dataset(Dataset): # パスとtransformの取得 def __init__(self, img_dir, transform=None): self.img_paths = self._get_img_paths(img_dir) self.transform = transform # データの取得 def __getitem__(self, index): path = self.img_paths[index] img = Image.open(path) if self.transform is not None: img = self.transform(img) return img # パスの取得 def _get_img_paths(self, img_dir): img_dir = Path(img_dir) img_paths = [p for p in img_dir.iterdir() if p.suffix in [".jpg", ".jpeg", ".png", ".bmp"]] return img_paths # ながさの取得 def __len__(self): return len(self.img_paths) # [5] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) dataset = Image_dataset("/content/data", transform) dataloader = DataLoader(dataset, batch_size=100) # [6] for x in dataloader: print(x.shape) break