Skip to content

Commit

Permalink
Merge pull request #7 from jaspersiebring/cli_fix
Browse files Browse the repository at this point in the history
Added cli options for (super)category names and ids
  • Loading branch information
jaspersiebring authored Sep 20, 2023
2 parents 1274288 + f823318 commit 84f44c1
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 81 deletions.
24 changes: 11 additions & 13 deletions geococo/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import typer_cloup as typer
import pathlib
from datetime import datetime
Expand All @@ -14,7 +15,9 @@ def build_coco(
output_dir: pathlib.Path,
width: int,
height: int,
category_attribute: str = "category_id",
category_id_col: Optional[str] = "category_id",
category_name_col: Optional[str] = None,
supercategory_col: Optional[str] = None,
) -> None:
"""Transform your GIS annotations into a COCO dataset.
Expand All @@ -38,14 +41,14 @@ def build_coco(
:param output_dir: Path to the output directory for image subsets
:param width: Width of the output images
:param height: Height of the output images
:param category_attribute: Column that contains category_id values per annotation
feature
:param category_id_col: Column containing category_id values
:param category_name_col: Column containing category_name values
:param supercategory_col: Column containing supercategory values
"""

if isinstance(json_path, pathlib.Path) and json_path.exists():
# Create and populate instance of CocoDataset from json_path
dataset = load_dataset(json_path=json_path)

else:
# Create instance of CocoDataset from user input
print("Creating new dataset..")
Expand All @@ -62,16 +65,8 @@ def build_coco(
date_created=date_created,
)

# Loading and validating GIS data
# Loading GIS data
labels = gpd.read_file(labels_path)
if not labels.is_valid.all():
raise ValueError("Invalid geometry found, exiting..")
elif category_attribute not in labels.columns:
raise ValueError(
f"User-specified category attribute (={category_attribute}) not found in "
"input labels, exiting.."
)

with rasterio.open(image_path) as src:
# Appending Annotation instances and clipping the image part that contains them
dataset = labels_to_dataset(
Expand All @@ -80,6 +75,9 @@ def build_coco(
src=src,
labels=labels,
window_bounds=[(width, height)],
category_id_col=category_id_col,
category_name_col=category_name_col,
supercategory_col=supercategory_col,
)

# Encode CocoDataset instance as JSON and save to json_path
Expand Down
34 changes: 21 additions & 13 deletions geococo/coco_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pathlib
import pandas as pd
from pandas import Series
from datetime import datetime
from typing import List, Optional, Dict, Any
from typing_extensions import TypedDict
Expand Down Expand Up @@ -50,9 +51,9 @@ def add_source(self, source_path: pathlib.Path) -> None:

def add_categories(
self,
category_ids: Optional[np.ndarray],
category_names: Optional[np.ndarray],
supercategory_names: Optional[np.ndarray],
category_ids: Optional[Series],
category_names: Optional[Series],
super_names: Optional[Series],
) -> None:
# initializing values
super_default = "1"
Expand All @@ -65,13 +66,15 @@ def add_categories(
)

# checking if names can be assigned to uid_array (used to check duplicates)
if isinstance(category_names, np.ndarray):
if category_names is not None:
category_names: np.ndarray = category_names.to_numpy()
uid_array = category_names
uid_attribute = "name"
names_present = True

# checking if ids can be assigned to uid_array (used to check duplicates)
if isinstance(category_ids, np.ndarray):
if category_ids is not None:
category_ids: np.ndarray = category_ids.to_numpy()
uid_array = category_ids # overrides existing array because ids are leading
uid_attribute = "id"
ids_present = True
Expand All @@ -89,11 +92,15 @@ def add_categories(
return

# creating default supercategory_names if not given
if not isinstance(supercategory_names, np.ndarray):
supercategory_names = np.full(shape=new_shape, fill_value=super_default)
if super_names is None:
super_names = np.full(
shape=new_shape,
fill_value=super_default
) # type: ignore[assignment]
else:
assert supercategory_names.shape == original_shape
supercategory_names = supercategory_names[indices][~member_mask]
super_names: np.ndarray = super_names.to_numpy()
assert super_names.shape == original_shape
super_names = super_names[indices][~member_mask]

# creating default category_names if not given (str version of ids)
if ids_present and not names_present:
Expand All @@ -105,16 +112,17 @@ def add_categories(
max_id = category_pd.loc[pandas_mask, "id"].max()
start = np.nansum([max_id, 1])
end = start + new_members.size
category_ids = np.arange(start, end)
category_ids = np.arange(start, end) # type: ignore[assignment]
category_names = new_members
# ensuring equal size for category names and ids (if given)
else:
assert category_names.shape == original_shape # type: ignore
category_names = category_names[indices][~member_mask] # type: ignore
assert category_names.shape == original_shape # type: ignore[union-attr]
category_names = category_names[indices][~member_mask] # type: ignore[index]
category_ids = new_members

# iteratively instancing and appending Category from set ids, names and supers
for cid, name, super in zip(category_ids, category_names, supercategory_names):
cip = zip(category_ids, category_names, super_names)
for cid, name, super in cip:
category = Category(id=cid, name=name, supercategory=super)
self.categories.append(category)

Expand Down
21 changes: 4 additions & 17 deletions geococo/coco_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import geopandas as gpd
import numpy as np
import rasterio
from pandas import Series
from typing import Optional
from rasterio.io import DatasetReader
from rasterio.mask import mask as riomask
Expand Down Expand Up @@ -66,23 +65,11 @@ def labels_to_dataset(
supercategory_col=supercategory_col,
)

# dumping series to array (if present)
category_ids = labels.get(category_id_col)
category_ids = category_ids.values if isinstance(category_ids, Series) else None
category_names = labels.get(category_name_col)
category_names = (
category_names.values if isinstance(category_names, Series) else None
)
supercategory_names = labels.get(supercategory_col)
supercategory_names = (
supercategory_names.values if isinstance(supercategory_names, Series) else None
)

# adding new Category instances (if any)
# adding new Category instances from labels (if any)
dataset.add_categories(
category_ids=category_ids,
category_names=category_names,
supercategory_names=supercategory_names,
category_ids=labels.get(category_id_col),
category_names=labels.get(category_name_col),
super_names=labels.get(supercategory_col),
)

# updating labels with validated COCO keys (i.e. 'name', 'id', 'supercategory')
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "geococo"
version = "0.4.0"
version = "0.4.1"
description = "Converts GIS annotations to Microsoft's Common Objects In Context (COCO) dataset format"
authors = ["Jasper <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -42,7 +42,12 @@ module = [
"geopandas",
"shapely.*"
]

[tool.mypy]
ignore_missing_imports = true
allow_redefinition = true
strict_optional = false
disable_error_code = 'no-redef'

[tool.ruff]
line-length = 88
Expand Down
Loading

0 comments on commit 84f44c1

Please sign in to comment.