Skip to content

Commit

Permalink
Merge pull request #15 from roboflow/feature/polygon-zone
Browse files Browse the repository at this point in the history
feature/polygon-zone
  • Loading branch information
SkalskiP authored Feb 7, 2023
2 parents ceabb83 + a0eb6c3 commit 2e76e3d
Show file tree
Hide file tree
Showing 21 changed files with 420 additions and 114 deletions.
3 changes: 3 additions & 0 deletions docs/detection_core.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Detections

:::supervision.detection.core.Detections
3 changes: 3 additions & 0 deletions docs/detection_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## generate_2d_mask

:::supervision.detection.utils.generate_2d_mask
6 changes: 2 additions & 4 deletions docs/draw.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
Utilities for drawing on images.

## Draw Line
## draw_line

:::supervision.draw.utils.draw_line

## Draw Rectangle
## draw_rectangle

:::supervision.draw.utils.draw_rectangle
2 changes: 1 addition & 1 deletion docs/notebook.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Utilities to help you build computer vision projects in notebook environments.
## show_frame_in_notebook

:::supervision.notebook.utils.show_frame_in_notebook
9 changes: 0 additions & 9 deletions docs/tools.md

This file was deleted.

15 changes: 9 additions & 6 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ extra:
property: G-P7ZG0Y19G5

nav:
- Home 🏠: index.md
- Video 📷: video.md
- Notebook Helpers 📓: notebook.md
- Draw 🎨: draw.md
- Geometry 📐: geometry.md
- Tools 🛠: tools.md
- Home: index.md
- API reference:
- Video: video.md
- Detection:
- Core: detection_core.md
- Utils: detection_utils.md
- Draw: draw.md
- Geometry: geometry.md
- Notebook: notebook.md

theme:
name: 'material'
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def get_version():
long_description_content_type='text/markdown',
url='https://github.com/roboflow/supervision',
install_requires=[
'numpy',
'opencv-python'
'numpy',
'opencv-python',
'matplotlib'
],
packages=find_packages(exclude=("tests",)),
extras_require={
Expand Down
17 changes: 16 additions & 1 deletion supervision/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,16 @@
__version__ = "0.1.0"
__version__ = "0.2.0"

from supervision.detection.core import BoxAnnotator, Detections
from supervision.detection.polygon_zone import PolygonZone, PolygonZoneAnnotator
from supervision.detection.utils import generate_2d_mask
from supervision.draw.color import Color, ColorPalette
from supervision.draw.utils import draw_filled_rectangle, draw_polygon, draw_text
from supervision.geometry.core import Point, Position, Rect
from supervision.geometry.utils import get_polygon_center
from supervision.notebook.utils import show_frame_in_notebook
from supervision.video import (
VideoInfo,
VideoSink,
get_video_frames_generator,
process_video,
)
File renamed without changes.
169 changes: 121 additions & 48 deletions supervision/tools/detections.py → supervision/detection/core.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Union

import cv2
import numpy as np

from supervision.draw.color import Color, ColorPalette
from supervision.geometry.core import Position


@dataclass
class Detections:
def __init__(
self,
xyxy: np.ndarray,
confidence: np.ndarray,
class_id: np.ndarray,
tracker_id: Optional[np.ndarray] = None,
):
"""
Data class containing information about the detections in a video frame.
"""
Data class containing information about the detections in a video frame.
Attributes:
xyxy (np.ndarray): An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2]
confidence (np.ndarray): An array of shape (n,) containing the confidence scores of the detections.
class_id (np.ndarray): An array of shape (n,) containing the class ids of the detections.
tracker_id (Optional[np.ndarray]): An array of shape (n,) containing the tracker ids of the detections.
"""
self.xyxy: np.ndarray = xyxy
self.confidence: np.ndarray = confidence
self.class_id: np.ndarray = class_id
self.tracker_id: Optional[np.ndarray] = tracker_id
Attributes:
xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`
confidence (np.ndarray): An array of shape `(n,)` containing the confidence scores of the detections.
class_id (np.ndarray): An array of shape `(n,)` containing the class ids of the detections.
tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections.
"""

xyxy: np.ndarray
confidence: np.ndarray
class_id: np.ndarray
tracker_id: Optional[np.ndarray] = None

def __post_init__(self):
n = len(self.xyxy)
validators = [
(isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)),
Expand Down Expand Up @@ -55,7 +55,7 @@ def __len__(self):

def __iter__(self):
"""
Iterates over the Detections object and yield a tuple of (xyxy, confidence, class_id, tracker_id) for each detection.
Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection.
"""
for i in range(len(self.xyxy)):
yield (
Expand All @@ -66,37 +66,68 @@ def __iter__(self):
)

@classmethod
def from_yolov5(cls, yolov5_output: np.ndarray):
def from_yolov5(cls, yolov5_detections):
"""
Creates a Detections instance from a YOLOv5 output tensor
Creates a Detections instance from a YOLOv5 output Detections
Attributes:
yolov5_output (np.ndarray): The output tensor from YOLOv5
yolov5_detections (yolov5.models.common.Detections): The output Detections instance from YOLOv5
Returns:
Example:
```python
>>> from supervision.tools.detections import Detections
>>> import torch
>>> from supervision import Detections
>>> detections = Detections.from_yolov5(yolov5_output)
>>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
>>> results = model(frame)
>>> detections = Detections.from_yolov5(results)
```
"""
xyxy = yolov5_output[:, :4]
confidence = yolov5_output[:, 4]
class_id = yolov5_output[:, 5].astype(int)
return cls(xyxy, confidence, class_id)
yolov5_detections_predictions = yolov5_detections.pred[0].cpu().cpu().numpy()
return cls(
xyxy=yolov5_detections_predictions[:, :4],
confidence=yolov5_detections_predictions[:, 4],
class_id=yolov5_detections_predictions[:, 5].astype(int),
)

