数据集 ¶
In [ ]:
Copied!
from torch.utils.data import Dataset
from PIL import Image
import os
# class Dataset(typing.Generic)
# | An abstract class representing a :class:`Dataset`.
# |
# | All datasets that represent a map from keys to data samples should subclass
# | it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
# | data sample for a given key. Subclasses could also optionally overwrite
# | :meth:`__len__`, which is expected to return the size of the dataset by many
# | :class:`~torch.utils.data.Sampler` implementations and the default options
# | of :class:`~torch.utils.data.DataLoader`. Subclasses could also
# | optionally implement :meth:`__getitems__`, for speedup batched samples
# | loading. This method accepts list of indices of samples of batch and returns
# | list of samples.
from torch.utils.data import Dataset
from PIL import Image
import os
# class Dataset(typing.Generic)
# | An abstract class representing a :class:`Dataset`.
# |
# | All datasets that represent a map from keys to data samples should subclass
# | it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
# | data sample for a given key. Subclasses could also optionally overwrite
# | :meth:`__len__`, which is expected to return the size of the dataset by many
# | :class:`~torch.utils.data.Sampler` implementations and the default options
# | of :class:`~torch.utils.data.DataLoader`. Subclasses could also
# | optionally implement :meth:`__getitems__`, for speedup batched samples
# | loading. This method accepts list of indices of samples of batch and returns
# | list of samples.
In [ ]:
Copied!
class MyDataset(Dataset):
"""
Oriented towards hymenoptera data
"""
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
pass
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
class MyDataset(Dataset):
"""
Oriented towards hymenoptera data
"""
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
pass
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
In [2]:
Copied!
root_dir = 'data/hymenoptera_data/train'
label_dir = 'ants'
ants_dataset = MyDataset(root_dir, label_dir)
# print(ants_dataset[0])
label_dir = 'bees'
bees_dataset = MyDataset(root_dir, label_dir)
train_dataset = ants_dataset + bees_dataset
root_dir = 'data/hymenoptera_data/train'
label_dir = 'ants'
ants_dataset = MyDataset(root_dir, label_dir)
# print(ants_dataset[0])
label_dir = 'bees'
bees_dataset = MyDataset(root_dir, label_dir)
train_dataset = ants_dataset + bees_dataset