Skip to content

Commit

Permalink
Migrate to simple_parsing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668933314
  • Loading branch information
fineguy authored and The TensorFlow Datasets Authors committed Aug 30, 2024
1 parent 7312de7 commit cb88725
Show file tree
Hide file tree
Showing 11 changed files with 352 additions and 563 deletions.
171 changes: 86 additions & 85 deletions tensorflow_datasets/scripts/cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""`tfds build` command."""

import argparse
import dataclasses
import functools
import importlib
import itertools
Expand All @@ -25,84 +25,84 @@
from typing import Any, Dict, Iterator, Optional, Tuple, Type, Union

from absl import logging
import simple_parsing
import tensorflow_datasets as tfds
from tensorflow_datasets.scripts.cli import cli_utils

# pylint: disable=logging-fstring-interpolation


def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access
"""Add subparser for `build` command."""
build_parser = parsers.add_parser(
'build', help='Commands for downloading and preparing datasets.'
)
build_parser.add_argument(
'datasets', # Positional arguments
type=str,
nargs='*',
help=(
'Name(s) of the dataset(s) to build. Default to current dir. '
'See https://www.tensorflow.org/datasets/cli for accepted values.'
),
)
build_parser.add_argument( # Also accept keyword arguments
'--datasets',
type=str,
nargs='+',
dest='datasets_keyword',
help='Datasets can also be provided as keyword argument.',
)
@dataclasses.dataclass(frozen=True, kw_only=True)
class _AutomationGroup:
"""Used by automated scripts.
cli_utils.add_debug_argument_group(build_parser)
cli_utils.add_path_argument_group(build_parser)
cli_utils.add_generation_argument_group(build_parser)
cli_utils.add_publish_argument_group(build_parser)
Attributes:
exclude_datasets: If set, generate all datasets except the one defined here.
Comma separated list of datasets to exclude.
experimental_latest_version: Build the latest Version(experiments=...)
available rather than default version.
"""

# **** Automation options ****
automation_group = build_parser.add_argument_group(
'Automation', description='Used by automated scripts.'
)
automation_group.add_argument(
'--exclude_datasets',
type=str,
help=(
'If set, generate all datasets except the one defined here. '
'Comma separated list of datasets to exclude. '
),
exclude_datasets: list[str] = cli_utils.comma_separated_list_field()
experimental_latest_version: bool = False


@dataclasses.dataclass(frozen=True, kw_only=True)
class CmdArgs:
"""Commands for downloading and preparing datasets.
Attributes:
datasets: Name(s) of the dataset(s) to build. Default to current dir. See
https://www.tensorflow.org/datasets/cli for accepted values.
datasets_keyword: Datasets can also be provided as keyword argument.
debug: Debug & tests options.
path: Paths options.
generation: Generation options.
publish: Publishing options.
automation: Automation options.
"""

datasets: list[str] = simple_parsing.field(
positional=True, default_factory=list, nargs='*'
)
automation_group.add_argument(
'--experimental_latest_version',
action='store_true',
help=(
'Build the latest Version(experiments=...) available rather than '
'default version.'
),
datasets_keyword: list[str] = simple_parsing.field(
alias='datasets', default_factory=list, nargs='*'
)
debug: cli_utils.DebugGroup = simple_parsing.field(prefix='')
path: cli_utils.PathGroup = simple_parsing.field(prefix='')
generation: cli_utils.GenerationGroup = simple_parsing.field(prefix='')
publish: cli_utils.PublishGroup = simple_parsing.field(prefix='')
automation: _AutomationGroup = simple_parsing.field(prefix='')

build_parser.set_defaults(subparser_fn=_build_datasets)
def execute(self):
_build_datasets(self)


def _build_datasets(args: argparse.Namespace) -> None:
def _build_datasets(args: CmdArgs) -> None:
"""Build the given datasets."""
# Eventually register additional datasets imports
if args.imports:
list(importlib.import_module(m) for m in args.imports.split(','))
if args.generation.imports:
list(importlib.import_module(m) for m in args.generation.imports)

# Select datasets to generate
datasets = (args.datasets or []) + (args.datasets_keyword or [])
if args.exclude_datasets: # Generate all datasets if `--exclude_datasets` set
datasets = args.datasets + args.datasets_keyword
if (
args.automation.exclude_datasets
): # Generate all datasets if `--exclude_datasets` set
if datasets:
raise ValueError("--exclude_datasets can't be used with `datasets`")
datasets = set(tfds.list_builders(with_community_datasets=False)) - set(
args.exclude_datasets.split(',')
args.automation.exclude_datasets
)
datasets = sorted(datasets) # `set` is not deterministic
else:
datasets = datasets or [''] # Empty string for default

# Import builder classes
builders_cls_and_kwargs = [
_get_builder_cls_and_kwargs(dataset, has_imports=bool(args.imports))
_get_builder_cls_and_kwargs(
dataset, has_imports=bool(args.generation.imports)
)
for dataset in datasets
]

Expand All @@ -112,19 +112,20 @@ def _build_datasets(args: argparse.Namespace) -> None:
for (builder_cls, builder_kwargs) in builders_cls_and_kwargs
))
process_builder_fn = functools.partial(
_download if args.download_only else _download_and_prepare, args
_download if args.generation.download_only else _download_and_prepare,
args,
)

if args.num_processes == 1:
if args.generation.num_processes == 1:
for builder in builders:
process_builder_fn(builder)
else:
with multiprocessing.Pool(args.num_processes) as pool:
with multiprocessing.Pool(args.generation.num_processes) as pool:
pool.map(process_builder_fn, builders)


def _make_builders(
args: argparse.Namespace,
args: CmdArgs,
builder_cls: Type[tfds.core.DatasetBuilder],
builder_kwargs: Dict[str, Any],
) -> Iterator[tfds.core.DatasetBuilder]:
Expand All @@ -139,7 +140,7 @@ def _make_builders(
Initialized dataset builders.
"""
# Eventually overwrite version
if args.experimental_latest_version:
if args.automation.experimental_latest_version:
if 'version' in builder_kwargs:
raise ValueError(
"Can't have both `--experimental_latest` and version set (`:1.0.0`)"
Expand All @@ -150,19 +151,19 @@ def _make_builders(
builder_kwargs['config'] = _get_config_name(
builder_cls=builder_cls,
config_kwarg=builder_kwargs.get('config'),
config_name=args.config,
config_idx=args.config_idx,
config_name=args.generation.config,
config_idx=args.generation.config_idx,
)

if args.file_format:
builder_kwargs['file_format'] = args.file_format
if args.generation.file_format:
builder_kwargs['file_format'] = args.generation.file_format

make_builder = functools.partial(
_make_builder,
builder_cls,
overwrite=args.overwrite,
fail_if_exists=args.fail_if_exists,
data_dir=args.data_dir,
overwrite=args.debug.overwrite,
fail_if_exists=args.debug.fail_if_exists,
data_dir=args.path.data_dir,
**builder_kwargs,
)

Expand Down Expand Up @@ -301,7 +302,7 @@ def _make_builder(


def _download(
args: argparse.Namespace,
args: CmdArgs,
builder: tfds.core.DatasetBuilder,
) -> None:
"""Downloads all files of the given builder."""
Expand All @@ -323,7 +324,7 @@ def _download(
if builder.MAX_SIMULTANEOUS_DOWNLOADS is not None:
max_simultaneous_downloads = builder.MAX_SIMULTANEOUS_DOWNLOADS

download_dir = args.download_dir or os.path.join(
download_dir = args.path.download_dir or os.path.join(
builder._data_dir_root, 'downloads' # pylint: disable=protected-access
)
dl_manager = tfds.download.DownloadManager(
Expand All @@ -345,51 +346,51 @@ def _download(


def _download_and_prepare(
args: argparse.Namespace,
args: CmdArgs,
builder: tfds.core.DatasetBuilder,
) -> None:
"""Generate a single builder."""
cli_utils.download_and_prepare(
builder=builder,
download_config=_make_download_config(args, dataset_name=builder.name),
download_dir=args.download_dir,
publish_dir=args.publish_dir,
skip_if_published=args.skip_if_published,
overwrite=args.overwrite,
download_dir=args.path.download_dir,
publish_dir=args.publish.publish_dir,
skip_if_published=args.publish.skip_if_published,
overwrite=args.debug.overwrite,
)


def _make_download_config(
args: argparse.Namespace,
args: CmdArgs,
dataset_name: str,
) -> tfds.download.DownloadConfig:
"""Generate the download and prepare configuration."""
# Load the download config
manual_dir = args.manual_dir
if args.add_name_to_manual_dir:
manual_dir = args.path.manual_dir
if args.path.add_name_to_manual_dir:
manual_dir = manual_dir / dataset_name

kwargs = {}
if args.max_shard_size_mb:
kwargs['max_shard_size'] = args.max_shard_size_mb << 20
if args.download_config:
kwargs.update(json.loads(args.download_config))
if args.generation.max_shard_size_mb:
kwargs['max_shard_size'] = args.generation.max_shard_size_mb << 20
if args.generation.download_config:
kwargs.update(json.loads(args.generation.download_config))

if 'download_mode' in kwargs:
kwargs['download_mode'] = tfds.download.GenerateMode(
kwargs['download_mode']
)
else:
kwargs['download_mode'] = tfds.download.GenerateMode.REUSE_DATASET_IF_EXISTS
if args.update_metadata_only:
if args.generation.update_metadata_only:
kwargs['download_mode'] = tfds.download.GenerateMode.UPDATE_DATASET_INFO

dl_config = tfds.download.DownloadConfig(
extract_dir=args.extract_dir,
extract_dir=args.path.extract_dir,
manual_dir=manual_dir,
max_examples_per_split=args.max_examples_per_split,
register_checksums=args.register_checksums,
force_checksums_validation=args.force_checksums_validation,
max_examples_per_split=args.debug.max_examples_per_split,
register_checksums=args.generation.register_checksums,
force_checksums_validation=args.generation.force_checksums_validation,
**kwargs,
)

Expand All @@ -400,9 +401,9 @@ def _make_download_config(
beam = None

if beam is not None:
if args.beam_pipeline_options:
if args.generation.beam_pipeline_options:
dl_config.beam_options = beam.options.pipeline_options.PipelineOptions(
flags=[f'--{opt}' for opt in args.beam_pipeline_options.split(',')]
flags=[f'--{opt}' for opt in args.generation.beam_pipeline_options]
)

return dl_config
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/scripts/cli/build_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_download_only():
)
def test_make_download_config(args: str, download_config_kwargs):
args = main._parse_flags(f'tfds build x {args}'.split())
actual = build_lib._make_download_config(args, dataset_name='x')
actual = build_lib._make_download_config(args.command, dataset_name='x')
# Ignore the beam runner
actual = actual.replace(beam_runner=None)
expected = tfds.download.DownloadConfig(**download_config_kwargs)
Expand Down
Loading

0 comments on commit cb88725

Please sign in to comment.