def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray]:
@classmethod
def from_yolov8(cls, yolov8_results):
"""
Creates a Detections instance from a YOLOv8 output Results
Attributes:
yolov8_results (ultralytics.yolo.engine.results.Results): The output Results instance from YOLOv8
Returns:
Example:
```python
>>> from ultralytics import YOLO
>>> from supervision import Detections
>>> model = YOLO('yolov8s.pt')
>>> results = model(frame)
>>> detections = Detections.from_yolov8(results)
```
"""
return cls(
xyxy=yolov8_results.boxes.xyxy.cpu().numpy(),
confidence=yolov8_results.boxes.conf.cpu().numpy(),
class_id=yolov8_results.boxes.cls.cpu().numpy().astype(int),
)

def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[Detections]:
"""
Filter the detections by applying a mask.
Attributes:
mask (np.ndarray): A mask of shape (n,) containing a boolean value for each detection indicating if it should be included in the filtered detections
mask (np.ndarray): A mask of shape `(n,)` containing a boolean value for each detection indicating if it should be included in the filtered detections
inplace (bool): If True, the original data will be modified and self will be returned.
Returns:
Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to False. None otherwise.
Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to `False`. `None` otherwise.
"""
if inplace:
self.xyxy = self.xyxy[mask]
Expand All @@ -116,11 +147,49 @@ def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray
else None,
)

def get_anchor_coordinates(self, anchor: Position) -> np.ndarray:
"""
Returns the bounding box coordinates for a specific anchor.
Properties:
anchor (Position): Position of bounding box anchor for which to return the coordinates.
Returns:
np.ndarray: An array of shape `(n, 2)` containing the bounding box anchor coordinates in format `[x, y]`.
"""
if anchor == Position.CENTER:
return np.array(
[
(self.xyxy[:, 0] + self.xyxy[:, 2]) / 2,
(self.xyxy[:, 1] + self.xyxy[:, 3]) / 2,
]
).transpose()
elif anchor == Position.BOTTOM_CENTER:
return np.array(
[(self.xyxy[:, 0] + self.xyxy[:, 2]) / 2, self.xyxy[:, 3]]
).transpose()

raise ValueError(f"{anchor} is not supported.")

def __getitem__(self, index: np.ndarray) -> Detections:
if isinstance(index, np.ndarray) and index.dtype == np.bool:
return Detections(
xyxy=self.xyxy[index],
confidence=self.confidence[index],
class_id=self.class_id[index],
tracker_id=self.tracker_id[index]
if self.tracker_id is not None
else None,
)
raise TypeError(
f"Detections.__getitem__ not supported for index of type {type(index)}."
)


class BoxAnnotator:
def __init__(
self,
color: Union[Color, ColorPalette],
color: Union[Color, ColorPalette] = ColorPalette.default(),
thickness: int = 2,
text_color: Color = Color.black(),
text_scale: float = 0.5,
Expand Down Expand Up @@ -148,35 +217,46 @@ def __init__(

def annotate(
self,
frame: np.ndarray,
scene: np.ndarray,
detections: Detections,
labels: Optional[List[str]] = None,
skip_label: bool = False,
) -> np.ndarray:
"""
Draws bounding boxes on the frame using the detections provided.
Attributes:
frame (np.ndarray): The image on which the bounding boxes will be drawn
Parameters:
scene (np.ndarray): The image on which the bounding boxes will be drawn
detections (Detections): The detections for which the bounding boxes will be drawn
labels (Optional[List[str]]): An optional list of labels corresponding to each detection. If labels is provided, the confidence score of the detection will be replaced with the label.
skip_label (bool): Is set to True, skips bounding box label annotation.
Returns:
np.ndarray: The image with the bounding boxes drawn on it
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections):
x1, y1, x2, y2 = xyxy.astype(int)
color = (
self.color.by_idx(class_id)
if isinstance(self.color, ColorPalette)
else self.color
)
cv2.rectangle(
img=scene,
pt1=(x1, y1),
pt2=(x2, y2),
color=color.as_bgr(),
thickness=self.thickness,
)
if skip_label:
continue

text = (
f"{confidence:0.2f}"
if (labels is None or len(detections) != len(labels))
else labels[i]
)

x1, y1, x2, y2 = xyxy.astype(int)
text_width, text_height = cv2.getTextSize(
text=text,
fontFace=font,
Expand All @@ -194,21 +274,14 @@ def annotate(
text_background_y2 = y1

cv2.rectangle(
img=frame,
pt1=(x1, y1),
pt2=(x2, y2),
color=color.as_bgr(),
thickness=self.thickness,
)
cv2.rectangle(
img=frame,
img=scene,
pt1=(text_background_x1, text_background_y1),
pt2=(text_background_x2, text_background_y2),
color=color.as_bgr(),
thickness=cv2.FILLED,
)
cv2.putText(
img=frame,
img=scene,
text=text,
org=(text_x, text_y),
fontFace=font,
Expand All @@ -217,4 +290,4 @@ def annotate(
thickness=self.text_thickness,
lineType=cv2.LINE_AA,
)
return frame
return scene
Loading

0 comments on commit 2e76e3d

Please sign in to comment.