Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate SAM 2.1 #8610

Open
wants to merge 23 commits into
base: develop
Choose a base branch
from
Open

Integrate SAM 2.1 #8610

wants to merge 23 commits into from

Conversation

hashJoe
Copy link

@hashJoe hashJoe commented Oct 29, 2024

Linked Issues:

Linked Pull Requests:

Summary

To integrate SAM 2.1 into CVAT, I

  1. forked cvat,
  2. created branch feature/sam2,
  3. merged forked repo of @jeanchristopheruel and merged branch develop into feature/sam2,
  4. updated SAM 2.0 to SAM 2.1,
  5. converted decoder part to ONNX using script export_sam21_cvat.py,
  6. split code into backend and frontend,
  7. added CLIENT_PLUGINS argument to pass plugin cvat-ui/plugins/sam2,
  8. and tested the model successfully within CVAT on CPU and GPU.

Motivation and context

This pull request builds upon a #8243 that aimed to integrate SAM 2 into CVAT. The progress on that contribution has been stalled, and this request serves as a continuation of integrating SAM 2.

Main enhancements:

  • Update to SAM 2.1
  • Backend encodes image
  • Decoder converted to ONNX with postprocessing steps
  • Frontend does the decoding

This way the structure of integrating SAM is maintained in SAM 2 with minimal changes.

How has this been tested?

Using the following commands:

CLIENT_PLUGINS=plugins/sam2 CVAT_HOST=localhost CVAT_VERSION=v2.21.2 docker compose -f docker-compose.yml -f docker-compose.dev.yml -f components/serverless/docker-compose.serverless.yml -p cvat up -d --build

# on cpu
./serverless/deploy_cpu.sh serverless/pytorch/facebookresearch/sam2

# on gpu
./serverless/deploy_gpu.sh serverless/pytorch/facebookresearch/sam2

, and applying the model on several images.

Checklist

  • I submit my changes into the develop branch
  • I have created a changelog fragment
  • I have updated the documentation accordingly
  • I have added tests to cover my changes
  • I have linked related issues (see GitHub docs)
  • I have increased versions of npm packages if it is necessary
    (cvat-canvas,
    cvat-core,
    cvat-data and
    cvat-ui)

License

  • I submit my code changes under the same MIT License that covers the project.
    Feel free to contact the maintainers if that's a concern.

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced documentation for the Computer Vision Annotation Tool (CVAT) with updated sections on automatic labeling algorithms and user support channels.
    • Introduced a new plugin for the Segment Anything Model 2.1 (SAM2) to facilitate interactive segmentation tasks.
    • Added serverless function configurations for interactive object segmentation using SAM2.
  • Bug Fixes

    • Corrected formatting issues in the documentation.
  • Chores

    • Improved configurability of the CVAT UI service with new build arguments.

jeanchristopheruel and others added 23 commits July 30, 2024 23:58
Segment Anything 2.0 require to compile a .cu file with nvcc at build time. Hence, a cuda devel baseImage is required to build the nuclio container.
…l as required in Dockerfile.ui for extra plugins, adjust function.yaml and function-gpu.yaml to accommodate SAM2.1 and frontend plugin, add index.tsx and inference.worker.ts, update main.py and model_handler.py accordingly
…since not included in self.predictor.get_image_embedding() and needed in decoder and a new encoder class is added in image_encoder.py
…changes in index.tsx, and accommodate those inputs in inference.worker.ts and add sam2.1_hiera_large.decoder where postprocessing steps are added to undergo minimal changes in sam2 plugin
Copy link
Contributor

coderabbitai bot commented Oct 29, 2024

Walkthrough

The changes in this pull request involve updates to the documentation and the introduction of new functionality for the Computer Vision Annotation Tool (CVAT). The README.md file has been enhanced with new sections and clarifications. Additionally, several new files have been added, including a plugin for the Segment Anything Model 2.1, web worker for inference tasks, and configurations for serverless functions. These changes collectively improve the tool's capabilities, particularly in object segmentation and model inference.

Changes

