Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream from Hugging Face instead of downloading and preparing everything. #5539

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import itertools
import multiprocessing
import os
import time
from typing import Any, Dict, Optional, Union

from absl import logging
Expand Down Expand Up @@ -108,9 +109,24 @@ class _ShardInfo:
num_exceptions: int


def _load_dataset(
hf_builder: hf_datasets.DatasetBuilder,
split: str,
) -> hf_datasets.Dataset:
"""Efficiently loads a HuggingFace iterable dataset from its builder."""
if hf_builder.repo_id is None:
return hf_builder.as_dataset(split=split)
return hf_datasets.load_dataset(
hf_builder.repo_id or hf_builder.cache_dir,
hf_builder.config_id,
split=split,
streaming=True,
)


def _write_shard(
shard_spec: _ShardSpec,
hf_builder,
hf_builder: hf_datasets.DatasetBuilder,
example_writer,
features: feature_lib.FeaturesDict,
ignore_hf_errors: bool,
Expand All @@ -136,12 +152,19 @@ def _write_shard(
def get_serialized_examples_iter():
nonlocal num_bytes
nonlocal num_exceptions
dataset = hf_builder.as_dataset(
split=shard_spec.shard_split, run_post_process=False
dataset = _load_dataset(
hf_builder,
shard_spec.hf_split,
)
for i in range(shard_spec.num_examples):
dataset = iter(dataset)
# Skipping the first `start_index` examples. `streaming=True` returns an
# iterable dataset, so we cannot jump to a specific index. This is not too
# costly because it takes <0.5 ms/element in the wikipedia dataset.
for _ in range(shard_spec.start_index):
next(dataset)
for _ in range(shard_spec.num_examples):
try:
hf_value = dataset[i]
hf_value = next(dataset)
except Exception: # pylint: disable=broad-exception-caught
num_exceptions += 1
if ignore_hf_errors:
Expand All @@ -155,6 +178,7 @@ def get_serialized_examples_iter():
num_bytes += len(serialized_example)
yield serialized_example

start = time.time()
example_writer.write(
os.fspath(shard_spec.path),
tqdm_utils.tqdm(
Expand All @@ -166,6 +190,11 @@ def get_serialized_examples_iter():
mininterval=1.0,
),
)
logging.info(
'Generated %s examples in %s seconds',
shard_spec.num_examples,
time.time() - start,
)

return _ShardInfo(
num_bytes=num_bytes,
Expand Down Expand Up @@ -247,6 +276,7 @@ def __init__(
self._builder_config = self._converted_builder_config
self.generation_errors = []
self._ignore_hf_errors = ignore_hf_errors
login_to_hf(self._hf_hub_token)

@property
def builder_config(self) -> Optional[Any]:
Expand All @@ -257,14 +287,6 @@ def _create_builder_config(
) -> Optional[dataset_builder.BuilderConfig]:
return self._converted_builder_config

@functools.lru_cache(maxsize=1)
def _hf_download_and_prepare(self):
login_to_hf(self._hf_hub_token)
self._hf_builder.download_and_prepare(
num_proc=self._hf_num_proc,
verification_mode=self._verification_mode,
)

@property
def _hf_info(self) -> hf_datasets.DatasetInfo:
"""Retrieves the dataset info from the HuggingFace Datasets."""
Expand All @@ -278,11 +300,18 @@ def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo:
)

def _hf_features(self) -> hf_datasets.Features:
if not self._hf_info.features:
# We need to download and prepare the data to know its features.
self._hf_download_and_prepare()

return self._hf_info.features
# Return the features from the builder info.
if self._hf_info.features:
return self._hf_info.features
# Return the features from the first split.
for split in self._hf_info.splits:
ds = _load_dataset(
self._hf_builder,
split,
)
if hasattr(ds, 'info') and ds.info.features:
return ds.info.features
raise ValueError('No features found in the dataset.')

def _info(self) -> dataset_info_lib.DatasetInfo:
return dataset_info_lib.DatasetInfo(
Expand All @@ -309,7 +338,6 @@ def _generate_splits(
) -> Sequence[splits_lib.SplitInfo]:
"""Prepares the dataset by writing to shards directly."""
del dl_manager, download_config # Unused.
self._hf_download_and_prepare()

shard_specs_by_split: dict[str, Sequence[_ShardSpec]] = {}
for hf_split, hf_split_info in self._hf_info.splits.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def mock_load_dataset_builder(tmp_path):
with mock.patch.object(
hf_datasets, 'load_dataset_builder', return_value=hf_builder
) as load_dataset_builder:
hf_builder.download_and_prepare()
yield load_dataset_builder


Expand Down Expand Up @@ -133,12 +134,6 @@ def test_download_and_prepare(builder):
assert len(ds['train_clean']) == 2


def test_all_parameters_are_passed_down_to_hf(builder):
builder._hf_builder.download_and_prepare.assert_called_once_with(
verification_mode='no_checks', num_proc=100
)


def test_hf_features(builder):
assert builder._hf_features() == {
'number': hf_datasets.Value('int64'),
Expand Down
Loading