From 290ee7e0a856849e8bd557fe81641412ab0328bf Mon Sep 17 00:00:00 2001 From: Pierre Marcenac Date: Mon, 29 Jul 2024 08:38:52 -0700 Subject: [PATCH] Stream from Hugging Face instead of downloading and preparing everything. PiperOrigin-RevId: 657212303 --- .../huggingface_dataset_builder.py | 66 +++++++++++++------ .../huggingface_dataset_builder_test.py | 7 +- 2 files changed, 48 insertions(+), 25 deletions(-) diff --git a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py index 9d1ee6b57ee..d3159f9ada4 100644 --- a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py @@ -32,6 +32,7 @@ import itertools import multiprocessing import os +import time from typing import Any, Dict, Optional, Union from absl import logging @@ -108,9 +109,24 @@ class _ShardInfo: num_exceptions: int +def _load_dataset( + hf_builder: hf_datasets.DatasetBuilder, + split: str, +) -> hf_datasets.Dataset: + """Efficiently loads a HuggingFace iterable dataset from its builder.""" + if hf_builder.repo_id is None: + return hf_builder.as_dataset(split=split) + return hf_datasets.load_dataset( + hf_builder.repo_id or hf_builder.cache_dir, + hf_builder.config_id, + split=split, + streaming=True, + ) + + def _write_shard( shard_spec: _ShardSpec, - hf_builder, + hf_builder: hf_datasets.DatasetBuilder, example_writer, features: feature_lib.FeaturesDict, ignore_hf_errors: bool, @@ -136,12 +152,19 @@ def _write_shard( def get_serialized_examples_iter(): nonlocal num_bytes nonlocal num_exceptions - dataset = hf_builder.as_dataset( - split=shard_spec.shard_split, run_post_process=False + dataset = _load_dataset( + hf_builder, + shard_spec.hf_split, ) - for i in range(shard_spec.num_examples): + dataset = iter(dataset) + # Skipping the first `start_index` examples. `streaming=True` returns an + # iterable dataset, so we cannot jump to a specific index. This is not too + # costly because it takes <0.5 ms/element in the wikipedia dataset. + for _ in range(shard_spec.start_index): + next(dataset) + for _ in range(shard_spec.num_examples): try: - hf_value = dataset[i] + hf_value = next(dataset) except Exception: # pylint: disable=broad-exception-caught num_exceptions += 1 if ignore_hf_errors: @@ -155,6 +178,7 @@ def get_serialized_examples_iter(): num_bytes += len(serialized_example) yield serialized_example + start = time.time() example_writer.write( os.fspath(shard_spec.path), tqdm_utils.tqdm( @@ -166,6 +190,11 @@ def get_serialized_examples_iter(): mininterval=1.0, ), ) + logging.info( + 'Generated %s examples in %s seconds', + shard_spec.num_examples, + time.time() - start, + ) return _ShardInfo( num_bytes=num_bytes, @@ -247,6 +276,7 @@ def __init__( self._builder_config = self._converted_builder_config self.generation_errors = [] self._ignore_hf_errors = ignore_hf_errors + login_to_hf(self._hf_hub_token) @property def builder_config(self) -> Optional[Any]: @@ -257,14 +287,6 @@ def _create_builder_config( ) -> Optional[dataset_builder.BuilderConfig]: return self._converted_builder_config - @functools.lru_cache(maxsize=1) - def _hf_download_and_prepare(self): - login_to_hf(self._hf_hub_token) - self._hf_builder.download_and_prepare( - num_proc=self._hf_num_proc, - verification_mode=self._verification_mode, - ) - @property def _hf_info(self) -> hf_datasets.DatasetInfo: """Retrieves the dataset info from the HuggingFace Datasets.""" @@ -278,11 +300,18 @@ def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo: ) def _hf_features(self) -> hf_datasets.Features: - if not self._hf_info.features: - # We need to download and prepare the data to know its features. - self._hf_download_and_prepare() - - return self._hf_info.features + # Return the features from the builder info. + if self._hf_info.features: + return self._hf_info.features + # Return the features from the first split. + for split in self._hf_info.splits: + ds = _load_dataset( + self._hf_builder, + split, + ) + if hasattr(ds, 'info') and ds.info.features: + return ds.info.features + raise ValueError('No features found in the dataset.') def _info(self) -> dataset_info_lib.DatasetInfo: return dataset_info_lib.DatasetInfo( @@ -309,7 +338,6 @@ def _generate_splits( ) -> Sequence[splits_lib.SplitInfo]: """Prepares the dataset by writing to shards directly.""" del dl_manager, download_config # Unused. - self._hf_download_and_prepare() shard_specs_by_split: dict[str, Sequence[_ShardSpec]] = {} for hf_split, hf_split_info in self._hf_info.splits.items(): diff --git a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py index da903654ee3..92ee6ec84fb 100644 --- a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py +++ b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py @@ -62,6 +62,7 @@ def mock_load_dataset_builder(tmp_path): with mock.patch.object( hf_datasets, 'load_dataset_builder', return_value=hf_builder ) as load_dataset_builder: + hf_builder.download_and_prepare() yield load_dataset_builder @@ -133,12 +134,6 @@ def test_download_and_prepare(builder): assert len(ds['train_clean']) == 2 -def test_all_parameters_are_passed_down_to_hf(builder): - builder._hf_builder.download_and_prepare.assert_called_once_with( - verification_mode='no_checks', num_proc=100 - ) - - def test_hf_features(builder): assert builder._hf_features() == { 'number': hf_datasets.Value('int64'),