Visual Prompting#
Description#
Visual prompting and zero-shot visual prompting segment objects in images using weak supervision such as point prompts. Standard visual prompting generates masks from prompts within the same image. Zero-shot visual prompting captures prompt-supervised features on one image and then segments other images with those features without additional prompts.
Models#
The visual prompting pipeline uses two models: an encoder and a decoder. The encoder consumes an image and produces features. The decoder consumes prepared prompt inputs and outputs segmentation masks plus auxiliary results.
Encoder parameters#
The following parameters can be provided via Python API or RT Info embedded into an OpenVINO model:
image_size(int): Encoder native input resolution. The input is expected to have a 1:1 aspect ratio.
Decoder parameters#
The following parameters can be provided via Python API or RT Info embedded into an OpenVINO model:
image_size(int): Encoder native input resolution. The input is expected to have a 1:1 aspect ratio.mask_threshold(float): Threshold for generating hard predictions from output soft masks.embed_dim(int): Size of the output embedding. This should match the real output size.
OpenVINO Model Specifications#
Encoder inputs#
A single NCHW tensor representing a batch of images.
Encoder outputs#
A single NDHW tensor, where D is the embedding dimension and HW is the output feature spatial resolution.
Decoder inputs#
Decoder OpenVINO model should have the following named inputs:
image_embeddings(B, D, H, W): Embeddings obtained with encoder.point_coords(B, N, 2): 2D input prompts in XY format.point_labels(B, N): Integer labels of input point prompts.mask_input(B, 1, H, W): Mask for input embeddings.has_mask_input(B, 1): 0/1 flag enabling or disabling applyingmask_input.ori_shape(B, 2): Resolution of the original image used as input to the encoder wrapper.
Decoder outputs#
upscaled_masks(B, N, H, W): Masks upscaled toori_shape.iou_predictions(B, N): IoU predictions for the output masks.low_res_masks(B, N, H, W): Masks in feature resolution.
Examples#
- class model_api.models.visual_prompting.Prompt(data, label)#
Bases:
NamedTupleCreate new instance of Prompt(data, label)
- data: ndarray#
Alias for field number 0
- label: int | ndarray#
Alias for field number 1
- class model_api.models.visual_prompting.SAMLearnableVisualPrompter(encoder_model, decoder_model, reference_features=None, threshold=0.65)#
Bases:
objectA wrapper that provides ZSL Visual Prompting workflow. To obtain segmentation results, one should run learn() first to obtain the reference features, or use previously generated ones.
Initializes ZSL pipeline.
- Parameters:
encoder_model (
SAMImageEncoder) – initialized decoder wrapperdecoder_model (
SAMDecoder) – initialized encoder wrapperreference_features (
VisualPromptingFeatures|None) – Previously generated reference features. Once the features are passed, one can skip learn() method, and start predicting masks right away. Defaults to None.threshold (
float) – Threshold to match vs reference features on infer(). Greater value means a0.65. (stricter matching. Defaults to)
- __call__(image, reference_features=None, apply_masks_refinement=True)#
A wrapper of the SAMLearnableVisualPrompter.infer() method
- Return type:
ZSLVisualPromptingResult
- has_reference_features()#
Checks if reference features are stored in the object state.
- Return type:
bool
- infer(image, reference_features=None, apply_masks_refinement=True)#
Obtains masks by already prepared reference features.
Reference features can be obtained with SAMLearnableVisualPrompter.learn() and passed as an argument. If the features are not passed, instance internal state will be used as a source of the features.
- Parameters:
image (
ndarray) – HWC-shaped imagereference_features (
VisualPromptingFeatures|None) – Reference features object obtained during previous learn() calls. If not passed, object internal state is used, which reflects the last learn() call. Defaults to None.apply_masks_refinement (
bool) – Flag controlling additional refinement stage on inference.enabled (Once)
decoder (decoder will be launched 2 extra times to refine the masks obtained with the first)
True. (call. Defaults to)
- Returns:
- Mapping label -> predicted mask. Each mask object contains a list of binary masks,
and a list of related prompts. Each binary mask corresponds to one prompt point. Class mask can be obtained by applying OR operation to all mask corresponding to one label.
- Return type:
ZSLVisualPromptingResult
- learn(image, boxes=None, points=None, polygons=None, reset_features=False)#
Executes learn stage of SAM ZSL pipeline.
Reference features are updated according to newly arrived prompts. Features corresponding to the same labels are overridden during consequent learn() calls.
- Parameters:
image (
ndarray) – HWC-shaped imageboxes (
list[Prompt] |None) – Prompts containing bounding boxes (in XYXY torchvision format) and their labels (ints, one per box). Defaults to None.points (
list[Prompt] |None) – Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None.polygons (
list[Prompt] |None) – (list[Prompt] | None): Prompts containing polygons (a sequence of points in XY format) and their labels (ints, one per polygon). Polygon prompts are used to mask out the source features without implying decoder usage. Defaults to None.reset_features (
bool) – Forces learning from scratch. Defaults to False.
- Returns:
- return values are the updated VPT reference features and
reference masks.
The shape of the reference mask is N_labels x H x W, where H and W are the same as in the input image.
- Return type:
tuple[VisualPromptingFeatures,ndarray]
- reset_reference_info()#
Initialize reference information.
- Return type:
None
- property reference_features: VisualPromptingFeatures#
Property represents reference features. An exception is thrown if called when the features are not presented in the internal object state.
- class model_api.models.visual_prompting.SAMVisualPrompter(encoder_model, decoder_model)#
Bases:
objectA wrapper that implements SAM Visual Prompter.
Segmentation results can be obtained by calling infer() method with corresponding parameters.
- __call__(image, boxes=None, points=None)#
A wrapper of the SAMVisualPrompter.infer() method
- Return type:
VisualPromptingResult
- infer(image, boxes=None, points=None)#
Obtains segmentation masks using given prompts.
- Parameters:
image (
ndarray) – HWC-shaped imageboxes (
list[Prompt] |None) – Prompts containing bounding boxes (in XYXY torchvision format) and their labels (ints, one per box). Defaults to None.points (
list[Prompt] |None) – Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None.
- Returns:
result object containing predicted masks and aux information.
- Return type:
VisualPromptingResult