【pytorchで深層生成モデル#12】画像データセットの作成

記事の目的

画像データセットの作成方法について解説します。これまでの深層生成モデルのデータセットはMNISTを使用してきました。今回は、自分が用意した画像データを使用できるようにするため、”jpg”の画像データがたくさん入っているファイルのパスからデータセットを作成する方法を解説します。

 

目次

  1. ライブラリ
  2. jpgデータの読み込み
  3. データセットの作成

 

1 ライブラリ

# [1]
# MNISTを読み込むライブラリ
from torchvision import datasets

# データセット作成ライブラリ
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

# パスに必要なライブラリ
from pathlib import Path

# 画像読み込みに必要なライブラリ
from PIL import Image

 

2 jpgデータの読み込み

# [2]
!mkdir './data'

dataset = datasets.MNIST(root='.', download=True)

for idx, (img, _) in enumerate(dataset):
    img.save('./data/{:05d}.jpg'.format(idx))

# [3]
!find /content/data/* -type f | wc -l

 

3 データセットの作成

# [4]
class Image_dataset(Dataset):

  # パスとtransformの取得
  def __init__(self, img_dir, transform=None):
        self.img_paths = self._get_img_paths(img_dir)
        self.transform = transform

  # データの取得
  def __getitem__(self, index):
      path = self.img_paths[index]
      img = Image.open(path)
      if self.transform is not None:
          img = self.transform(img)
      return img
  
  # パスの取得
  def _get_img_paths(self, img_dir):
      img_dir = Path(img_dir)
      img_paths = [p for p in img_dir.iterdir() if p.suffix in [".jpg", ".jpeg", ".png", ".bmp"]]
      return img_paths

  # ながさの取得
  def __len__(self):
      return len(self.img_paths)

# [5]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = Image_dataset("/content/data", transform)
dataloader = DataLoader(dataset, batch_size=100)

# [6]
for x in dataloader:
  print(x.shape)
  break