File Path Change Summary
README.md Added new sections and updated existing content, including expanded details on algorithms and licensing clarifications.
cvat-ui/plugins/sam2/src/ts/index.tsx Introduced SAM2Plugin interface, added methods for plugin functionality, and error handling.
cvat-ui/plugins/sam2/src/ts/inference.worker.ts Added web worker for inference tasks, defined actions and interfaces for communication with the main thread.
docker-compose.dev.yml Added CLIENT_PLUGINS argument to cvat_ui service for enhanced configurability.
serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml Defined serverless function configuration for interactive object segmentation using SAM2.
serverless/pytorch/facebookresearch/sam2/nuclio/function.yaml Introduced configuration for a serverless function for interactive segmentation with SAM2.
serverless/pytorch/facebookresearch/sam2/nuclio/image_encoder.py Added SAM2Encoder class for processing image tensors through the SAM2 model.
serverless/pytorch/facebookresearch/sam2/nuclio/main.py Introduced main serverless function for image processing, including context initialization and event handling.
serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py Added ModelHandler class for managing SAM2 model loading and processing of images.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant CVAT
    participant SAM2Plugin
    participant InferenceWorker

    User->>CVAT: Upload Image
    CVAT->>SAM2Plugin: Process Image
    SAM2Plugin->>InferenceWorker: Send Image for Inference
    InferenceWorker->>InferenceWorker: Run Model
    InferenceWorker-->>SAM2Plugin: Return Inference Results
    SAM2Plugin-->>CVAT: Update with Results
    CVAT-->>User: Display Segmentation Results
Loading

🐇 "In the garden where algorithms bloom,
New sections added to dispel the gloom.
Plugins and workers, all in a row,
Helping us segment, watch the magic flow!
With each little change, our tool grows bright,
Hopping along, bringing joy and delight!" 🐇


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

Quality Gate Failed Quality Gate failed

Failed conditions
32.3% Duplication on New Code (required ≤ 3%)

See analysis details on SonarCloud

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 19

🧹 Outside diff range and nitpick comments (11)
serverless/pytorch/facebookresearch/sam2/nuclio/image_encoder.py (2)

1-13: LGTM! Consider adding class docstring.

The class structure and initialization look good. Type hints are properly used, and the inheritance from torch.nn.Module is appropriate.

Add a docstring to describe the purpose and usage of the SAM2Encoder class:

 class SAM2Encoder(torch.nn.Module):
+    """SAM 2.1 image encoder that processes input images through the backbone and prepares features for the mask decoder.
+
+    Args:
+        sam2_model (SAM2Base): Base SAM 2.1 model containing the image encoder and mask decoder components.
+    """
     def __init__(self, sam2_model: SAM2Base) -> None:

14-34: Improve type hints and documentation.

The method could benefit from more specific type hints and comprehensive documentation.

