Source code for qumphy.data.attractor

"""
File: qumphy/data/attractor.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Attractor image dataset and lightning data module .
"""

import torch
import PIL
import numpy as np
import torchvision
import pathlib
import lightning as L


[docs] class AttractorDataset(torch.utils.data.Dataset): # 2. Initialize our custom dataset def __init__(self, data_directory: str, subset: str): """Initialize the attractor image dataset. Parameters ---------- data_directory : str Path to the root data directory. subset : str Dataset subset to load, such as "train", "val", or "test". """ self.data_directory = pathlib.Path(data_directory) self.subset = subset self.img_size = (224, 224) self.transforms = None self.id_map = {"AF": 0.0, "NOAF": 1.0} # Get all if the image paths # with open(os.path.join(self.root,'subsets',self.split + '.txt')) as f: self.paths = list((self.data_directory / subset / "AF").glob("*.png")) self.paths.extend(list((self.data_directory / subset / "NOAF").glob("*.png"))) # print(self.paths) # 5. Overwrite __len__() def __len__(self) -> int: "Returns the total number of samples" return len(self.paths) # 6. Overwrite __getitem__() method to return a particular sample def __getitem__(self, index: int) -> tuple[torch.Tensor, int, str]: "Returns one sample of data, data and label (X,y)." img_path = self.paths[index] cls_name = img_path.parent.name cls_id = self.id_map[cls_name] img = np.asarray(PIL.Image.open(img_path)) cls_id = torch.tensor(cls_id).unsqueeze(-1) img_t = torch.tensor(img.transpose(2, 0, 1)).to(torch.float) / 255.0 img_t = torchvision.transforms.functional.resize( img_t, size=self.img_size, antialias=True ) # if not self.test and self.transforms is not None: # img_t = self.transforms(img_t) return img_t, cls_id
[docs] class AttractorDataModule(L.LightningDataModule): """LightningDataModule implementation for the DeepBeat dataset.""" def __init__( self, data_directory, batch_size, num_workers, ): """Initialize the attractor data module. Parameters ---------- data_directory : str Path to the root data directory. batch_size : int Number of samples per batch. num_workers : int Number of worker processes used by the dataloaders. """ super().__init__() self.data_directory = data_directory self.batch_size = batch_size self.num_workers = num_workers
[docs] def setup(self, stage): if stage == "fit": if hasattr(self, "train_ds"): return self.train_ds = AttractorDataset( data_directory=self.data_directory, subset="train", ) self.val_ds = AttractorDataset( data_directory=self.data_directory, subset="val", ) elif stage == "validate": if hasattr(self, "val_ds"): return self.val_ds = AttractorDataset( data_directory=self.data_directory, subset="val", ) elif stage == "test": if hasattr(self, "test_ds"): return self.test_ds = AttractorDataset( data_directory=self.data_directory, subset="test", )
[docs] def train_dataloader(self): """Create the training dataloader. Returns ------- torch.utils.data.DataLoader Dataloader for the training dataset. """ return torch.utils.data.DataLoader( self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True, )
[docs] def val_dataloader(self): """Create the validation dataloader. Returns ------- torch.utils.data.DataLoader Dataloader for the validation dataset. """ return torch.utils.data.DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=True, )
[docs] def test_dataloader(self): """Create the test dataloader. Returns ------- torch.utils.data.DataLoader Dataloader for the test dataset. """ return torch.utils.data.DataLoader( self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=True, )