Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to merge several classes to dataset scripts #156

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions luminoth/tools/dataset/readers/object_detection/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,24 @@ def __init__(self, data_dir, split, year=DEFAULT_YEAR,
for annotation in annotations_json['annotations']:
image_id = annotation['image_id']
x, y, width, height = annotation['bbox']
if not self.merge_classes:
try:
label_id = self.classes.index(
category_to_name[annotation['category_id']]
)
except ValueError:
# Class may have gotten filtered by:
# --only-classes or --limit-classes
continue
else:
label_id = 0

self._image_to_bboxes.setdefault(image_id, []).append({
'xmin': x,
'ymin': y,
'xmax': x + width,
'ymax': y + height,
'label': self.classes.index(
category_to_name[annotation['category_id']]
),
'label': label_id,
})

self._image_to_details = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class ObjectDetectionReader(BaseReader):
Iterate over all records.
"""
def __init__(self, only_classes=None, only_images=None,
limit_examples=None, limit_classes=None, seed=None, **kwargs):
limit_examples=None, limit_classes=None, merge_classes=False,
seed=None, **kwargs):
"""
Args:
- only_classes: string or list of strings used as a class
Expand All @@ -47,6 +48,7 @@ def __init__(self, only_classes=None, only_images=None,

self._limit_examples = limit_examples
self._limit_classes = limit_classes
self.merge_classes = merge_classes
random.seed(seed)

self._total = None
Expand Down
13 changes: 9 additions & 4 deletions luminoth/tools/dataset/readers/object_detection/pascalvoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,15 @@ def iterate(self):
gt_boxes = []

for b in annotation['object']:
try:
label_id = self.classes.index(b['name'])
except ValueError:
continue
if not self.merge_classes:
try:
label_id = self.classes.index(b['name'])
except ValueError:
# Class may have gotten filtered by:
# --only-classes or --limit-classes
continue
else:
label_id = 0

gt_boxes.append({
'label': label_id,
Expand Down
17 changes: 10 additions & 7 deletions luminoth/tools/dataset/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_output_subfolder(only_classes, only_images, limit_examples,
Returns: subfolder name for records
"""
if only_classes is not None:
return 'classes-{}'.format(only_classes)
return 'classes-{}'.format('-'.join(only_classes))
elif only_images is not None:
return 'only-{}'.format(only_images)
elif limit_examples is not None and limit_classes is not None:
Expand All @@ -30,16 +30,17 @@ def get_output_subfolder(only_classes, only_images, limit_examples,
@click.option('--data-dir', help='Where to locate the original data.')
@click.option('--output-dir', help='Where to save the transformed data.')
@click.option('splits', '--split', required=True, multiple=True, help='Which splits to transform.') # noqa
@click.option('--only-classes', help='Whitelist of classes.')
@click.option('--only-classes', multiple=True, help='Whitelist of classes.')
@click.option('--merge-classes', help='Merge all classes into a single class')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we name it --single-class or --merge-all to make the fact that only a single class creatd more explicit?

@click.option('--only-images', help='Create dataset with specific examples.')
@click.option('--limit-examples', type=int, help='Limit dataset with to the first `N` examples.') # noqa
@click.option('--limit-classes', type=int, help='Limit dataset with `N` random classes.') # noqa
@click.option('--seed', type=int, help='Seed used for picking random classes.')
@click.option('overrides', '--override', '-o', multiple=True, help='Custom parameters for readers.') # noqa
@click.option('--debug', is_flag=True, help='Set level logging to DEBUG.')
def transform(dataset_reader, data_dir, output_dir, splits, only_classes,
only_images, limit_examples, limit_classes, seed, overrides,
debug):
merge_classes, only_images, limit_examples, limit_classes, seed,
overrides, debug):
"""
Prepares dataset for ingestion.

Expand Down Expand Up @@ -67,16 +68,18 @@ def transform(dataset_reader, data_dir, output_dir, splits, only_classes,
# All splits must have a consistent set of classes.
classes = None

merge_classes = merge_classes in ('True', 'true', 'TRUE')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use click types instead.


reader_kwargs = parse_override(overrides)

try:
for split in splits:
# Create instance of reader.
split_reader = reader(
data_dir, split,
only_classes=only_classes, only_images=only_images,
limit_examples=limit_examples, limit_classes=limit_classes,
seed=seed, **reader_kwargs
only_classes=only_classes, merge_classes=merge_classes,
only_images=only_images, limit_examples=limit_examples,
limit_classes=limit_classes, seed=seed, **reader_kwargs
)

if classes is None:
Expand Down
7 changes: 6 additions & 1 deletion luminoth/tools/dataset/writers/object_detection_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def save(self):

# Save classes in simple json format for later use.
classes_file = os.path.join(self._output_dir, CLASSES_FILENAME)
json.dump(self._reader.classes, tf.gfile.GFile(classes_file, 'w'))
if self._reader.merge_classes:
# Don't assign a name to the class if its a merge of several others
json.dump([''], tf.gfile.GFile(classes_file, 'w'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have an option to set the class name before merging the pull request.

else:
json.dump(self._reader.classes, tf.gfile.GFile(classes_file, 'w'))

record_file = os.path.join(
self._output_dir, '{}.tfrecords'.format(self._split))
writer = tf.python_io.TFRecordWriter(record_file)
Expand Down