"""
File: qumphy/data/attractor.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Attractor image dataset and lightning data module .
"""
import torch
import PIL
import numpy as np
import torchvision
import pathlib
import lightning as L
[docs]
class AttractorDataset(torch.utils.data.Dataset):
# 2. Initialize our custom dataset
def __init__(self, data_directory: str, subset: str):
"""Initialize the attractor image dataset.
Parameters
----------
data_directory : str
Path to the root data directory.
subset : str
Dataset subset to load, such as "train", "val", or "test".
"""
self.data_directory = pathlib.Path(data_directory)
self.subset = subset
self.img_size = (224, 224)
self.transforms = None
self.id_map = {"AF": 0.0, "NOAF": 1.0}
# Get all if the image paths
# with open(os.path.join(self.root,'subsets',self.split + '.txt')) as f:
self.paths = list((self.data_directory / subset / "AF").glob("*.png"))
self.paths.extend(list((self.data_directory / subset / "NOAF").glob("*.png")))
# print(self.paths)
# 5. Overwrite __len__()
def __len__(self) -> int:
"Returns the total number of samples"
return len(self.paths)
# 6. Overwrite __getitem__() method to return a particular sample
def __getitem__(self, index: int) -> tuple[torch.Tensor, int, str]:
"Returns one sample of data, data and label (X,y)."
img_path = self.paths[index]
cls_name = img_path.parent.name
cls_id = self.id_map[cls_name]
img = np.asarray(PIL.Image.open(img_path))
cls_id = torch.tensor(cls_id).unsqueeze(-1)
img_t = torch.tensor(img.transpose(2, 0, 1)).to(torch.float) / 255.0
img_t = torchvision.transforms.functional.resize(
img_t, size=self.img_size, antialias=True
)
# if not self.test and self.transforms is not None:
# img_t = self.transforms(img_t)
return img_t, cls_id
[docs]
class AttractorDataModule(L.LightningDataModule):
"""LightningDataModule implementation for the DeepBeat dataset."""
def __init__(
self,
data_directory,
batch_size,
num_workers,
):
"""Initialize the attractor data module.
Parameters
----------
data_directory : str
Path to the root data directory.
batch_size : int
Number of samples per batch.
num_workers : int
Number of worker processes used by the dataloaders.
"""
super().__init__()
self.data_directory = data_directory
self.batch_size = batch_size
self.num_workers = num_workers
[docs]
def setup(self, stage):
if stage == "fit":
if hasattr(self, "train_ds"):
return
self.train_ds = AttractorDataset(
data_directory=self.data_directory,
subset="train",
)
self.val_ds = AttractorDataset(
data_directory=self.data_directory,
subset="val",
)
elif stage == "validate":
if hasattr(self, "val_ds"):
return
self.val_ds = AttractorDataset(
data_directory=self.data_directory,
subset="val",
)
elif stage == "test":
if hasattr(self, "test_ds"):
return
self.test_ds = AttractorDataset(
data_directory=self.data_directory,
subset="test",
)
[docs]
def train_dataloader(self):
"""Create the training dataloader.
Returns
-------
torch.utils.data.DataLoader
Dataloader for the training dataset.
"""
return torch.utils.data.DataLoader(
self.train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True,
)
[docs]
def val_dataloader(self):
"""Create the validation dataloader.
Returns
-------
torch.utils.data.DataLoader
Dataloader for the validation dataset.
"""
return torch.utils.data.DataLoader(
self.val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True,
)
[docs]
def test_dataloader(self):
"""Create the test dataloader.
Returns
-------
torch.utils.data.DataLoader
Dataloader for the test dataset.
"""
return torch.utils.data.DataLoader(
self.test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True,
)