Skip to content

Torchvision Model

sahi.models.torchvision

TorchVision detection model wrapper for SAHI.

Provides integration with PyTorch's TorchVision library for object detection and instance segmentation models.

Classes

TorchVisionDetectionModel

Bases: DetectionModel

TorchVision object detection model.

Supports various TorchVision detection models like Faster R-CNN, Mask R-CNN, etc.

Source code in sahi/models/torchvision.py
class TorchVisionDetectionModel(DetectionModel):
    """TorchVision object detection model.

    Supports various TorchVision detection models like Faster R-CNN, Mask R-CNN, etc.
    """

    def __init__(self, *args: object, **kwargs: object) -> None:
        """Initialize TorchVision detection model."""
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "torch", "torchvision"]
        super().__init__(*args, **kwargs)  # type: ignore[misc, arg-type]

    def load_model(self) -> None:
        """Load TorchVision model from config and weights."""
        import torch

        # read config params
        model_name = None
        num_classes = None
        if self.config_path is not None:
            with open(self.config_path) as stream:
                try:
                    config = yaml.safe_load(stream)
                except yaml.YAMLError as exc:
                    raise RuntimeError(exc)

            model_name = config.get("model_name", None)
            num_classes = config.get("num_classes", None)

        # complete params if not provided in config
        if not model_name:
            model_name = "fasterrcnn_resnet50_fpn"
            logger.warning(f"model_name not provided in config, using default model_type: {model_name}'")
        if num_classes is None:
            logger.warning("num_classes not provided in config, using default num_classes: 91")
            num_classes = 91
        if self.model_path is None:
            logger.warning("model_path not provided in config, using pretrained weights and default num_classes: 91.")
            weights = "DEFAULT"
            num_classes = 91
        else:
            weights = None

        # load model
        # Note: torchvision >= 0.13 is required for the 'weights' parameter
        model = MODEL_NAME_TO_CONSTRUCTOR[model_name](num_classes=num_classes, weights=weights)
        if self.model_path:
            try:
                model.load_state_dict(torch.load(self.model_path))
            except Exception as e:
                logger.error(f"Invalid {self.model_path=}")
                raise TypeError("model_path is not a valid torchvision model path: ", e)

        self.set_model(model)

    def set_model(self, model: Any, **kwargs: Any) -> None:
        """Sets the underlying TorchVision model.

        Args:
            model: Any
                A TorchVision model
            **kwargs: Any
                Additional keyword arguments for model setup.
        """
        model.eval()  # type: ignore[attr-defined]
        self.model = model.to(self.device)  # type: ignore[attr-defined]

        # set category_mapping

        if self.category_mapping is None:
            category_names = {str(i): COCO_CLASSES[i] for i in range(len(COCO_CLASSES))}
            self.category_mapping = category_names

    def perform_inference(self, image: np.ndarray, image_size: int | None = None) -> None:
        """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
            image_size: int
                Inference input size.
        """
        from sahi.utils.torch_utils import to_float_tensor

        # arrange model input size
        assert self.model is not None
        if self.image_size is not None:
            # get min and max of image height and width
            min_shape, max_shape = min(image.shape[:2]), max(image.shape[:2])
            # torchvision resize transform scales the shorter dimension to the target size
            # we want to scale the longer dimension to the target size
            image_size = self.image_size * min_shape / max_shape
            self.model.transform.min_size = (image_size,)  # default is (800,)
            self.model.transform.max_size = image_size  # default is 1333

        image_tensor = to_float_tensor(image)
        image_tensor = image_tensor.to(self.device)
        prediction_result = self.model([image_tensor])

        self._original_predictions = prediction_result

    @property
    def num_categories(self) -> int:
        """Returns number of categories."""
        assert self.category_mapping is not None
        return len(self.category_mapping)

    @property
    def has_mask(self) -> bool:
        """Returns if model output contains segmentation mask."""
        return hasattr(self.model, "roi_heads") and hasattr(self.model.roi_heads, "mask_predictor")  # type: ignore[attr-defined]

    @property
    def category_names(self) -> list:
        """Return category names from mapping."""
        assert self.category_mapping is not None
        return list(self.category_mapping.values())

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int | float]] | None = [[0, 0]],
        full_shape_list: list[list[int | float]] | None = None,
    ) -> None:
        """Convert predictions to ObjectPrediction list.

        self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.

        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        original_predictions = self._original_predictions

        # compatilibty for sahi v0.8.20
        if shift_amount_list is not None and isinstance(shift_amount_list[0], int):
            shift_amount_list = [shift_amount_list]  # type: ignore[list-item]
        if full_shape_list is not None and isinstance(full_shape_list[0], int):
            full_shape_list = [full_shape_list]  # type: ignore[list-item]

        for image_predictions in original_predictions:
            object_prediction_list_per_image = []

            # get indices of boxes with score > confidence_threshold
            scores = image_predictions["scores"].cpu().detach().numpy()
            selected_indices = np.where(scores > self.confidence_threshold)[0]

            # parse boxes, masks, scores, category_ids from predictions
            category_ids = list(image_predictions["labels"][selected_indices].cpu().detach().numpy())
            boxes = list(image_predictions["boxes"][selected_indices].cpu().detach().numpy())
            scores = scores[selected_indices]

            # check if predictions contain mask
            masks = image_predictions.get("masks", None)
            if masks is not None:
                masks = list(
                    (image_predictions["masks"][selected_indices] > self.mask_threshold).cpu().detach().numpy()
                )
            else:
                masks = None

            # create object_prediction_list
            object_prediction_list = []

            shift_amount = shift_amount_list[0] if shift_amount_list else [0, 0]
            full_shape = None if full_shape_list is None else full_shape_list[0]

            for ind in range(len(boxes)):
                if masks is not None:
                    segmentation = get_coco_segmentation_from_bool_mask(np.array(masks[ind]))
                else:
                    segmentation = None

                object_prediction = ObjectPrediction(
                    bbox=boxes[ind],
                    segmentation=segmentation,
                    category_id=int(category_ids[ind]),
                    category_name=self.category_mapping[str(int(category_ids[ind]))] if self.category_mapping else "",  # type: ignore[index]
                    shift_amount=shift_amount,
                    score=scores[ind],
                    full_shape=full_shape,
                )
                object_prediction_list.append(object_prediction)
            object_prediction_list_per_image.append(object_prediction_list)

        self._object_prediction_list_per_image = object_prediction_list_per_image
Attributes
category_names property

Return category names from mapping.

has_mask property

Returns if model output contains segmentation mask.

num_categories property

Returns number of categories.

Functions
__init__(*args, **kwargs)

Initialize TorchVision detection model.

Source code in sahi/models/torchvision.py
def __init__(self, *args: object, **kwargs: object) -> None:
    """Initialize TorchVision detection model."""
    existing_packages = getattr(self, "required_packages", None) or []
    self.required_packages = [*list(existing_packages), "torch", "torchvision"]
    super().__init__(*args, **kwargs)  # type: ignore[misc, arg-type]
load_model()

Load TorchVision model from config and weights.

Source code in sahi/models/torchvision.py
def load_model(self) -> None:
    """Load TorchVision model from config and weights."""
    import torch

    # read config params
    model_name = None
    num_classes = None
    if self.config_path is not None:
        with open(self.config_path) as stream:
            try:
                config = yaml.safe_load(stream)
            except yaml.YAMLError as exc:
                raise RuntimeError(exc)

        model_name = config.get("model_name", None)
        num_classes = config.get("num_classes", None)

    # complete params if not provided in config
    if not model_name:
        model_name = "fasterrcnn_resnet50_fpn"
        logger.warning(f"model_name not provided in config, using default model_type: {model_name}'")
    if num_classes is None:
        logger.warning("num_classes not provided in config, using default num_classes: 91")
        num_classes = 91
    if self.model_path is None:
        logger.warning("model_path not provided in config, using pretrained weights and default num_classes: 91.")
        weights = "DEFAULT"
        num_classes = 91
    else:
        weights = None

    # load model
    # Note: torchvision >= 0.13 is required for the 'weights' parameter
    model = MODEL_NAME_TO_CONSTRUCTOR[model_name](num_classes=num_classes, weights=weights)
    if self.model_path:
        try:
            model.load_state_dict(torch.load(self.model_path))
        except Exception as e:
            logger.error(f"Invalid {self.model_path=}")
            raise TypeError("model_path is not a valid torchvision model path: ", e)

    self.set_model(model)
perform_inference(image, image_size=None)

Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Parameters:

Name Type Description Default
image ndarray

np.ndarray A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.

required
image_size int | None

int Inference input size.

None
Source code in sahi/models/torchvision.py
def perform_inference(self, image: np.ndarray, image_size: int | None = None) -> None:
    """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        image_size: int
            Inference input size.
    """
    from sahi.utils.torch_utils import to_float_tensor

    # arrange model input size
    assert self.model is not None
    if self.image_size is not None:
        # get min and max of image height and width
        min_shape, max_shape = min(image.shape[:2]), max(image.shape[:2])
        # torchvision resize transform scales the shorter dimension to the target size
        # we want to scale the longer dimension to the target size
        image_size = self.image_size * min_shape / max_shape
        self.model.transform.min_size = (image_size,)  # default is (800,)
        self.model.transform.max_size = image_size  # default is 1333

    image_tensor = to_float_tensor(image)
    image_tensor = image_tensor.to(self.device)
    prediction_result = self.model([image_tensor])

    self._original_predictions = prediction_result
set_model(model, **kwargs)

Sets the underlying TorchVision model.

Parameters:

Name Type Description Default
model Any

Any A TorchVision model

required
**kwargs Any

Any Additional keyword arguments for model setup.

{}
Source code in sahi/models/torchvision.py
def set_model(self, model: Any, **kwargs: Any) -> None:
    """Sets the underlying TorchVision model.

    Args:
        model: Any
            A TorchVision model
        **kwargs: Any
            Additional keyword arguments for model setup.
    """
    model.eval()  # type: ignore[attr-defined]
    self.model = model.to(self.device)  # type: ignore[attr-defined]

    # set category_mapping

    if self.category_mapping is None:
        category_names = {str(i): COCO_CLASSES[i] for i in range(len(COCO_CLASSES))}
        self.category_mapping = category_names

Functions