Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weird ouput from SamPredictor #13

Open
rongjhan opened this issue Aug 22, 2024 · 1 comment
Open

weird ouput from SamPredictor #13

rongjhan opened this issue Aug 22, 2024 · 1 comment

Comments

@rongjhan
Copy link

i implement predict in two ways one is like demo in eval.py in repository , the other one is use Sampredictor
but i get different output
I think Sampredictor should transform ndarray to tensor first , then just all the same operation after that
i can't tell where i make mistake
image
image

import numpy as np
import torch.types
from robust_segment_anything import SamPredictor, sam_model_registry
from robust_segment_anything.utils.transforms import ResizeLongestSide 
from PIL import Image
from typing import Literal



ckpt_cache = r"path\to\robustsam_checkpoint_b.pth"
img_path =  r"path\to\rain.jpg"   # rain.jpg can be download from demo_images 
input_boxes =[[2.47010520553918, 1.390010846867655, 267.41756993667735, 296.76731580624437]] 

def sam_robust_predict(
    raw_image: Image.Image,
    input_points = None,
    input_labels = None,
    input_boxes = None,
    model_size: Literal["base", "large", "huge"] = "large",
    device: Literal["cpu", "cuda"] = "cpu",
) -> Image.Image:


    model = sam_model_registry[f"vit_{model_size[0:1]}"](None, checkpoint=ckpt_cache)
    sam_transform = ResizeLongestSide(model.image_encoder.img_size)
    model = model.to(device)

    if raw_image.mode != "RGB":
        raw_image = raw_image.convert("RGB")

    data_dict = {}

    #trasnform image
    image = np.array(raw_image, dtype=np.uint8)
    image_t = torch.tensor(image, dtype=torch.uint8).unsqueeze(0).to(device)
    image_t = torch.permute(image_t, (0, 3, 1, 2))
    image_t_transformed = sam_transform.apply_image_torch(image_t.float())
    
    data_dict['image'] = image_t_transformed


    # prompt
    np_input_boxes = np.array(input_boxes) if input_boxes else None
    np_input_points = np.array(input_points) if input_points else None
    np_input_labels = np.array(input_labels) if input_labels else None


    #handle box prompt
    if np_input_boxes is not None:
        box_t = torch.Tensor(input_boxes).unsqueeze(0).to(device)
        data_dict['boxes'] = sam_transform.apply_boxes_torch(box_t, image_t.shape[-2:]).unsqueeze(0)


    #handle point prompt
    if np_input_points is not None:
        input_label = torch.Tensor(np_input_labels).to(device)
        point_t = torch.Tensor(np_input_points).to(device)
        data_dict['point_coords'] = sam_transform.apply_coords_torch(point_t, image_t.shape[-2:]).unsqueeze(0)
        data_dict['point_labels'] = input_label.unsqueeze(0)

    data_dict['original_size'] = image_t.shape[-2:]
    with torch.no_grad():   
        batched_output = model.predict(None, [data_dict], multimask_output=False, return_logits=False) 

    output_mask = batched_output[0]['masks']
    h, w = output_mask.shape[-2:]
    img = Image.fromarray(output_mask.reshape(h, w).numpy().astype(np.uint8)*255)
    img.show()
    return img


def sam_robust_predict2(
    raw_image: Image.Image,
    input_points = None,
    input_labels = None,
    input_boxes = None,
    model_size: Literal["base", "large", "huge"] = "base",
    device: Literal["cpu", "cuda"] = "cpu",
) -> Image.Image:

    sam = sam_model_registry[f"vit_{model_size[0:1]}"](None, checkpoint=ckpt_cache)

    sam.eval()
    predictor = SamPredictor(sam)

    if raw_image.mode != "RGB":
        raw_image = raw_image.convert("RGB")

    predictor.set_image(np.array(raw_image, dtype=np.uint8))

    # prompt
    np_input_boxes = np.array(input_boxes) if input_boxes else None
    np_input_points = np.array(input_points) if input_points else None
    np_input_labels = np.array(input_labels) if input_labels else None

    masks, scores, logits = predictor.predict(
        point_coords=np_input_points,
        point_labels=np_input_labels,
        box=np_input_boxes,
        multimask_output=False,
        # return_logits=True
    )

    h, w = masks.shape[-2:] 
    img = Image.fromarray(masks.reshape(h, w).numpy().astype(np.uint8)*255)
    img.show()
    return img


sam_robust_predict(
    Image.open(img_path),
    None,
    None,
    input_boxes=input_boxes,
    model_size="base",
)

sam_robust_predict2(
    Image.open(img_path),
    None,
    None,
    input_boxes=input_boxes,
    model_size="base",
)
@robustsam
Copy link
Owner

Will the same issue occur when using robustsam_checkpoint_l.pth and robustsam_checkpoint_h.pth?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants