Skip to content

Commit

Permalink
Make tfds.data_source pickable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636824581
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed May 27, 2024
1 parent 6bbba45 commit 52f4d72
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 40 deletions.
18 changes: 2 additions & 16 deletions tensorflow_datasets/core/data_sources/array_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,8 @@
"""

import dataclasses
from typing import Any, Optional

from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.data_sources import base
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source


Expand All @@ -42,18 +37,9 @@ class ArrayRecordDataSource(base.BaseDataSource):
source.
"""

dataset_info: dataset_info_lib.DatasetInfo
split: splits_lib.Split = None
decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = (
None
)
# In order to lazy load array_record, we don't load
# `array_record_data_source.ArrayRecordDataSource` here.
data_source: Any = dataclasses.field(init=False)
length: int = dataclasses.field(init=False)

def __post_init__(self):
file_instructions = base.file_instructions(self.dataset_info, self.split)
dataset_info = self.dataset_builder.info
file_instructions = base.file_instructions(dataset_info, self.split)
self.data_source = array_record_data_source.ArrayRecordDataSource(
file_instructions
)
32 changes: 24 additions & 8 deletions tensorflow_datasets/core/data_sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

from collections.abc import MappingView, Sequence
import dataclasses
import functools
import typing
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar

from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.features import top_level_feature
from tensorflow_datasets.core.utils import shard_utils
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tree
Expand Down Expand Up @@ -54,6 +56,14 @@ def file_instructions(
return split_dict[split].file_instructions


class _DatasetBuilder(Protocol):
"""Protocol for the DatasetBuilder to avoid cyclic imports."""

@property
def info(self) -> dataset_info_lib.DatasetInfo:
...


@dataclasses.dataclass
class BaseDataSource(MappingView, Sequence):
"""Base DataSource to override all dunder methods with the deserialization.
Expand All @@ -64,22 +74,28 @@ class BaseDataSource(MappingView, Sequence):
deserialization/decoding.
Attributes:
dataset_info: The DatasetInfo of the
dataset_builder: The dataset builder.
split: The split to load in the data source.
decoders: Optional decoders for decoding.
data_source: The underlying data source to initialize in the __post_init__.
"""

dataset_info: dataset_info_lib.DatasetInfo
dataset_builder: _DatasetBuilder
split: splits_lib.Split | None = None
decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None
data_source: DataSource[Any] = dataclasses.field(init=False)

@functools.cached_property
def _features(self) -> top_level_feature.TopLevelFeature:
"""Caches features because we log the use of dataset_builder.info."""
features = self.dataset_builder.info.features
if not features:
raise ValueError('No feature defined in the dataset builder.')
return features

def __getitem__(self, key: SupportsIndex) -> Any:
record = self.data_source[key.__index__()]
return self.dataset_info.features.deserialize_example_np(
record, decoders=self.decoders
)
return self._features.deserialize_example_np(record, decoders=self.decoders)

def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
"""Retrieves items by batch.
Expand All @@ -98,24 +114,24 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
if not keys:
return []
records = self.data_source.__getitems__(keys)
features = self.dataset_info.features
if len(keys) != len(records):
raise IndexError(
f'Requested {len(keys)} records but got'
f' {len(records)} records.'
f'{keys=}, {records=}'
)
return [
features.deserialize_example_np(record, decoders=self.decoders)
self._features.deserialize_example_np(record, decoders=self.decoders)
for record in records
]

def __repr__(self) -> str:
decoders_repr = (
tree.map_structure(type, self.decoders) if self.decoders else None
)
name = self.dataset_builder.info.name
return (
f'{self.__class__.__name__}(name={self.dataset_info.name}, '
f'{self.__class__.__name__}(name={name}, '
f'split={self.split!r}, '
f'decoders={decoders_repr})'
)
Expand Down
49 changes: 39 additions & 10 deletions tensorflow_datasets/core/data_sources/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

"""Tests for all data sources."""

import pickle
from unittest import mock

import cloudpickle
from etils import epath
import pytest
import tensorflow_datasets as tfds
from tensorflow_datasets import testing
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_builder as dataset_builder_lib
from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import file_adapters
Expand Down Expand Up @@ -77,7 +79,7 @@ def mocked_parquet_dataset():
)
def test_read_write(
tmp_path: epath.Path,
builder_cls: dataset_builder.DatasetBuilder,
builder_cls: dataset_builder_lib.DatasetBuilder,
file_format: file_adapters.FileFormat,
):
builder = builder_cls(data_dir=tmp_path, file_format=file_format)
Expand Down Expand Up @@ -106,28 +108,36 @@ def test_read_write(
]


def create_dataset_info(file_format: file_adapters.FileFormat):
def create_dataset_builder(
file_format: file_adapters.FileFormat,
) -> dataset_builder_lib.DatasetBuilder:
with mock.patch.object(splits_lib, 'SplitInfo') as split_mock:
split_mock.return_value.name = 'train'
split_mock.return_value.file_instructions = _FILE_INSTRUCTIONS
dataset_info = mock.create_autospec(dataset_info_lib.DatasetInfo)
dataset_info.file_format = file_format
dataset_info.splits = {'train': split_mock()}
dataset_info.name = 'dataset_name'
return dataset_info

dataset_builder = mock.create_autospec(dataset_builder_lib.DatasetBuilder)
dataset_builder_lib.info = dataset_info

return dataset_builder


@pytest.mark.parametrize(
'data_source_cls',
_DATA_SOURCE_CLS,
)
def test_missing_split_raises_error(data_source_cls):
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
with pytest.raises(
ValueError,
match="Unknown split 'doesnotexist'.",
):
data_source_cls(dataset_info, split='doesnotexist')
data_source_cls(dataset_builder, split='doesnotexist')


@pytest.mark.usefixtures(*_FIXTURES)
Expand All @@ -136,8 +146,10 @@ def test_missing_split_raises_error(data_source_cls):
_DATA_SOURCE_CLS,
)
def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
source = data_source_cls(dataset_info, split='train')
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
source = data_source_cls(dataset_builder, split='train')
name = data_source_cls.__name__
assert (
repr(source) == f"{name}(name=dataset_name, split='train', decoders=None)"
Expand All @@ -150,9 +162,11 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
_DATA_SOURCE_CLS,
)
def test_repr_returns_meaningful_string_with_decoders(data_source_cls):
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
source = data_source_cls(
dataset_info,
dataset_builder,
split='train',
decoders={'my_feature': decode.SkipDecoding()},
)
Expand Down Expand Up @@ -181,3 +195,18 @@ def test_data_source_is_sliceable():
file_instructions = mock_array_record_data_source.call_args_list[1].args[0]
assert file_instructions[0].skip == 0
assert file_instructions[0].take == 30000


# PyGrain requires that data sources are picklable.
@pytest.mark.parametrize(
'file_format',
file_adapters.FileFormat.with_random_access(),
)
@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle])
def test_data_source_is_picklable_after_use(file_format, pickle_module):
with tfds.testing.tmp_dir() as data_dir:
builder = tfds.testing.DummyDataset(data_dir=data_dir)
builder.download_and_prepare(file_format=file_format)
data_source = builder.as_data_source(split='train')
assert data_source[0] == {'id': 0}
assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0}
3 changes: 2 additions & 1 deletion tensorflow_datasets/core/data_sources/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class ParquetDataSource(base.BaseDataSource):
"""ParquetDataSource to read from a ParquetDataset."""

def __post_init__(self):
file_instructions = base.file_instructions(self.dataset_info, self.split)
dataset_info = self.dataset_builder.info
file_instructions = base.file_instructions(dataset_info, self.split)
filenames = [
file_instruction.filename for file_instruction in file_instructions
]
Expand Down
21 changes: 17 additions & 4 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,23 @@ def code_path(cls) -> Optional[epath.Path]:
return epath.Path(filepath)

def __getstate__(self):
return self._original_state
state = {"original_state": self._original_state}
features = self.info.features
if hasattr(features, "deserialize_example_np"):
# See the comment in __setstate__ to understand why we do this.
state["deserialize_example_np"] = features.deserialize_example_np
return state

def __setstate__(self, state):
self.__init__(**state)
self.__init__(**state["original_state"])

# This is a hack. We explicitly set deserialize_example_np to propagate any
# mock on this function to PyGrain workers in multiprocessing. Indeed,
# mock.patch cannot be used in multiprocessing since the builder is created
# in a totally different process.
deserialize_example_np = state.get("deserialize_example_np")
if deserialize_example_np:
self.info.features.deserialize_example_np = deserialize_example_np

@functools.cached_property
def canonical_version(self) -> utils.Version:
Expand Down Expand Up @@ -774,13 +787,13 @@ def build_single_data_source(
file_format = self.info.file_format
if file_format == file_adapters.FileFormat.ARRAY_RECORD:
return array_record.ArrayRecordDataSource(
self.info,
self,
split=split,
decoders=decoders,
)
elif file_format == file_adapters.FileFormat.PARQUET:
return parquet.ParquetDataSource(
self.info,
self,
split=split,
decoders=decoders,
)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/testing/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):

def build_single_data_source(split):
single_data_source = array_record.ArrayRecordDataSource(
dataset_info=self.info, split=split, decoders=decoders
dataset_builder=self, split=split, decoders=decoders
)
return single_data_source

Expand Down
9 changes: 9 additions & 0 deletions tensorflow_datasets/testing/mocking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,12 @@ def test_as_data_source_fn():
assert imagenet[0] == 'foo'
assert imagenet[1] == 'bar'
assert imagenet[2] == 'baz'


# PyGrain requires that data sources are picklable.
def test_mocked_data_source_is_pickable():
with tfds.testing.mock_data(num_examples=2):
data_source = tfds.data_source('imagenet2012', split='train')
pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source))
assert len(pickled_and_unpickled_data_source) == 2
assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)

0 comments on commit 52f4d72

Please sign in to comment.