Introduction
In this tutorial, we will go over the basics of writing custom PyTorch datasets. We will go over the Dataset
and DataLoader
classes and how they are used to fetch data. We will also cover how to write custom datasets that can be used with PyTorch’s DataLoader
class.
The Dataset
Class
The Dataset
class is an abstract base class that represents a dataset. It is the base class for all other datasets. The Dataset
class has two abstract methods that need to be overridden:
__len__
which is used to get the size of the dataset.__getitem__
which is used to get an item from the dataset.
The __len__
method takes no arguments and returns the size of the dataset. The __getitem__
method takes an index as an argument and returns the item at that index.
The Dataset
class also implements the __iter__
method which returns an iterator over the dataset. The __iter__
method is implemented using the __getitem__
method. The iterator returned by the __iter__
method returns the items of the dataset one at a time. The iterator raises a StopIteration
exception when there are no more items to return. The __getitem__
method is used to implement the __iter__
method. The __getitem__
method is also used to implement indexing. The __getitem__
method is called whenever an item is retrieved from the dataset using indexing. For example, the following code retrieves the first item from the dataset:
|
|
The __getitem__
method is also called when the dataset is iterated over:
|
|
The __getitem__
method can be called using the dataset[i]
syntax. The __getitem__
method can also be called using the iter(dataset)
syntax. The __getitem__
method is also called by the next(iter(dataset))
syntax. The next
function calls the __iter__
method of the dataset to get an iterator over the dataset. The __iter__
method returns an iterator which is used to iterate over the dataset.
Dummy Dataset
Suppose we have a dataset $\mathcal{D} = {(x_1, y_1), \dots, (x_m, y_m)}$ that contains pairs of features $x_i \in \mathbb{R}^4$ and target $y_i \in \mathbb{R}$. We can create a dataset class that represents this dataset, i.e, dataset[i]
returns the pair $(x_i, y_i)$, as follows:
|
|
|
|
Cats-Dogs Kaggle Dataset
The Cats-vs-Dogs dataset is a dataset that contains images of cats and dogs. The dataset contains 25,000 images of cats and dogs. The dataset is available on Kaggle and it can be downloaded from this link.
After downloading the dataset and extracting the archive file (archive.zip
), a PetImages
folder is extracted containing two subfolders Cat
and Dog
. The Cat
folder contains 12,500 images of cats and the Dog
folder contains 12,500 images of dogs. The images in the Cat
folder are named 0.jpg
, 1.jpg
, …, 12499.jpg
. The images in the Dog
folder are named 0.jpg
, 1.jpg
, …, 12499.jpg
.
Next, we create a csv file containing the metadata and the labels for the images. The csv file contains two columns: path
and label
. The path
column contains the path to the image and the label
column contains the label for the image. The label is 0
for cats and 1
for dogs. The csv file is saved as metadata.csv
. The following code creates the csv file:
|
|
Once the metadata file is created we can create a dataset class that represents the Cats-vs-Dogs dataset. The __getitem__
method should return the image and the label for the image at the given index. The __getitem__
method should return the image as a torch.Tensor
object representing the image and the label as a torch.Tensor
object. The following code creates the dataset class:
|
|
|
|
|
|
|
|