Skip to content

Datasets

copick-torch provides PyTorch Dataset classes that extract 3D subvolumes around copick picks, with optional background sampling, caching, and on-the-fly augmentation. Each dataset yields (volume, label) tuples ready for a DataLoader.

MinimalCopickDataset

The simplest dataset — no caching or augmentation, minimal dependencies.

copick_torch.minimal_dataset.MinimalCopickDataset

Bases: Dataset

A minimal PyTorch dataset for working with copick data that returns (image, label) pairs.

Unlike the SimpleCopickDataset, this implementation: 1. Does not use caching (loads data on-the-fly) 2. Does not include augmentation 3. Has minimal dependencies 4. Focuses on correct label mapping

This dataset can be saved to disk and loaded later for reproducibility.

__init__

__init__(proj=None, dataset_id=None, overlay_root=None, boxsize=(48, 48, 48), voxel_spacing=10.012, include_background=False, background_ratio=0.2, min_background_distance=None, preload=True)

Initialize a MinimalCopickDataset.

Parameters:

  • proj

    A copick project object. If provided, dataset_id and overlay_root are ignored.

  • dataset_id

    Dataset ID from the CZ cryoET Data Portal. Only used if proj is None.

  • overlay_root

    Root directory for the overlay storage. Only used if proj is None.

  • boxsize

    Size of the subvolumes to extract (z, y, x)

  • voxel_spacing

    Voxel spacing to use for extraction

  • include_background

    Whether to include background samples

  • background_ratio

    Ratio of background to particle samples

  • min_background_distance

    Minimum distance from particles for background samples

  • preload

    Whether to preload all subvolumes into memory (faster but more memory intensive)

extract_subvolume

extract_subvolume(point, tomogram_idx=0)

Extract a cubic subvolume centered around a point.

Parameters:

  • point

    (x, y, z) coordinates

  • tomogram_idx

    Index of the tomogram to use

Returns:

  • Extracted subvolume as a numpy array

__len__

__len__()

Get the length of the dataset.

__getitem__

__getitem__(idx)

Get an item from the dataset.

Parameters:

  • idx

    Index

Returns:

  • Tuple of (subvolume, label)

keys

keys()

Get the list of class names.

get_class_distribution

get_class_distribution()

Get the distribution of classes in the dataset.

get_sample_weights

get_sample_weights()

Compute sample weights for balanced sampling.

Returns:

  • List of weights for each sample

save

save(save_dir)

Save the dataset to disk for later reloading.

Parameters:

  • save_dir

    Directory to save the dataset

load

load(save_dir, proj=None)

Load a previously saved dataset.

Parameters:

  • save_dir

    Directory where the dataset was saved

  • proj

    Optional copick project object. If provided, tomograms will be loaded from it.

Returns:

  • Loaded MinimalCopickDataset instance

SimpleCopickDataset

Adds disk caching (pickle or parquet), augmentation, and class-balancing helpers.

copick_torch.dataset.SimpleCopickDataset

Bases: SimpleDatasetMixin, Dataset

A simplified PyTorch dataset for working with copick data that returns (image, label) pairs.

This implementation is a wrapper around the original CopickDataset that modifies the getitem method to return a simpler format suitable for standard training pipelines.

__init__

__init__(config_path: Union[str, Any] = None, copick_root: Optional[Any] = None, boxsize: Tuple[int, int, int] = (32, 32, 32), augment: bool = False, cache_dir: Optional[str] = None, cache_format: str = 'parquet', seed: Optional[int] = 1717, max_samples: Optional[int] = None, voxel_spacing: float = 10.0, include_background: bool = False, background_ratio: float = 0.2, min_background_distance: Optional[float] = None, patch_strategy: str = 'centered', debug_mode: bool = False, dataset_id: Optional[int] = None, overlay_root: str = '/tmp/test/')

Initialize a SimpleCopickDataset.

Parameters:

  • config_path (Union[str, Any], default: None ) –

    Path to the copick config file or CopickConfig object

  • copick_root (Optional[Any], default: None ) –

    Copick root object (alternative to config_path)

  • boxsize (Tuple[int, int, int], default: (32, 32, 32) ) –

    Size of the subvolumes to extract (z, y, x)

  • augment (bool, default: False ) –

    Whether to apply data augmentation

  • cache_dir (Optional[str], default: None ) –

    Directory to cache extracted subvolumes

  • cache_format (str, default: 'parquet' ) –

    Format for caching ('pickle' or 'parquet')

  • seed (Optional[int], default: 1717 ) –

    Random seed for reproducibility

  • max_samples (Optional[int], default: None ) –

    Maximum number of samples to use

  • voxel_spacing (float, default: 10.0 ) –

    Voxel spacing to use for extraction

  • include_background (bool, default: False ) –

    Whether to include background samples

  • background_ratio (float, default: 0.2 ) –

    Ratio of background to particle samples

  • min_background_distance (Optional[float], default: None ) –

    Minimum distance from particles for background samples

  • patch_strategy (str, default: 'centered' ) –

    Strategy for extracting patches ('centered', 'random', or 'jittered')

  • debug_mode (bool, default: False ) –

    Whether to enable debug mode

__len__

__len__()

Get the total number of items in the dataset.

get_sample_weights

get_sample_weights()

Return sample weights for use in a WeightedRandomSampler.

keys

keys()

Get pickable object keys.

get_class_distribution

get_class_distribution()

Get distribution of classes in the dataset.

SplicedMixupDataset

Combines experimental and synthetic tomograms with Gaussian-blended splicing and mixup.

copick_torch.dataset.SplicedMixupDataset

Bases: SimpleCopickDataset

A dataset that loads zarr arrays into memory and performs balanced sampling with mixup splicing.

This dataset extends SimpleCopickDataset to add experimental-synthetic data splicing capabilities, keeping zarr arrays in memory for faster loading and using balanced sampling by default.

__init__

__init__(exp_dataset_id: int, synth_dataset_id: int, synth_run_id: str = '16487', overlay_root: str = '/tmp/test/', boxsize: Tuple[int, int, int] = (48, 48, 48), augment: bool = True, cache_dir: Optional[str] = None, cache_format: str = 'parquet', seed: Optional[int] = 1717, max_samples: Optional[int] = None, voxel_spacing: float = 10.0, include_background: bool = False, background_ratio: float = 0.2, min_background_distance: Optional[float] = None, blend_sigma: float = 2.0, mixup_alpha: float = 0.2, debug_mode: bool = False)

Initialize the SplicedMixupDataset.

Parameters:

  • exp_dataset_id (int) –

    Dataset ID for the experimental dataset

  • synth_dataset_id (int) –

    Dataset ID for the synthetic dataset

  • synth_run_id (str, default: '16487' ) –

    Run ID for the synthetic dataset (default: "16487")

  • overlay_root (str, default: '/tmp/test/' ) –

    Root directory for the overlay storage (default: "/tmp/test/")

  • boxsize (Tuple[int, int, int], default: (48, 48, 48) ) –

    Size of the subvolumes to extract (z, y, x)

  • augment (bool, default: True ) –

    Whether to apply data augmentation

  • cache_dir (Optional[str], default: None ) –

    Directory to cache extracted subvolumes

  • cache_format (str, default: 'parquet' ) –

    Format for caching ('pickle' or 'parquet')

  • seed (Optional[int], default: 1717 ) –

    Random seed for reproducibility

  • max_samples (Optional[int], default: None ) –

    Maximum number of samples to use

  • voxel_spacing (float, default: 10.0 ) –

    Voxel spacing to use for extraction

  • include_background (bool, default: False ) –

    Whether to include background samples

  • background_ratio (float, default: 0.2 ) –

    Ratio of background to particle samples

  • min_background_distance (Optional[float], default: None ) –

    Minimum distance from particles for background samples

  • blend_sigma (float, default: 2.0 ) –

    Controls the standard deviation of Gaussian blending at boundaries

  • mixup_alpha (float, default: 0.2 ) –

    Alpha parameter for mixup augmentation

  • debug_mode (bool, default: False ) –

    Whether to enable debug mode

__getitem__

__getitem__(idx)

Get an item with spliced mixup augmentation.

CopickDataset

The original, full-featured dataset implementation.

copick_torch.copick.CopickDataset

Bases: Dataset

A PyTorch dataset for working with copick data for particle picking tasks.

This implementation focuses on extracting subvolumes around pick coordinates with support for data augmentation, caching, and class balancing.

__len__

__len__()

Get the total number of items in the dataset.

__getitem__

__getitem__(idx)

Get an item from the dataset with proper mixup handling and augmentation tracking.

Returns:

  • tuple

    (subvolume, label_dict)

  • Where label_dict contains:

    • 'class_idx': Original class index (or primary class index if mixed)
    • 'is_mixed': Boolean indicating if mixup was applied
    • 'mix_lambda': Lambda value for mixup (1.0 if no mixup)
    • 'mix_class_idx': Secondary class index for mixup (None if no mixup)
    • 'applied_augmentations': List of applied augmentations (if debug_mode=True)

get_sample_weights

get_sample_weights()

Return sample weights for use in a WeightedRandomSampler.

keys

keys()

Get pickable object keys.

examples

examples()

Get example volumes for each class.

get_class_distribution

get_class_distribution()

Get distribution of classes in the dataset.

stratified_split

stratified_split(train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=None)

Split the dataset into train, validation, and test sets while preserving class distributions.

Parameters:

  • train_ratio

    Proportion of data to use for training

  • val_ratio

    Proportion of data to use for validation

  • test_ratio

    Proportion of data to use for testing

  • seed

    Random seed for reproducibility

Returns:

  • Tuple of (train_dataset, val_dataset, test_dataset) as Subset objects

balance_classes

balance_classes(method='oversample', target_ratio=1.0, exclude_background=False)

Balance class distribution in the dataset.

Parameters:

  • method

    Balancing method to use ('oversample' or 'undersample')

  • target_ratio

    For partial balancing (1.0 = perfect balance)

  • exclude_background

    Whether to exclude background class from balancing

Returns:

  • A new CopickDataset instance with balanced classes

extract_grid_patches

extract_grid_patches(patch_size, overlap=0.25, normalize=True, run_index=0, tomo_type='raw')

Extract a grid of patches from a tomogram.

Parameters:

  • patch_size

    Int or tuple (z, y, x) for patch dimensions

  • overlap

    Overlap ratio between adjacent patches (0-1)

  • normalize

    Whether to normalize patches

  • run_index

    Index of the run to extract from

  • tomo_type

    Type of tomogram to extract from ('raw' or 'filtered')

Returns:

  • List of extracted patches and their coordinates (z, y, x)

extract_from_region

extract_from_region(x_range, y_range, z_range, tomo_type='raw')

Extract a specific region from a tomogram.

Parameters:

  • x_range

    Tuple of (min_x, max_x) in voxel space

  • y_range

    Tuple of (min_y, max_y) in voxel space

  • z_range

    Tuple of (min_z, max_z) in voxel space

  • tomo_type

    Type of tomogram to extract from ('raw' or 'filtered')

Returns:

  • A numpy array containing the extracted region