-    def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
+    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Process input image through SAM encoder and prepare multi-scale features.
+
+        Args:
+            x (torch.Tensor): Input image tensor of shape (1, 3, H, W)
+
+        Returns:
+            tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Multi-scale features at three
+                different resolutions, ordered from highest to lowest resolution.
+                Each tensor has shape (1, C, H_i, W_i) where H_i, W_i are the
+                spatial dimensions at scale i.
+        """
serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (2)

1-10: Remove unnecessary empty line

There's an extra empty line at line 11 that can be removed to maintain consistent spacing.


12-38: Consider architectural improvements for serverless environment

As this is running in a serverless environment, consider these architectural improvements:

  1. Implement async processing using async/await for better scalability
  2. Add a health check method to verify model status
  3. Consider implementing model warmup to improve cold start performance

Example health check implementation:

async def health_check(self):
    try:
        # Verify model and device status
        if self.device.type == 'cuda':
            assert torch.cuda.is_available()
        # Run a small inference to ensure model is responsive
        dummy_input = torch.zeros((1, 3, 64, 64), device=self.device)
        with torch.inference_mode():
            _ = self.sam2_encoder(dummy_input)
        return True
    except Exception as e:
        logger.error(f"Health check failed: {e}")
        return False
serverless/pytorch/facebookresearch/sam2/nuclio/function.yaml (2)

12-17: Update help message and example GIF for SAM 2.1

The help message and example GIF should be specific to SAM 2.1:

  1. The help message could better describe SAM 2.1's unique capabilities and improvements over SAM 2.0
  2. The animated GIF currently shows an HRNet example (hrnet_example.gif) which should be replaced with a SAM 2.1-specific demonstration

20-23: Consider updating Python runtime and event timeout

  1. Python 3.8 is approaching end-of-life. Consider upgrading to Python 3.10+ for better performance and security updates.
  2. The 30-second event timeout might be insufficient for processing large images or batch requests. Consider increasing based on your performance testing results.
serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (1)

16-17: Update demo GIF and help message for SAM 2.1

The animated GIF currently points to an HRNet example. Consider updating it with a SAM 2.1-specific demonstration. Additionally, the help message could be more descriptive about SAM 2.1's specific capabilities and limitations.

-    animated_gif: https://raw.githubusercontent.com/cvat-ai/cvat/develop/site/content/en/images/hrnet_example.gif
-    help_message: The interactor allows to get a mask of an object using at least one positive, and any negative points inside it
+    animated_gif: https://raw.githubusercontent.com/cvat-ai/cvat/develop/site/content/en/images/sam2_example.gif
+    help_message: SAM 2.1 interactor generates high-quality object masks from user-provided prompts. Supports positive/negative points and optional bounding box input for improved accuracy.
docker-compose.dev.yml (1)

101-101: Consider adding a default value for CLIENT_PLUGINS.

The addition of CLIENT_PLUGINS build argument is correct, but consider providing a default empty value to ensure backward compatibility and prevent build failures when the variable is not set.

-        CLIENT_PLUGINS: ${CLIENT_PLUGINS}
+        CLIENT_PLUGINS: ${CLIENT_PLUGINS:-}
README.md (1)

194-195: Consider adding ONNX conversion documentation.

Since the PR objectives mention that the decoder part was converted to ONNX format using a specific export script, it would be helpful to add documentation about this process, either in the serverless function's README or by adding a note in the table.

Example addition:

| [Segment Anything 2.1](/serverless/pytorch/facebookresearch/sam2/nuclio/)                               | interactor | PyTorch    | ✔️  | ✔️  |
+| [Segment Anything 2.1 (ONNX)](/serverless/pytorch/facebookresearch/sam2/nuclio/onnx)                   | interactor | ONNX       | ✔️  | ✔️  |
serverless/pytorch/facebookresearch/sam2/nuclio/main.py (1)

27-27: Remove Unused Variable image_

The variable image_ is assigned but not used afterward. Consider removing it to clean up the code.

Apply this diff to remove the unused variable:

-high_res_feats_0, high_res_feats_1, image_embed, image_ = context.user_data.model.handle(image)
+high_res_feats_0, high_res_feats_1, image_embed = context.user_data.model.handle(image)
cvat-ui/plugins/sam2/src/ts/inference.worker.ts (1)

54-55: Avoid disabling ESLint rules without justification

The line // eslint-disable-next-line no-restricted-globals disables the ESLint rule no-restricted-globals. If possible, refactor the code to avoid using restricted globals, or provide a comment explaining the necessity of disabling this rule to improve code readability and maintainability.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 9982156 and 0a8d038.

📒 Files selected for processing (9)
  • README.md (1 hunks)
  • cvat-ui/plugins/sam2/src/ts/index.tsx (1 hunks)
  • cvat-ui/plugins/sam2/src/ts/inference.worker.ts (1 hunks)
  • docker-compose.dev.yml (1 hunks)
  • serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (1 hunks)
  • serverless/pytorch/facebookresearch/sam2/nuclio/function.yaml (1 hunks)
  • serverless/pytorch/facebookresearch/sam2/nuclio/image_encoder.py (1 hunks)
  • serverless/pytorch/facebookresearch/sam2/nuclio/main.py (1 hunks)
  • serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (1 hunks)
🔇 Additional comments (12)
serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (1)

29-38: 🛠️ Refactor suggestion

Consider adding input size limits and memory management

The handle method could benefit from additional safeguards:

  1. Add maximum input size validation to prevent memory issues
  2. Consider adding memory cleanup after processing
  3. Add more comprehensive input validation

Let's check if there are any memory-related issues reported in similar implementations:

     def handle(self, image):
+        max_image_size = 1024 * 1024 * 4  # 4MP limit
         with torch.inference_mode():
             assert isinstance(image, Image)
+            # Validate image size
+            w, h = image.size
+            if w * h > max_image_size:
+                raise ValueError(f"Image too large: {w}x{h}. Maximum size allowed: {max_image_size} pixels")
+
             input_image = self._transforms(image)
             input_image = input_image[None, ...].to(self.device)
             assert (
                 len(input_image.shape) == 4 and input_image.shape[1] == 3
             ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
             high_res_feats_0, high_res_feats_1, image_embed = self.sam2_encoder(input_image)
-            return high_res_feats_0, high_res_feats_1, image_embed, input_image
+            # Clean up CUDA memory if needed
+            if self.device.type == 'cuda':
+                torch.cuda.empty_cache()
+            return high_res_feats_0, high_res_feats_1, image_embed, input_image
serverless/pytorch/facebookresearch/sam2/nuclio/function.yaml (2)

44-45: Verify CUDA configuration for SAM2 installation

The SAM2 installation has CUDA disabled (SAM2_BUILD_CUDA=0). If this is the CPU-only version, consider creating a separate GPU configuration file (function-gpu.yaml) with CUDA enabled for better performance on GPU instances.


53-58: Review resource limits and timeouts

Please verify these configuration values based on your performance testing:

  1. maxWorkers: 2 might need adjustment based on expected load and resource availability
  2. maxRequestBodySize: 33554432 (32MB) might be insufficient for high-resolution images
  3. workerAvailabilityTimeoutMilliseconds: 10000 (10s) seems short, consider increasing if worker initialization takes longer

Run the following script to check for similar configurations in other functions:

serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (3)

69-74: LGTM: Platform configuration is well-defined

The platform settings with automatic restart policy and proper volume mounting are appropriate for a GPU-enabled service.


59-67: Review resource constraints for production workloads

The current configuration might be restrictive for production use:

  1. 32MB request body limit might not accommodate high-resolution images
  2. 10-second worker availability timeout is quite aggressive
  3. Single worker configuration could become a bottleneck

Let's check typical request sizes in the codebase:


21-23: Consider runtime and timeout adjustments

  1. Python 3.8 is aging - consider upgrading to Python 3.10+ for better performance and features.
  2. The 30-second timeout might be insufficient for processing large images or batch requests.

Let's check if Python 3.10+ is used elsewhere in the project:

docker-compose.dev.yml (1)

101-101: Verify documentation for the new build argument.

The new CLIENT_PLUGINS build argument needs to be documented to help users understand its purpose and usage, especially in the context of the SAM 2.1 plugin integration.

README.md (1)

194-195: LGTM! SAM 2.1 entry added correctly.

The addition of Segment Anything 2.1 to the serverless functions table is well-formatted and accurately reflects its capabilities as an interactive segmentation model supporting both CPU and GPU environments.

cvat-ui/plugins/sam2/src/ts/inference.worker.ts (2)

29-29: Clarify the use of the index signature in 'DecodeBody' interface

The index signature readonly [name: string]: Tensor; in the DecodeBody interface allows any string property to be a Tensor. Ensure this is intentional to accept dynamic keys, as it might mask unintended extra properties or typos in property names.


55-55: Confirm validity of worker context check

The condition if ((self as any).importScripts) is used to determine if the code is running in a web worker context. Verify that this check is reliable across all target browsers and environments, or consider using a more standard method such as typeof importScripts === 'function'.

cvat-ui/plugins/sam2/src/ts/index.tsx (2)

316-343: Avoid overwriting 'worker.onmessage'; centralize message handling

Overwriting the worker's onmessage handler here can interfere with other message handling logic. As previously mentioned, it's better to have a single onmessage handler for the worker.

Refer to the previous suggestion on centralizing the onmessage handler to handle different actions appropriately.


355-355: Verify the accessibility of the model file path

The model URL is set to /assets/sam2.1_hiera_large.decoder.onnx. Ensure that this file exists at the specified path and is accessible in the production environment.

Run the following script to check if the model file exists:

✅ Verification successful

Let me search for more context about how this model file is handled in the codebase.


Let me check the webpack configuration to understand how assets are handled during the build process.


Based on the webpack configuration, I can see that ONNX model files from plugin assets are handled by the CopyPlugin configuration:

{
    from: 'plugins/**/assets/*.(onnx|js)',
    to  : 'assets/[name][ext]',
}

Let me verify if the model file exists in the plugin's assets directory:


Model file is present and correctly configured

The model file sam2.1_hiera_large.decoder.onnx exists in the correct location at cvat-ui/plugins/sam2/assets/. The webpack configuration correctly handles copying ONNX files from plugin assets to the public assets directory during build. The model URL path /assets/sam2.1_hiera_large.decoder.onnx will be valid in production.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the model file exists in the assets directory

# Test: Check for the existence of the model file
if [ -f ./assets/sam2.1_hiera_large.decoder.onnx ]; then
    echo "Model file exists."
else
    echo "Model file does not exist."
fi

Length of output: 116


Script:

#!/bin/bash
# Search for references to the model file name or related configuration
rg -g '!*.{lock,min.js}' "sam2.1_hiera_large.decoder.onnx" -C 2

# Search for any webpack or build configuration files that might handle assets
fd -e config.js -e webpack.config.js

# Look for documentation about model files
fd README -t f

Length of output: 1136


Script:

#!/bin/bash
# Check webpack configuration for asset handling
cat cvat-ui/webpack.config.js

# Look for any documentation about model setup in the plugin's directory
fd -t f -p "**/sam2/**/README*"

# Check if there are any build or asset related instructions in the main README
head -n 100 README.md

Length of output: 13083


Script:

#!/bin/bash
# Check if the model exists in the plugin's assets directory
fd -t f -e onnx . cvat-ui/plugins/sam2/

# Also check if there are any build instructions or documentation specific to the SAM2 plugin
rg -g '!*.{lock,min.js}' -C 3 "model.*\.onnx" cvat-ui/plugins/sam2/

Length of output: 770

self.image_encoder = sam2_model.image_encoder
self.no_mem_embed = sam2_model.no_mem_embed

def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation and shape assertions.

The forward method should validate input tensor dimensions and device placement to fail fast with clear error messages.

     def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
+        if x.dim() != 4:
+            raise ValueError(f"Expected 4D input tensor (B,C,H,W), got {x.dim()}D")
+        if x.size(0) != 1:
+            raise ValueError(f"Expected batch size 1, got {x.size(0)}")
+        if x.size(1) != 3:
+            raise ValueError(f"Expected 3 channels (RGB), got {x.size(1)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
if x.dim() != 4:
raise ValueError(f"Expected 4D input tensor (B,C,H,W), got {x.dim()}D")
if x.size(0) != 1:
raise ValueError(f"Expected batch size 1, got {x.size(0)}")
if x.size(1) != 3:
raise ValueError(f"Expected 3 channels (RGB), got {x.size(1)}")

Comment on lines +23 to +25
self.sam2_checkpoint = "/opt/nuclio/sam2/sam2.1_hiera_large.pt"
self.sam2_model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider making model paths configurable

The model checkpoint and configuration paths are hardcoded. Consider making these configurable through environment variables or configuration files for better flexibility and maintainability.

-        self.sam2_checkpoint = "/opt/nuclio/sam2/sam2.1_hiera_large.pt"
-        self.sam2_model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
+        self.sam2_checkpoint = os.getenv('SAM2_CHECKPOINT', "/opt/nuclio/sam2/sam2.1_hiera_large.pt")
+        self.sam2_model_cfg = os.getenv('SAM2_CONFIG', "configs/sam2.1/sam2.1_hiera_l.yaml")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.sam2_checkpoint = "/opt/nuclio/sam2/sam2.1_hiera_large.pt"
self.sam2_model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device)
self.sam2_checkpoint = os.getenv('SAM2_CHECKPOINT', "/opt/nuclio/sam2/sam2.1_hiera_large.pt")
self.sam2_model_cfg = os.getenv('SAM2_CONFIG', "configs/sam2.1/sam2.1_hiera_l.yaml")
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device)

Comment on lines +25 to +27
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device)
self.sam2_encoder = SAM2Encoder(sam2_model) if torch.cuda.is_available() else SAM2Encoder(sam2_model).cpu()
self._transforms = SAM2Transforms(resolution=sam2_model.image_size, mask_threshold=0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling and logging for model initialization

The model loading process lacks error handling and logging. Consider adding try-catch blocks and logging statements to handle potential failures gracefully.

+        import logging
+        logger = logging.getLogger(__name__)
+        
+        try:
             sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device)
             self.sam2_encoder = SAM2Encoder(sam2_model) if torch.cuda.is_available() else SAM2Encoder(sam2_model).cpu()
+            logger.info(f"SAM2 model loaded successfully on {self.device}")
+        except FileNotFoundError as e:
+            logger.error(f"Failed to load SAM2 model: {e}")
+            raise
+        except Exception as e:
+            logger.error(f"Unexpected error during model initialization: {e}")
+            raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device)
self.sam2_encoder = SAM2Encoder(sam2_model) if torch.cuda.is_available() else SAM2Encoder(sam2_model).cpu()
self._transforms = SAM2Transforms(resolution=sam2_model.image_size, mask_threshold=0.0)
import logging
logger = logging.getLogger(__name__)
try:
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device)
self.sam2_encoder = SAM2Encoder(sam2_model) if torch.cuda.is_available() else SAM2Encoder(sam2_model).cpu()
logger.info(f"SAM2 model loaded successfully on {self.device}")
except FileNotFoundError as e:
logger.error(f"Failed to load SAM2 model: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error during model initialization: {e}")
raise
self._transforms = SAM2Transforms(resolution=sam2_model.image_size, mask_threshold=0.0)

Comment on lines +41 to +42
- kind: RUN
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Optimize Docker image size

Consider cleaning up apt cache after installation to reduce the image size:

- value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6
+ value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6 && \
+   apt-get clean && \
+   rm -rf /var/lib/apt/lists/*
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- kind: RUN
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6
- kind: RUN
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6 && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

Comment on lines +47 to +48
- kind: RUN
value: curl -O https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add checksum verification for model weights

The model weights are downloaded without verifying their integrity. Consider adding SHA256 checksum verification:

- value: curl -O https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
+ value: |
+   curl -O https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt && \
+   echo "expected-sha256-checksum sam2.1_hiera_large.pt" | sha256sum -c

Committable suggestion was skipped due to low confidence.

Comment on lines +132 to +134
if (typeof query.jobID === 'number') {
[plugin.data.jobs[query.jobID]] = results;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect assignment to 'plugin.data.jobs[query.jobID]'

The assignment [plugin.data.jobs[query.jobID]] = results; is incorrect because square brackets are used for array destructuring, not for object property assignment. This can lead to unexpected behavior.

Apply this fix to correctly assign the result:

- [plugin.data.jobs[query.jobID]] = results;
+ plugin.data.jobs[query.jobID] = results[0];

Assuming that results is an array and you want to assign the first element to plugin.data.jobs[query.jobID].

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (typeof query.jobID === 'number') {
[plugin.data.jobs[query.jobID]] = results;
}
if (typeof query.jobID === 'number') {
plugin.data.jobs[query.jobID] = results[0];
}

Comment on lines +222 to +225
plugin.data.jobs = {
// we do not need to store old job instances
[job.id]: job,
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid overwriting 'plugin.data.jobs' object with a single job

Overwriting the entire plugin.data.jobs object with a single job can lead to loss of previously stored jobs. This can cause issues if other parts of the code rely on accessing other job instances.

Modify the code to update the job entry without overwriting the entire object:

- plugin.data.jobs = {
-     // we do not need to store old job instances
-     [job.id]: job,
- };
+ plugin.data.jobs[job.id] = job;

If it's intentional to discard old job instances, consider resetting the object explicitly and documenting this behavior.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
plugin.data.jobs = {
// we do not need to store old job instances
[job.id]: job,
};
plugin.data.jobs[job.id] = job;

Comment on lines +281 to +282
const isLowResMaskSuitable = JSON
.stringify(clicks.slice(0, -1)) === JSON.stringify(plugin.data.lastClicks);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle edge cases when comparing 'clicks' arrays

Using clicks.slice(0, -1) when clicks has 0 or 1 elements may not behave as expected, since slicing with -1 can return an empty array.

Ensure that the comparison accounts for cases when clicks has fewer than two elements:

const isLowResMaskSuitable = clicks.length > 1 &&
    JSON.stringify(clicks.slice(0, -1)) === JSON.stringify(plugin.data.lastClicks);

This addition checks that there are at least two clicks before performing the comparison.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const isLowResMaskSuitable = JSON
.stringify(clicks.slice(0, -1)) === JSON.stringify(plugin.data.lastClicks);
const isLowResMaskSuitable = clicks.length > 1 &&
JSON.stringify(clicks.slice(0, -1)) === JSON.stringify(plugin.data.lastClicks);

Comment on lines +268 to +271
if (obj_bbox.length) {
clicks.push({ clickType: 2, x: obj_bbox[0][0], y: obj_bbox[0][1] });
clicks.push({ clickType: 3, x: obj_bbox[1][0], y: obj_bbox[1][1] });
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add check for 'obj_bbox' length before accessing its elements

Accessing obj_bbox[1] without checking the length of obj_bbox can result in an error if it contains fewer than two elements.

Add a condition to ensure that obj_bbox has at least two elements:

if (obj_bbox.length >= 2) {
    clicks.push({ clickType: 2, x: obj_bbox[0][0], y: obj_bbox[0][1] });
    clicks.push({ clickType: 3, x: obj_bbox[1][0], y: obj_bbox[1][1] });
}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (obj_bbox.length) {
clicks.push({ clickType: 2, x: obj_bbox[0][0], y: obj_bbox[0][1] });
clicks.push({ clickType: 3, x: obj_bbox[1][0], y: obj_bbox[1][1] });
}
if (obj_bbox.length >= 2) {
clicks.push({ clickType: 2, x: obj_bbox[0][0], y: obj_bbox[0][1] });
clicks.push({ clickType: 3, x: obj_bbox[1][0], y: obj_bbox[1][1] });
}

Comment on lines +292 to +306
function toMatImage(input: number[], width: number, height: number): number[][] {
const image = Array(height).fill(0);
for (let i = 0; i < image.length; i++) {
image[i] = Array(width).fill(0);
}

for (let i = 0; i < input.length; i++) {
const row = Math.floor(i / width);
const col = i % width;
image[row][col] = input[i] > 0 ? 255 : 0;
}

return image;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Optimize 'toMatImage' function for better performance

The current implementation initializes a 2D array and fills it using nested loops, which can be inefficient for large images.

Consider flattening the data processing or using more efficient data structures:

function toMatImage(input: number[], width: number, height: number): number[][] {
-    const image = Array(height).fill(0);
-    for (let i = 0; i < image.length; i++) {
-        image[i] = Array(width).fill(0);
-    }
+    const image = new Array(height);
+    for (let i = 0; i < height; i++) {
+        image[i] = input.slice(i * width, (i + 1) * width).map(value => (value > 0 ? 255 : 0));
+    }
     return image;
}

This approach avoids initializing the array with zeros and directly maps the input values to the image array.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
function toMatImage(input: number[], width: number, height: number): number[][] {
const image = Array(height).fill(0);
for (let i = 0; i < image.length; i++) {
image[i] = Array(width).fill(0);
}
for (let i = 0; i < input.length; i++) {
const row = Math.floor(i / width);
const col = i % width;
image[row][col] = input[i] > 0 ? 255 : 0;
}
return image;
}
function toMatImage(input: number[], width: number, height: number): number[][] {
const image = new Array(height);
for (let i = 0; i < height; i++) {
image[i] = input.slice(i * width, (i + 1) * width).map(value => (value > 0 ? 255 : 0));
}
return image;
}

@jeanchristopheruel
Copy link

Very nice work @hashJoe ! I think the next big milestone would be to integrate encoder in frontend to unlock decentralized tracking capabilities for video annotation (this is mostly all SAM2 is about)!

@hashJoe
Copy link
Author

hashJoe commented Oct 29, 2024

Very nice work @hashJoe ! I think the next big milestone would be to integrate encoder in frontend to unlock decentralized tracking capabilities for video annotation (this is mostly all SAM2 is about)!

@jeanchristopheruel I need to look into the video prediction and check how to implement it in CVAT. How would integrating the encoder in frontend help? There is a tracking model example which is done in backend (both encoder and decoder), here. Any insights would be helpful!

@jeanchristopheruel
Copy link

@hashJoe I beleve Sam2 encoder is lightweight enough to be supported by the frontend (I think it is 1Gb or so). Porting it to the frontend would remove the need of an inference backend server. Also, it would reduce latency associated to the request containing the state and returning the embeddings. Video annotation faster than ever. Each user have its own internal tracking state, which is a memory embedding in Sam2.

@Youho99
Copy link

Youho99 commented Nov 4, 2024

I'm following this PR

@corkwing
Copy link

corkwing commented Nov 5, 2024

@jeanchristopheruel I need to look into the video prediction and check how to implement it in CVAT. How would integrating the encoder in frontend help? There is a tracking model example which is done in backend (both encoder and decoder), here. Any insights would be helpful!

SAM 2 for video tracking would be really transform our annotation workflow (fish monitoring)! I hope this is possible to implement in CVAT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants