Skip to content

BaseModel

sahi.models.base

Classes

DetectionModel

Source code in sahi/models/base.py
class DetectionModel:
    def __init__(
        self,
        model_path: Optional[str] = None,
        model: Optional[Any] = None,
        config_path: Optional[str] = None,
        device: Optional[str] = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: Optional[Dict] = None,
        category_remapping: Optional[Dict] = None,
        load_at_init: bool = True,
        image_size: Optional[int] = None,
    ):
        """
        Init object detection/instance segmentation model.
        Args:
            model_path: str
                Path for the instance segmentation model weight
            config_path: str
                Path for the mmdetection instance segmentation model config file
            device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
            mask_threshold: float
                Value to threshold mask pixels, should be between 0 and 1
            confidence_threshold: float
                All predictions with score < confidence_threshold will be discarded
            category_mapping: dict: str to str
                Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
            category_remapping: dict: str to int
                Remap category ids based on category names, after performing inference e.g. {"car": 3}
            load_at_init: bool
                If True, automatically loads the model at initialization
            image_size: int
                Inference input size.
        """
        self.model_path = model_path
        self.config_path = config_path
        self.model = None
        self.mask_threshold = mask_threshold
        self.confidence_threshold = confidence_threshold
        self.category_mapping = category_mapping
        self.category_remapping = category_remapping
        self.image_size = image_size
        self._original_predictions = None
        self._object_prediction_list_per_image = None
        self.set_device(device)

        # automatically load model if load_at_init is True
        if load_at_init:
            if model:
                self.set_model(model)
            else:
                self.load_model()

    def check_dependencies(self) -> None:
        """
        This function can be implemented to ensure model dependencies are installed.
        """
        pass

    def load_model(self):
        """
        This function should be implemented in a way that detection model
        should be initialized and set to self.model.
        (self.model_path, self.config_path, and self.device should be utilized)
        """
        raise NotImplementedError()

    def set_model(self, model: Any, **kwargs):
        """
        This function should be implemented to instantiate a DetectionModel out of an already loaded model
        Args:
            model: Any
                Loaded model
        """
        raise NotImplementedError()

    def set_device(self, device: Optional[str] = None):
        """Sets the device pytorch should use for the model

        Args:
            device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
        """
        if has_torch:
            self.device = select_device(device)
        else:
            raise NotImplementedError(f"Could not set device {self.device}")

    def unload_model(self):
        """
        Unloads the model from CPU/GPU.
        """
        self.model = None
        empty_cuda_cache()

    def perform_inference(self, image: np.ndarray):
        """
        This function should be implemented in a way that prediction should be
        performed using self.model and the prediction result should be set to self._original_predictions.
        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted.
        """
        raise NotImplementedError()

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
        full_shape_list: Optional[List[List[int]]] = None,
    ):
        """
        This function should be implemented in a way that self._original_predictions should
        be converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list. self.mask_threshold can also be utilized.
        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        raise NotImplementedError()

    def _apply_category_remapping(self):
        """
        Applies category remapping based on mapping given in self.category_remapping
        """
        # confirm self.category_remapping is not None
        if self.category_remapping is None:
            raise ValueError("self.category_remapping cannot be None")
        # remap categories
        if not isinstance(self._object_prediction_list_per_image, list):
            logger.error(
                f"Unknown type for self._object_prediction_list_per_image: {type(self._object_prediction_list_per_image)}"
            )
            return
        for object_prediction_list in self._object_prediction_list_per_image:  # type: ignore
            for object_prediction in object_prediction_list:
                old_category_id_str = str(object_prediction.category.id)
                new_category_id_int = self.category_remapping[old_category_id_str]
                object_prediction.category.id = new_category_id_int

    def convert_original_predictions(
        self,
        shift_amount: Optional[List[List[int]]] = [[0, 0]],
        full_shape: Optional[List[List[int]]] = None,
    ):
        """
        Converts original predictions of the detection model to a list of
        prediction.ObjectPrediction object. Should be called after perform_inference().
        Args:
            shift_amount: list
                To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]
            full_shape: list
                Size of the full image after shifting, should be in the form of [height, width]
        """
        self._create_object_prediction_list_from_original_predictions(
            shift_amount_list=shift_amount,
            full_shape_list=full_shape,
        )
        if self.category_remapping:
            self._apply_category_remapping()

    @property
    def object_prediction_list(self) -> List[List[ObjectPrediction]]:
        if self._object_prediction_list_per_image is None:
            return []
        if len(self._object_prediction_list_per_image) == 0:
            return []
        return self._object_prediction_list_per_image[0]

    @property
    def object_prediction_list_per_image(self) -> List[List[ObjectPrediction]]:
        return self._object_prediction_list_per_image or []

    @property
    def original_predictions(self):
        return self._original_predictions
Functions
__init__(model_path=None, model=None, config_path=None, device=None, mask_threshold=0.5, confidence_threshold=0.3, category_mapping=None, category_remapping=None, load_at_init=True, image_size=None)

Init object detection/instance segmentation model. Args: model_path: str Path for the instance segmentation model weight config_path: str Path for the mmdetection instance segmentation model config file device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc. mask_threshold: float Value to threshold mask pixels, should be between 0 and 1 confidence_threshold: float All predictions with score < confidence_threshold will be discarded category_mapping: dict: str to str Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} category_remapping: dict: str to int Remap category ids based on category names, after performing inference e.g. {"car": 3} load_at_init: bool If True, automatically loads the model at initialization image_size: int Inference input size.

Source code in sahi/models/base.py
def __init__(
    self,
    model_path: Optional[str] = None,
    model: Optional[Any] = None,
    config_path: Optional[str] = None,
    device: Optional[str] = None,
    mask_threshold: float = 0.5,
    confidence_threshold: float = 0.3,
    category_mapping: Optional[Dict] = None,
    category_remapping: Optional[Dict] = None,
    load_at_init: bool = True,
    image_size: Optional[int] = None,
):
    """
    Init object detection/instance segmentation model.
    Args:
        model_path: str
            Path for the instance segmentation model weight
        config_path: str
            Path for the mmdetection instance segmentation model config file
        device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
        mask_threshold: float
            Value to threshold mask pixels, should be between 0 and 1
        confidence_threshold: float
            All predictions with score < confidence_threshold will be discarded
        category_mapping: dict: str to str
            Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
        category_remapping: dict: str to int
            Remap category ids based on category names, after performing inference e.g. {"car": 3}
        load_at_init: bool
            If True, automatically loads the model at initialization
        image_size: int
            Inference input size.
    """
    self.model_path = model_path
    self.config_path = config_path
    self.model = None
    self.mask_threshold = mask_threshold
    self.confidence_threshold = confidence_threshold
    self.category_mapping = category_mapping
    self.category_remapping = category_remapping
    self.image_size = image_size
    self._original_predictions = None
    self._object_prediction_list_per_image = None
    self.set_device(device)

    # automatically load model if load_at_init is True
    if load_at_init:
        if model:
            self.set_model(model)
        else:
            self.load_model()
check_dependencies()

This function can be implemented to ensure model dependencies are installed.

Source code in sahi/models/base.py
def check_dependencies(self) -> None:
    """
    This function can be implemented to ensure model dependencies are installed.
    """
    pass
convert_original_predictions(shift_amount=[[0, 0]], full_shape=None)

Converts original predictions of the detection model to a list of prediction.ObjectPrediction object. Should be called after perform_inference(). Args: shift_amount: list To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y] full_shape: list Size of the full image after shifting, should be in the form of [height, width]

Source code in sahi/models/base.py
def convert_original_predictions(
    self,
    shift_amount: Optional[List[List[int]]] = [[0, 0]],
    full_shape: Optional[List[List[int]]] = None,
):
    """
    Converts original predictions of the detection model to a list of
    prediction.ObjectPrediction object. Should be called after perform_inference().
    Args:
        shift_amount: list
            To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]
        full_shape: list
            Size of the full image after shifting, should be in the form of [height, width]
    """
    self._create_object_prediction_list_from_original_predictions(
        shift_amount_list=shift_amount,
        full_shape_list=full_shape,
    )
    if self.category_remapping:
        self._apply_category_remapping()
load_model()

This function should be implemented in a way that detection model should be initialized and set to self.model. (self.model_path, self.config_path, and self.device should be utilized)

Source code in sahi/models/base.py
def load_model(self):
    """
    This function should be implemented in a way that detection model
    should be initialized and set to self.model.
    (self.model_path, self.config_path, and self.device should be utilized)
    """
    raise NotImplementedError()
perform_inference(image)

This function should be implemented in a way that prediction should be performed using self.model and the prediction result should be set to self._original_predictions. Args: image: np.ndarray A numpy array that contains the image to be predicted.

Source code in sahi/models/base.py
def perform_inference(self, image: np.ndarray):
    """
    This function should be implemented in a way that prediction should be
    performed using self.model and the prediction result should be set to self._original_predictions.
    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted.
    """
    raise NotImplementedError()
set_device(device=None)

Sets the device pytorch should use for the model

Parameters:

Name Type Description Default
device Optional[str]

Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.

None
Source code in sahi/models/base.py
def set_device(self, device: Optional[str] = None):
    """Sets the device pytorch should use for the model

    Args:
        device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
    """
    if has_torch:
        self.device = select_device(device)
    else:
        raise NotImplementedError(f"Could not set device {self.device}")
set_model(model, **kwargs)

This function should be implemented to instantiate a DetectionModel out of an already loaded model Args: model: Any Loaded model

Source code in sahi/models/base.py
def set_model(self, model: Any, **kwargs):
    """
    This function should be implemented to instantiate a DetectionModel out of an already loaded model
    Args:
        model: Any
            Loaded model
    """
    raise NotImplementedError()
unload_model()

Unloads the model from CPU/GPU.

Source code in sahi/models/base.py
def unload_model(self):
    """
    Unloads the model from CPU/GPU.
    """
    self.model = None
    empty_cuda_cache()

Functions