diff --git a/docs/detection_core.md b/docs/detection_core.md new file mode 100644 index 000000000..fc3814fa9 --- /dev/null +++ b/docs/detection_core.md @@ -0,0 +1,3 @@ +## Detections + +:::supervision.detection.core.Detections diff --git a/docs/detection_utils.md b/docs/detection_utils.md new file mode 100644 index 000000000..500f8b95b --- /dev/null +++ b/docs/detection_utils.md @@ -0,0 +1,3 @@ +## generate_2d_mask + +:::supervision.detection.utils.generate_2d_mask \ No newline at end of file diff --git a/docs/draw.md b/docs/draw.md index 8ba499b16..e94509e30 100644 --- a/docs/draw.md +++ b/docs/draw.md @@ -1,9 +1,7 @@ -Utilities for drawing on images. - -## Draw Line +## draw_line :::supervision.draw.utils.draw_line -## Draw Rectangle +## draw_rectangle :::supervision.draw.utils.draw_rectangle \ No newline at end of file diff --git a/docs/notebook.md b/docs/notebook.md index ce650120a..b910496d9 100644 --- a/docs/notebook.md +++ b/docs/notebook.md @@ -1,3 +1,3 @@ -Utilities to help you build computer vision projects in notebook environments. +## show_frame_in_notebook :::supervision.notebook.utils.show_frame_in_notebook \ No newline at end of file diff --git a/docs/tools.md b/docs/tools.md deleted file mode 100644 index 9274cb0ef..000000000 --- a/docs/tools.md +++ /dev/null @@ -1,9 +0,0 @@ -Useful utilities for common computer vision tasks. - -## Helper for Processing Model Detections - -:::supervision.tools.detections.Detections - -## Count Objects That Pass a Line - -:::supervision.tools.line_counter.LineCounter \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 3212b2aaf..b788d0ce4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,12 +22,15 @@ extra: property: G-P7ZG0Y19G5 nav: - - Home 🏠: index.md - - Video 📷: video.md - - Notebook Helpers 📓: notebook.md - - Draw 🎨: draw.md - - Geometry 📐: geometry.md - - Tools 🛠: tools.md + - Home: index.md + - API reference: + - Video: video.md + - Detection: + - Core: detection_core.md + - Utils: detection_utils.md + - Draw: draw.md + - Geometry: geometry.md + - Notebook: notebook.md theme: name: 'material' diff --git a/setup.py b/setup.py index 6afc4eaa5..d91926459 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,9 @@ def get_version(): long_description_content_type='text/markdown', url='https://github.com/roboflow/supervision', install_requires=[ - 'numpy', - 'opencv-python' + 'numpy', + 'opencv-python', + 'matplotlib' ], packages=find_packages(exclude=("tests",)), extras_require={ diff --git a/supervision/__init__.py b/supervision/__init__.py index 3dc1f76bc..c449f7733 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -1 +1,16 @@ -__version__ = "0.1.0" +__version__ = "0.2.0" + +from supervision.detection.core import BoxAnnotator, Detections +from supervision.detection.polygon_zone import PolygonZone, PolygonZoneAnnotator +from supervision.detection.utils import generate_2d_mask +from supervision.draw.color import Color, ColorPalette +from supervision.draw.utils import draw_filled_rectangle, draw_polygon, draw_text +from supervision.geometry.core import Point, Position, Rect +from supervision.geometry.utils import get_polygon_center +from supervision.notebook.utils import show_frame_in_notebook +from supervision.video import ( + VideoInfo, + VideoSink, + get_video_frames_generator, + process_video, +) diff --git a/supervision/tools/__init__.py b/supervision/detection/__init__.py similarity index 100% rename from supervision/tools/__init__.py rename to supervision/detection/__init__.py diff --git a/supervision/tools/detections.py b/supervision/detection/core.py similarity index 57% rename from supervision/tools/detections.py rename to supervision/detection/core.py index 4faaf6cf6..0863e64fc 100644 --- a/supervision/tools/detections.py +++ b/supervision/detection/core.py @@ -1,33 +1,33 @@ +from __future__ import annotations + +from dataclasses import dataclass from typing import List, Optional, Union import cv2 import numpy as np from supervision.draw.color import Color, ColorPalette +from supervision.geometry.core import Position +@dataclass class Detections: - def __init__( - self, - xyxy: np.ndarray, - confidence: np.ndarray, - class_id: np.ndarray, - tracker_id: Optional[np.ndarray] = None, - ): - """ - Data class containing information about the detections in a video frame. + """ + Data class containing information about the detections in a video frame. - Attributes: - xyxy (np.ndarray): An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2] - confidence (np.ndarray): An array of shape (n,) containing the confidence scores of the detections. - class_id (np.ndarray): An array of shape (n,) containing the class ids of the detections. - tracker_id (Optional[np.ndarray]): An array of shape (n,) containing the tracker ids of the detections. - """ - self.xyxy: np.ndarray = xyxy - self.confidence: np.ndarray = confidence - self.class_id: np.ndarray = class_id - self.tracker_id: Optional[np.ndarray] = tracker_id + Attributes: + xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]` + confidence (np.ndarray): An array of shape `(n,)` containing the confidence scores of the detections. + class_id (np.ndarray): An array of shape `(n,)` containing the class ids of the detections. + tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections. + """ + xyxy: np.ndarray + confidence: np.ndarray + class_id: np.ndarray + tracker_id: Optional[np.ndarray] = None + + def __post_init__(self): n = len(self.xyxy) validators = [ (isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)), @@ -55,7 +55,7 @@ def __len__(self): def __iter__(self): """ - Iterates over the Detections object and yield a tuple of (xyxy, confidence, class_id, tracker_id) for each detection. + Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection. """ for i in range(len(self.xyxy)): yield ( @@ -66,37 +66,68 @@ def __iter__(self): ) @classmethod - def from_yolov5(cls, yolov5_output: np.ndarray): + def from_yolov5(cls, yolov5_detections): """ - Creates a Detections instance from a YOLOv5 output tensor + Creates a Detections instance from a YOLOv5 output Detections Attributes: - yolov5_output (np.ndarray): The output tensor from YOLOv5 + yolov5_detections (yolov5.models.common.Detections): The output Detections instance from YOLOv5 Returns: Example: ```python - >>> from supervision.tools.detections import Detections + >>> import torch + >>> from supervision import Detections - >>> detections = Detections.from_yolov5(yolov5_output) + >>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s') + >>> results = model(frame) + >>> detections = Detections.from_yolov5(results) ``` """ - xyxy = yolov5_output[:, :4] - confidence = yolov5_output[:, 4] - class_id = yolov5_output[:, 5].astype(int) - return cls(xyxy, confidence, class_id) + yolov5_detections_predictions = yolov5_detections.pred[0].cpu().cpu().numpy() + return cls( + xyxy=yolov5_detections_predictions[:, :4], + confidence=yolov5_detections_predictions[:, 4], + class_id=yolov5_detections_predictions[:, 5].astype(int), + ) - def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray]: + @classmethod + def from_yolov8(cls, yolov8_results): + """ + Creates a Detections instance from a YOLOv8 output Results + + Attributes: + yolov8_results (ultralytics.yolo.engine.results.Results): The output Results instance from YOLOv8 + + Returns: + + Example: + ```python + >>> from ultralytics import YOLO + >>> from supervision import Detections + + >>> model = YOLO('yolov8s.pt') + >>> results = model(frame) + >>> detections = Detections.from_yolov8(results) + ``` + """ + return cls( + xyxy=yolov8_results.boxes.xyxy.cpu().numpy(), + confidence=yolov8_results.boxes.conf.cpu().numpy(), + class_id=yolov8_results.boxes.cls.cpu().numpy().astype(int), + ) + + def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[Detections]: """ Filter the detections by applying a mask. Attributes: - mask (np.ndarray): A mask of shape (n,) containing a boolean value for each detection indicating if it should be included in the filtered detections + mask (np.ndarray): A mask of shape `(n,)` containing a boolean value for each detection indicating if it should be included in the filtered detections inplace (bool): If True, the original data will be modified and self will be returned. Returns: - Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to False. None otherwise. + Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to `False`. `None` otherwise. """ if inplace: self.xyxy = self.xyxy[mask] @@ -116,11 +147,49 @@ def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray else None, ) + def get_anchor_coordinates(self, anchor: Position) -> np.ndarray: + """ + Returns the bounding box coordinates for a specific anchor. + + Properties: + anchor (Position): Position of bounding box anchor for which to return the coordinates. + + Returns: + np.ndarray: An array of shape `(n, 2)` containing the bounding box anchor coordinates in format `[x, y]`. + """ + if anchor == Position.CENTER: + return np.array( + [ + (self.xyxy[:, 0] + self.xyxy[:, 2]) / 2, + (self.xyxy[:, 1] + self.xyxy[:, 3]) / 2, + ] + ).transpose() + elif anchor == Position.BOTTOM_CENTER: + return np.array( + [(self.xyxy[:, 0] + self.xyxy[:, 2]) / 2, self.xyxy[:, 3]] + ).transpose() + + raise ValueError(f"{anchor} is not supported.") + + def __getitem__(self, index: np.ndarray) -> Detections: + if isinstance(index, np.ndarray) and index.dtype == np.bool: + return Detections( + xyxy=self.xyxy[index], + confidence=self.confidence[index], + class_id=self.class_id[index], + tracker_id=self.tracker_id[index] + if self.tracker_id is not None + else None, + ) + raise TypeError( + f"Detections.__getitem__ not supported for index of type {type(index)}." + ) + class BoxAnnotator: def __init__( self, - color: Union[Color, ColorPalette], + color: Union[Color, ColorPalette] = ColorPalette.default(), thickness: int = 2, text_color: Color = Color.black(), text_scale: float = 0.5, @@ -148,35 +217,46 @@ def __init__( def annotate( self, - frame: np.ndarray, + scene: np.ndarray, detections: Detections, labels: Optional[List[str]] = None, + skip_label: bool = False, ) -> np.ndarray: """ Draws bounding boxes on the frame using the detections provided. - Attributes: - frame (np.ndarray): The image on which the bounding boxes will be drawn + Parameters: + scene (np.ndarray): The image on which the bounding boxes will be drawn detections (Detections): The detections for which the bounding boxes will be drawn labels (Optional[List[str]]): An optional list of labels corresponding to each detection. If labels is provided, the confidence score of the detection will be replaced with the label. - + skip_label (bool): Is set to True, skips bounding box label annotation. Returns: np.ndarray: The image with the bounding boxes drawn on it """ font = cv2.FONT_HERSHEY_SIMPLEX for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections): + x1, y1, x2, y2 = xyxy.astype(int) color = ( self.color.by_idx(class_id) if isinstance(self.color, ColorPalette) else self.color ) + cv2.rectangle( + img=scene, + pt1=(x1, y1), + pt2=(x2, y2), + color=color.as_bgr(), + thickness=self.thickness, + ) + if skip_label: + continue + text = ( f"{confidence:0.2f}" if (labels is None or len(detections) != len(labels)) else labels[i] ) - x1, y1, x2, y2 = xyxy.astype(int) text_width, text_height = cv2.getTextSize( text=text, fontFace=font, @@ -194,21 +274,14 @@ def annotate( text_background_y2 = y1 cv2.rectangle( - img=frame, - pt1=(x1, y1), - pt2=(x2, y2), - color=color.as_bgr(), - thickness=self.thickness, - ) - cv2.rectangle( - img=frame, + img=scene, pt1=(text_background_x1, text_background_y1), pt2=(text_background_x2, text_background_y2), color=color.as_bgr(), thickness=cv2.FILLED, ) cv2.putText( - img=frame, + img=scene, text=text, org=(text_x, text_y), fontFace=font, @@ -217,4 +290,4 @@ def annotate( thickness=self.text_thickness, lineType=cv2.LINE_AA, ) - return frame + return scene diff --git a/supervision/tools/line_counter.py b/supervision/detection/line_counter.py similarity index 95% rename from supervision/tools/line_counter.py rename to supervision/detection/line_counter.py index f791c9a6f..e40ead8ec 100644 --- a/supervision/tools/line_counter.py +++ b/supervision/detection/line_counter.py @@ -3,12 +3,12 @@ import cv2 import numpy as np +from supervision.detection.core import Detections from supervision.draw.color import Color -from supervision.geometry.dataclasses import Point, Rect, Vector -from supervision.tools.detections import Detections +from supervision.geometry.core import Point, Rect, Vector -class LineCounter: +class LineZone: """ Count the number of objects that cross a line. """ @@ -27,7 +27,7 @@ def __init__(self, start: Point, end: Point): self.in_count: int = 0 self.out_count: int = 0 - def update(self, detections: Detections): + def trigger(self, detections: Detections): """ Update the in_count and out_count for the detections that cross the line. @@ -71,7 +71,7 @@ def update(self, detections: Detections): self.out_count += 1 -class LineCounterAnnotator: +class LineZoneAnnotator: def __init__( self, thickness: float = 2, @@ -103,7 +103,7 @@ def __init__( self.text_offset: float = text_offset self.text_padding: int = text_padding - def annotate(self, frame: np.ndarray, line_counter: LineCounter) -> np.ndarray: + def annotate(self, frame: np.ndarray, line_counter: LineZone) -> np.ndarray: """ Draws the line on the frame using the line_counter provided. diff --git a/supervision/detection/polygon_zone.py b/supervision/detection/polygon_zone.py new file mode 100644 index 000000000..54bbbf625 --- /dev/null +++ b/supervision/detection/polygon_zone.py @@ -0,0 +1,80 @@ +from typing import Optional, Tuple + +import cv2 +import numpy as np + +from supervision import Detections +from supervision.detection.utils import generate_2d_mask +from supervision.draw.color import Color +from supervision.draw.utils import draw_polygon, draw_text +from supervision.geometry.core import Position +from supervision.geometry.utils import get_polygon_center + + +class PolygonZone: + def __init__( + self, + polygon: np.ndarray, + frame_resolution_wh: Tuple[int, int], + triggering_position: Position = Position.BOTTOM_CENTER, + ): + self.polygon = polygon + self.frame_resolution_wh = frame_resolution_wh + self.triggering_position = triggering_position + self.mask = generate_2d_mask(polygon=polygon, resolution_wh=frame_resolution_wh) + self.current_count = 0 + + def trigger(self, detections: Detections) -> np.ndarray: + anchors = ( + np.ceil( + detections.get_anchor_coordinates(anchor=self.triggering_position) + ).astype(int) + - 1 + ) + is_in_zone = self.mask[anchors[:, 1], anchors[:, 0]] + self.current_count = np.sum(is_in_zone) + return is_in_zone.astype(bool) + + +class PolygonZoneAnnotator: + def __init__( + self, + zone: PolygonZone, + color: Color, + thickness: int = 2, + text_color: Color = Color.black(), + text_scale: float = 0.5, + text_thickness: int = 1, + text_padding: int = 10, + ): + self.zone = zone + self.color = color + self.thickness = thickness + self.text_color = text_color + self.text_scale = text_scale + self.text_thickness = text_thickness + self.text_padding = text_padding + self.font = cv2.FONT_HERSHEY_SIMPLEX + self.center = get_polygon_center(polygon=zone.polygon) + + def annotate(self, scene: np.ndarray, label: Optional[str] = None) -> np.ndarray: + annotated_frame = draw_polygon( + scene=scene, + polygon=self.zone.polygon, + color=self.color, + thickness=self.thickness, + ) + + annotated_frame = draw_text( + scene=annotated_frame, + text=str(self.zone.current_count) if label is None else label, + text_anchor=self.center, + background_color=self.color, + text_color=self.text_color, + text_scale=self.text_scale, + text_thickness=self.text_thickness, + text_padding=self.text_padding, + text_font=self.font, + ) + + return annotated_frame diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py new file mode 100644 index 000000000..a3a33490f --- /dev/null +++ b/supervision/detection/utils.py @@ -0,0 +1,20 @@ +from typing import Tuple + +import cv2 +import numpy as np + + +def generate_2d_mask(polygon: np.ndarray, resolution_wh: Tuple[int, int]) -> np.ndarray: + """Generate a 2D mask from a polygon. + + Properties: + polygon (np.ndarray): The polygon for which the mask should be generated, given as a list of vertices. + resolution_wh (Tuple[int, int]): The width and height of the desired resolution. + + Returns: + np.ndarray: The generated 2D mask, where the polygon is marked with `1`'s and the rest is filled with `0`'s. + """ + width, height = resolution_wh + mask = np.zeros((height, width), dtype=np.uint8) + cv2.fillPoly(mask, [polygon], color=1) + return mask diff --git a/supervision/draw/color.py b/supervision/draw/color.py index 842dd2be1..b597831ea 100644 --- a/supervision/draw/color.py +++ b/supervision/draw/color.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import List, Tuple DEFAULT_COLOR_PALETTE = [ @@ -94,11 +94,11 @@ def blue(cls) -> Color: @dataclass class ColorPalette: - colors: List[Color] = field( - default_factory=lambda: [ - Color.from_hex(color_hex) for color_hex in DEFAULT_COLOR_PALETTE - ] - ) + colors: List[Color] + + @classmethod + def default(cls) -> ColorPalette: + return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE) @classmethod def from_hex(cls, color_hex_list: List[str]): diff --git a/supervision/draw/utils.py b/supervision/draw/utils.py index ff6c050c7..e8bcea212 100644 --- a/supervision/draw/utils.py +++ b/supervision/draw/utils.py @@ -1,8 +1,10 @@ +from typing import Optional + import cv2 import numpy as np -from supervision.commons.dataclasses import Point, Rect from supervision.draw.color import Color +from supervision.geometry.core import Point, Rect def draw_line( @@ -11,8 +13,7 @@ def draw_line( """ Draws a line on a given scene. - Attributes: - + Parameters: scene (np.ndarray): The scene on which the line will be drawn start (Point): The starting point of the line end (Point): The end point of the line @@ -46,11 +47,6 @@ def draw_rectangle( Returns: np.ndarray: The scene with the rectangle drawn on it - - Example: - ```python - >>> # TODO: Add example - ``` """ cv2.rectangle( scene, @@ -78,12 +74,6 @@ def draw_filled_rectangle(scene: np.ndarray, rect: Rect, color: Color) -> np.nda Returns: np.ndarray: The scene with the rectangle drawn on it - - Example: - ```python - >>> # TODO: Add example - ``` - """ cv2.rectangle( scene, @@ -93,3 +83,91 @@ def draw_filled_rectangle(scene: np.ndarray, rect: Rect, color: Color) -> np.nda -1, ) return scene + + +def draw_polygon( + scene: np.ndarray, polygon: np.ndarray, color: Color, thickness: int = 2 +) -> np.ndarray: + """Draw a polygon on a scene. + + Attributes: + scene (np.ndarray): The scene to draw the polygon on. + polygon (np.ndarray): The polygon to be drawn, given as a list of vertices. + color (Color): The color of the polygon. + thickness (int, optional): The thickness of the polygon lines, by default 2. + + Returns: + np.ndarray: The scene with the polygon drawn on it. + """ + cv2.polylines( + scene, [polygon], isClosed=True, color=color.as_bgr(), thickness=thickness + ) + return scene + + +def draw_text( + scene: np.ndarray, + text: str, + text_anchor: Point, + text_color: Color = Color.black(), + text_scale: float = 0.5, + text_thickness: int = 1, + text_padding: int = 10, + text_font: int = cv2.FONT_HERSHEY_SIMPLEX, + background_color: Optional[Color] = None, +) -> np.ndarray: + """ + Draw text on a scene. + + This function takes in a 2-dimensional numpy ndarray representing an image or scene, and draws text on it using OpenCV's putText function. The text is anchored at a specified Point, and its appearance can be customized using arguments such as color, scale, and font. An optional background color and padding can be specified to draw a rectangle behind the text. + + Parameters: + scene (np.ndarray): A 2-dimensional numpy ndarray representing an image or scene. + text (str): The text to be drawn. + text_anchor (Point): The anchor point for the text, represented as a Point object with x and y attributes. + text_color (Color, optional): The color of the text. Defaults to black. + text_scale (float, optional): The scale of the text. Defaults to 0.5. + text_thickness (int, optional): The thickness of the text. Defaults to 1. + text_padding (int, optional): The amount of padding to add around the text when drawing a rectangle in the background. Defaults to 10. + text_font (int, optional): The font to use for the text. Defaults to cv2.FONT_HERSHEY_SIMPLEX. + background_color (Color, optional): The color of the background rectangle, if one is to be drawn. Defaults to None. + + Returns: + np.ndarray: The input scene with the text drawn on it. + + Examples: + ```python + >>> scene = np.zeros((100, 100, 3), dtype=np.uint8) + >>> text_anchor = Point(x=50, y=50) + >>> scene = draw_text(scene=scene, text="Hello, world!", text_anchor=text_anchor) + ``` + """ + text_width, text_height = cv2.getTextSize( + text=text, + fontFace=text_font, + fontScale=text_scale, + thickness=text_thickness, + )[0] + text_rect = Rect( + x=text_anchor.x - text_width // 2, + y=text_anchor.y - text_height // 2, + width=text_width, + height=text_height, + ).pad(text_padding) + + if background_color is not None: + scene = draw_filled_rectangle( + scene=scene, rect=text_rect, color=background_color + ) + + cv2.putText( + img=scene, + text=text, + org=(text_anchor.x - text_width // 2, text_anchor.y + text_height // 2), + fontFace=text_font, + fontScale=text_scale, + color=text_color.as_bgr(), + thickness=text_thickness, + lineType=cv2.LINE_AA, + ) + return scene diff --git a/supervision/geometry/dataclasses.py b/supervision/geometry/core.py similarity index 86% rename from supervision/geometry/dataclasses.py rename to supervision/geometry/core.py index b647ec096..fc352f558 100644 --- a/supervision/geometry/dataclasses.py +++ b/supervision/geometry/core.py @@ -1,9 +1,19 @@ from __future__ import annotations from dataclasses import dataclass +from enum import Enum from typing import Tuple +class Position(Enum): + CENTER = "CENTER" + BOTTOM_CENTER = "BOTTOM_CENTER" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + @dataclass class Point: x: float diff --git a/supervision/geometry/utils.py b/supervision/geometry/utils.py new file mode 100644 index 000000000..400e4d755 --- /dev/null +++ b/supervision/geometry/utils.py @@ -0,0 +1,28 @@ +import numpy as np + +from supervision.geometry.core import Point + + +def get_polygon_center(polygon: np.ndarray) -> Point: + """ + Calculate the center of a polygon. + + This function takes in a polygon as a 2-dimensional numpy ndarray and returns the center of the polygon as a Point object. The center is calculated as the mean of the polygon's vertices along each axis, and is rounded down to the nearest integer. + + Parameters: + polygon (np.ndarray): A 2-dimensional numpy ndarray representing the vertices of the polygon. + + Returns: + Point: The center of the polygon, represented as a Point object with x and y attributes. + + Examples: + ```python + >>> from supervision.geometry.utils import get_polygon_center + + >>> vertices = np.array([[0, 0], [0, 1], [1, 1], [1, 0]]) + >>> get_center(vertices) + Point(x=0.5, y=0.5) + ``` + """ + center = np.mean(polygon, axis=0).astype(int) + return Point(x=center[0], y=center[1]) diff --git a/supervision/notebook/utils.py b/supervision/notebook/utils.py index c50c98428..a750daa17 100644 --- a/supervision/notebook/utils.py +++ b/supervision/notebook/utils.py @@ -18,8 +18,10 @@ def show_frame_in_notebook( Examples: ```python - >>> from supervision.notebook import show_frame_in_notebook + >>> from supervision.notebook.utils import show_frame_in_notebook + %matplotlib inline + show_frame_in_notebook(frame, (16, 16)) ``` """ if frame.ndim == 2: diff --git a/supervision/video.py b/supervision/video.py index 92ad0a6e2..5dd807eb8 100644 --- a/supervision/video.py +++ b/supervision/video.py @@ -1,11 +1,13 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Callable, Generator, Optional, Tuple import cv2 import numpy as np +@dataclass class VideoInfo: """ A class to store video information, including width, height, fps and total number of frames. @@ -18,22 +20,22 @@ class VideoInfo: Examples: ```python - >>> from supervision.video import VideoInfo + >>> from supervision import VideoInfo >>> video_info = VideoInfo.from_video_path(video_path='video.mp4') >>> video_info VideoInfo(width=3840, height=2160, fps=25, total_frames=538) + + >>> video_info.resolution_wh + (3840, 2160) ``` """ - def __init__( - self, width: int, height: int, fps: int, total_frames: Optional[int] = None - ): - self.width = width - self.height = height - self.fps = fps - self.total_frames = total_frames + width: int + height: int + fps: int + total_frames: Optional[int] = None @classmethod def from_video_path(cls, video_path: str) -> VideoInfo: @@ -49,7 +51,7 @@ def from_video_path(cls, video_path: str) -> VideoInfo: return VideoInfo(width, height, fps, total_frames) @property - def resolution(self) -> Tuple[int, int]: + def resolution_wh(self) -> Tuple[int, int]: return self.width, self.height @@ -63,8 +65,7 @@ class VideoSink: Examples: ```python - >>> from supervision.video import VideoInfo - >>> from supervision.video import VideoSink + >>> from supervision import VideoInfo, VideoSink >>> video_info = VideoInfo.from_video_path(video_path='source_video.mp4') @@ -85,7 +86,7 @@ def __enter__(self): self.target_path, self.__fourcc, self.video_info.fps, - self.video_info.resolution, + self.video_info.resolution_wh, ) return self @@ -108,7 +109,7 @@ def get_video_frames_generator(source_path: str) -> Generator[np.ndarray, None, Examples: ```python - >>> from supervision.video import get_video_frames_generator + >>> from supervision import get_video_frames_generator >>> for frame in get_video_frames_generator(source_path='source_video.mp4'): ... ... @@ -139,9 +140,9 @@ def process_video( Examples: ```python - >>> from supervision.video import process_video + >>> from supervision import process_video - >>> def process_frame(frame: np.ndarray) -> np.ndarray: + >>> def process_frame(scene: np.ndarray) -> np.ndarray: ... ... >>> process_video( diff --git a/test/detection/__init__.py b/test/detection/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/geometry/test_dataclasses.py b/test/geometry/test_dataclasses.py index 348218846..711407123 100644 --- a/test/geometry/test_dataclasses.py +++ b/test/geometry/test_dataclasses.py @@ -1,6 +1,6 @@ import pytest -from supervision.geometry.dataclasses import Vector, Point +from supervision.geometry.core import Vector, Point @pytest.mark.parametrize(