Skip to content

sahi

Classes

AutoDetectionModel

Source code in sahi/auto_model.py
class AutoDetectionModel:
    @staticmethod
    def from_pretrained(
        model_type: str,
        model_path: str | None = None,
        model: Any | None = None,
        config_path: str | None = None,
        device: str | None = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: dict | None = None,
        category_remapping: dict | None = None,
        load_at_init: bool = True,
        image_size: int | None = None,
        **kwargs,
    ) -> DetectionModel:
        """Loads a DetectionModel from given path.

        Args:
            model_type: str
                Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")
            model_path: str
                Path of the detection model (ex. 'model.pt')
            model: Any
                A pre-initialized model instance, if available
            config_path: str
                Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py')
            device: str
                Device, "cpu" or "cuda:0"
            mask_threshold: float
                Value to threshold mask pixels, should be between 0 and 1
            confidence_threshold: float
                All predictions with score < confidence_threshold will be discarded
            category_mapping: dict: str to str
                Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
            category_remapping: dict: str to int
                Remap category ids based on category names, after performing inference e.g. {"car": 3}
            load_at_init: bool
                If True, automatically loads the model at initialization
            image_size: int
                Inference input size.

        Returns:
            Returns an instance of a DetectionModel

        Raises:
            ImportError: If given {model_type} framework is not installed
        """
        if model_type in ULTRALYTICS_MODEL_NAMES:
            model_type = "ultralytics"
        model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
        DetectionModel = import_model_class(model_type, model_class_name)

        return DetectionModel(
            model_path=model_path,
            model=model,
            config_path=config_path,
            device=device,
            mask_threshold=mask_threshold,
            confidence_threshold=confidence_threshold,
            category_mapping=category_mapping,
            category_remapping=category_remapping,
            load_at_init=load_at_init,
            image_size=image_size,
            **kwargs,
        )
Functions
from_pretrained(model_type, model_path=None, model=None, config_path=None, device=None, mask_threshold=0.5, confidence_threshold=0.3, category_mapping=None, category_remapping=None, load_at_init=True, image_size=None, **kwargs) staticmethod

Loads a DetectionModel from given path.

Parameters:

Name Type Description Default
model_type str

str Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")

required
model_path str | None

str Path of the detection model (ex. 'model.pt')

None
model Any | None

Any A pre-initialized model instance, if available

None
config_path str | None

str Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py')

None
device str | None

str Device, "cpu" or "cuda:0"

None
mask_threshold float

float Value to threshold mask pixels, should be between 0 and 1

0.5
confidence_threshold float

float All predictions with score < confidence_threshold will be discarded

0.3
category_mapping dict | None

dict: str to str Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}

None
category_remapping dict | None

dict: str to int Remap category ids based on category names, after performing inference e.g. {"car": 3}

None
load_at_init bool

bool If True, automatically loads the model at initialization

True
image_size int | None

int Inference input size.

None

Returns:

Type Description
DetectionModel

Returns an instance of a DetectionModel

Raises:

Type Description
ImportError

If given {model_type} framework is not installed

Source code in sahi/auto_model.py
@staticmethod
def from_pretrained(
    model_type: str,
    model_path: str | None = None,
    model: Any | None = None,
    config_path: str | None = None,
    device: str | None = None,
    mask_threshold: float = 0.5,
    confidence_threshold: float = 0.3,
    category_mapping: dict | None = None,
    category_remapping: dict | None = None,
    load_at_init: bool = True,
    image_size: int | None = None,
    **kwargs,
) -> DetectionModel:
    """Loads a DetectionModel from given path.

    Args:
        model_type: str
            Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")
        model_path: str
            Path of the detection model (ex. 'model.pt')
        model: Any
            A pre-initialized model instance, if available
        config_path: str
            Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py')
        device: str
            Device, "cpu" or "cuda:0"
        mask_threshold: float
            Value to threshold mask pixels, should be between 0 and 1
        confidence_threshold: float
            All predictions with score < confidence_threshold will be discarded
        category_mapping: dict: str to str
            Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
        category_remapping: dict: str to int
            Remap category ids based on category names, after performing inference e.g. {"car": 3}
        load_at_init: bool
            If True, automatically loads the model at initialization
        image_size: int
            Inference input size.

    Returns:
        Returns an instance of a DetectionModel

    Raises:
        ImportError: If given {model_type} framework is not installed
    """
    if model_type in ULTRALYTICS_MODEL_NAMES:
        model_type = "ultralytics"
    model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
    DetectionModel = import_model_class(model_type, model_class_name)

    return DetectionModel(
        model_path=model_path,
        model=model,
        config_path=config_path,
        device=device,
        mask_threshold=mask_threshold,
        confidence_threshold=confidence_threshold,
        category_mapping=category_mapping,
        category_remapping=category_remapping,
        load_at_init=load_at_init,
        image_size=image_size,
        **kwargs,
    )

BoundingBox dataclass

BoundingBox represents a rectangular region in 2D space, typically used for object detection annotations.

Attributes:

Name Type Description
box Tuple[float, float, float, float]

The bounding box coordinates in the format (minx, miny, maxx, maxy). - minx (float): Minimum x-coordinate (left). - miny (float): Minimum y-coordinate (top). - maxx (float): Maximum x-coordinate (right). - maxy (float): Maximum y-coordinate (bottom).

shift_amount Tuple[int, int]

The amount to shift the bounding box in the x and y directions. Defaults to (0, 0).

BoundingBox Usage Example

bbox = BoundingBox((10.0, 20.0, 50.0, 80.0))
area = bbox.area
expanded_bbox = bbox.get_expanded_box(ratio=0.2)
shifted_bbox = bbox.get_shifted_box()
coco_format = bbox.to_coco_bbox()
Source code in sahi/annotation.py
@dataclass(frozen=True)
class BoundingBox:
    """BoundingBox represents a rectangular region in 2D space, typically used for object detection annotations.

    Attributes:
        box (Tuple[float, float, float, float]): The bounding box coordinates in the format (minx, miny, maxx, maxy).
            - minx (float): Minimum x-coordinate (left).
            - miny (float): Minimum y-coordinate (top).
            - maxx (float): Maximum x-coordinate (right).
            - maxy (float): Maximum y-coordinate (bottom).
        shift_amount (Tuple[int, int], optional): The amount to shift the bounding box in the x and y directions.
            Defaults to (0, 0).

    !!! example "BoundingBox Usage Example"
        ```python
        bbox = BoundingBox((10.0, 20.0, 50.0, 80.0))
        area = bbox.area
        expanded_bbox = bbox.get_expanded_box(ratio=0.2)
        shifted_bbox = bbox.get_shifted_box()
        coco_format = bbox.to_coco_bbox()
        ```
    """

    box: tuple[float, float, float, float] | list[float]
    shift_amount: tuple[int, int] = (0, 0)

    def __post_init__(self):
        if len(self.box) != 4 or any(coord < 0 for coord in self.box):
            raise ValueError("box must be 4 non-negative floats: [minx, miny, maxx, maxy]")
        if len(self.shift_amount) != 2:
            raise ValueError("shift_amount must be 2 integers: [shift_x, shift_y]")

    @property
    def minx(self):
        return self.box[0]

    @property
    def miny(self):
        return self.box[1]

    @property
    def maxx(self):
        return self.box[2]

    @property
    def maxy(self):
        return self.box[3]

    @property
    def shift_x(self):
        return self.shift_amount[0]

    @property
    def shift_y(self):
        return self.shift_amount[1]

    @property
    def area(self):
        return (self.maxx - self.minx) * (self.maxy - self.miny)

    def get_expanded_box(self, ratio: float = 0.1, max_x: int | None = None, max_y: int | None = None):
        """Returns an expanded bounding box by increasing its size by a given ratio. The expansion is applied equally in
        all directions. Optionally, the expanded box can be clipped to maximum x and y boundaries.

        Args:
            ratio (float, optional): The proportion by which to expand the box size.
                Default is 0.1 (10%).
            max_x (int, optional): The maximum allowed x-coordinate for the expanded box.
                If None, no maximum is applied.
            max_y (int, optional): The maximum allowed y-coordinate for the expanded box.
                If None, no maximum is applied.

        Returns:
            BoundingBox: A new BoundingBox instance representing the expanded box.
        """

        w = self.maxx - self.minx
        h = self.maxy - self.miny
        y_mar = int(h * ratio)
        x_mar = int(w * ratio)
        maxx = min(max_x, self.maxx + x_mar) if max_x else self.maxx + x_mar
        minx = max(0, self.minx - x_mar)
        maxy = min(max_y, self.maxy + y_mar) if max_y else self.maxy + y_mar
        miny = max(0, self.miny - y_mar)
        box: list[float] = [minx, miny, maxx, maxy]
        return BoundingBox(box)

    def to_xywh(self):
        """Returns [xmin, ymin, width, height]

        Returns:
            List[float]: A list containing the bounding box in the format [xmin, ymin, width, height].
        """

        return [self.minx, self.miny, self.maxx - self.minx, self.maxy - self.miny]

    def to_coco_bbox(self):
        """
        Returns the bounding box in COCO format: [xmin, ymin, width, height]

        Returns:
            List[float]: A list containing the bounding box in COCO format.
        """
        return self.to_xywh()

    def to_xyxy(self):
        """
        Returns: [xmin, ymin, xmax, ymax]

        Returns:
            List[float]: A list containing the bounding box in the format [xmin, ymin, xmax, ymax].
        """
        return [self.minx, self.miny, self.maxx, self.maxy]

    def to_voc_bbox(self):
        """
        Returns the bounding box in VOC format: [xmin, ymin, xmax, ymax]

        Returns:
            List[float]: A list containing the bounding box in VOC format.
        """
        return self.to_xyxy()

    def get_shifted_box(self):
        """Returns shifted BoundingBox.

        Returns:
            BoundingBox: A new BoundingBox instance representing the shifted box.
        """
        box = [
            self.minx + self.shift_x,
            self.miny + self.shift_y,
            self.maxx + self.shift_x,
            self.maxy + self.shift_y,
        ]
        return BoundingBox(box)

    def __repr__(self):
        return (
            f"BoundingBox: <{(self.minx, self.miny, self.maxx, self.maxy)}, "
            f"w: {self.maxx - self.minx}, h: {self.maxy - self.miny}>"
        )
Functions
get_expanded_box(ratio=0.1, max_x=None, max_y=None)

Returns an expanded bounding box by increasing its size by a given ratio. The expansion is applied equally in all directions. Optionally, the expanded box can be clipped to maximum x and y boundaries.

Parameters:

Name Type Description Default
ratio float

The proportion by which to expand the box size. Default is 0.1 (10%).

0.1
max_x int

The maximum allowed x-coordinate for the expanded box. If None, no maximum is applied.

None
max_y int

The maximum allowed y-coordinate for the expanded box. If None, no maximum is applied.

None

Returns:

Name Type Description
BoundingBox

A new BoundingBox instance representing the expanded box.

Source code in sahi/annotation.py
def get_expanded_box(self, ratio: float = 0.1, max_x: int | None = None, max_y: int | None = None):
    """Returns an expanded bounding box by increasing its size by a given ratio. The expansion is applied equally in
    all directions. Optionally, the expanded box can be clipped to maximum x and y boundaries.

    Args:
        ratio (float, optional): The proportion by which to expand the box size.
            Default is 0.1 (10%).
        max_x (int, optional): The maximum allowed x-coordinate for the expanded box.
            If None, no maximum is applied.
        max_y (int, optional): The maximum allowed y-coordinate for the expanded box.
            If None, no maximum is applied.

    Returns:
        BoundingBox: A new BoundingBox instance representing the expanded box.
    """

    w = self.maxx - self.minx
    h = self.maxy - self.miny
    y_mar = int(h * ratio)
    x_mar = int(w * ratio)
    maxx = min(max_x, self.maxx + x_mar) if max_x else self.maxx + x_mar
    minx = max(0, self.minx - x_mar)
    maxy = min(max_y, self.maxy + y_mar) if max_y else self.maxy + y_mar
    miny = max(0, self.miny - y_mar)
    box: list[float] = [minx, miny, maxx, maxy]
    return BoundingBox(box)
get_shifted_box()

Returns shifted BoundingBox.

Returns:

Name Type Description
BoundingBox

A new BoundingBox instance representing the shifted box.

Source code in sahi/annotation.py
def get_shifted_box(self):
    """Returns shifted BoundingBox.

    Returns:
        BoundingBox: A new BoundingBox instance representing the shifted box.
    """
    box = [
        self.minx + self.shift_x,
        self.miny + self.shift_y,
        self.maxx + self.shift_x,
        self.maxy + self.shift_y,
    ]
    return BoundingBox(box)
to_coco_bbox()

Returns the bounding box in COCO format: [xmin, ymin, width, height]

Returns:

Type Description

List[float]: A list containing the bounding box in COCO format.

Source code in sahi/annotation.py
def to_coco_bbox(self):
    """
    Returns the bounding box in COCO format: [xmin, ymin, width, height]

    Returns:
        List[float]: A list containing the bounding box in COCO format.
    """
    return self.to_xywh()
to_voc_bbox()

Returns the bounding box in VOC format: [xmin, ymin, xmax, ymax]

Returns:

Type Description

List[float]: A list containing the bounding box in VOC format.

Source code in sahi/annotation.py
def to_voc_bbox(self):
    """
    Returns the bounding box in VOC format: [xmin, ymin, xmax, ymax]

    Returns:
        List[float]: A list containing the bounding box in VOC format.
    """
    return self.to_xyxy()
to_xywh()

Returns [xmin, ymin, width, height]

Returns:

Type Description

List[float]: A list containing the bounding box in the format [xmin, ymin, width, height].

Source code in sahi/annotation.py
def to_xywh(self):
    """Returns [xmin, ymin, width, height]

    Returns:
        List[float]: A list containing the bounding box in the format [xmin, ymin, width, height].
    """

    return [self.minx, self.miny, self.maxx - self.minx, self.maxy - self.miny]
to_xyxy()

Returns: [xmin, ymin, xmax, ymax]

Returns:

Type Description

List[float]: A list containing the bounding box in the format [xmin, ymin, xmax, ymax].

Source code in sahi/annotation.py
def to_xyxy(self):
    """
    Returns: [xmin, ymin, xmax, ymax]

    Returns:
        List[float]: A list containing the bounding box in the format [xmin, ymin, xmax, ymax].
    """
    return [self.minx, self.miny, self.maxx, self.maxy]

Category dataclass

Category of the annotation.

Attributes:

Name Type Description
id int

Unique identifier for the category.

name str

Name of the category.

Source code in sahi/annotation.py
@dataclass(frozen=True)
class Category:
    """Category of the annotation.

    Attributes:
        id (int): Unique identifier for the category.
        name (str): Name of the category.
    """

    id: int
    name: str

    def __post_init__(self):
        if not isinstance(self.id, int):
            raise TypeError("id should be integer")
        if not isinstance(self.name, str):
            raise TypeError("name should be string")

    def __repr__(self):
        return f"Category: <id: {self.id}, name: {self.name}>"

DetectionModel

Source code in sahi/models/base.py
class DetectionModel:
    required_packages: list[str] | None = None

    def __init__(
        self,
        model_path: str | None = None,
        model: Any | None = None,
        config_path: str | None = None,
        device: str | None = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: dict | None = None,
        category_remapping: dict | None = None,
        load_at_init: bool = True,
        image_size: int | None = None,
    ):
        """Init object detection/instance segmentation model.

        Args:
            model_path: str
                Path for the instance segmentation model weight
            config_path: str
                Path for the mmdetection instance segmentation model config file
            device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
            mask_threshold: float
                Value to threshold mask pixels, should be between 0 and 1
            confidence_threshold: float
                All predictions with score < confidence_threshold will be discarded
            category_mapping: dict: str to str
                Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
            category_remapping: dict: str to int
                Remap category ids based on category names, after performing inference e.g. {"car": 3}
            load_at_init: bool
                If True, automatically loads the model at initialization
            image_size: int
                Inference input size.
        """

        self.model_path = model_path
        self.config_path = config_path
        self.model = None
        self.mask_threshold = mask_threshold
        self.confidence_threshold = confidence_threshold
        self.category_mapping = category_mapping
        self.category_remapping = category_remapping
        self.image_size = image_size
        self._original_predictions = None
        self._object_prediction_list_per_image = None
        self.set_device(device)

        # automatically ensure dependencies
        self.check_dependencies()

        # automatically load model if load_at_init is True
        if load_at_init:
            if model:
                self.set_model(model)
            else:
                self.load_model()

    def check_dependencies(self, packages: list[str] | None = None) -> None:
        """Ensures required dependencies are installed.

        If 'packages' is None, uses self.required_packages. Subclasses may still call with a custom list for dynamic
        needs.
        """
        pkgs = packages if packages is not None else getattr(self, "required_packages", [])
        if pkgs:
            check_requirements(pkgs)

    def load_model(self):
        """This function should be implemented in a way that detection model should be initialized and set to
        self.model.

        (self.model_path, self.config_path, and self.device should be utilized)
        """
        raise NotImplementedError()

    def set_model(self, model: Any, **kwargs):
        """
        This function should be implemented to instantiate a DetectionModel out of an already loaded model
        Args:
            model: Any
                Loaded model
        """
        raise NotImplementedError()

    def set_device(self, device: str | None = None):
        """Sets the device pytorch should use for the model.

        Args:
            device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
        """

        self.device = select_device(device)

    def unload_model(self):
        """Unloads the model from CPU/GPU."""
        self.model = None
        empty_cuda_cache()

    def perform_inference(self, image: np.ndarray):
        """This function should be implemented in a way that prediction should be performed using self.model and the
        prediction result should be set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted.
        """
        raise NotImplementedError()

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int]] | None = [[0, 0]],
        full_shape_list: list[list[int]] | None = None,
    ):
        """This function should be implemented in a way that self._original_predictions should be converted to a list of
        prediction.ObjectPrediction and set to self._object_prediction_list.

        self.mask_threshold can also be utilized.
        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        raise NotImplementedError()

    def _apply_category_remapping(self):
        """Applies category remapping based on mapping given in self.category_remapping."""
        # confirm self.category_remapping is not None
        if self.category_remapping is None:
            raise ValueError("self.category_remapping cannot be None")
        # remap categories
        if not isinstance(self._object_prediction_list_per_image, list):
            logger.error(
                f"Unknown type for self._object_prediction_list_per_image: "
                f"{type(self._object_prediction_list_per_image)}"
            )
            return
        for object_prediction_list in self._object_prediction_list_per_image:  # type: ignore
            for object_prediction in object_prediction_list:
                old_category_id_str = str(object_prediction.category.id)
                new_category_id_int = self.category_remapping[old_category_id_str]
                object_prediction.category = Category(id=new_category_id_int, name=object_prediction.category.name)

    def convert_original_predictions(
        self,
        shift_amount: list[list[int]] | None = [[0, 0]],
        full_shape: list[list[int]] | None = None,
    ):
        """Converts original predictions of the detection model to a list of prediction.ObjectPrediction object.

        Should be called after perform_inference().
        Args:
            shift_amount: list
                To shift the box and mask predictions from sliced image to full sized image,
                    should be in the form of [shift_x, shift_y]
            full_shape: list
                Size of the full image after shifting, should be in the form of [height, width]
        """
        self._create_object_prediction_list_from_original_predictions(
            shift_amount_list=shift_amount,
            full_shape_list=full_shape,
        )
        if self.category_remapping:
            self._apply_category_remapping()

    @property
    def object_prediction_list(self) -> list[list[ObjectPrediction]]:
        if self._object_prediction_list_per_image is None:
            return []
        if len(self._object_prediction_list_per_image) == 0:
            return []
        return self._object_prediction_list_per_image[0]

    @property
    def object_prediction_list_per_image(self) -> list[list[ObjectPrediction]]:
        return self._object_prediction_list_per_image or []

    @property
    def original_predictions(self):
        return self._original_predictions
Functions
__init__(model_path=None, model=None, config_path=None, device=None, mask_threshold=0.5, confidence_threshold=0.3, category_mapping=None, category_remapping=None, load_at_init=True, image_size=None)

Init object detection/instance segmentation model.

Parameters:

Name Type Description Default
model_path str | None

str Path for the instance segmentation model weight

None
config_path str | None

str Path for the mmdetection instance segmentation model config file

None
device str | None

Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.

None
mask_threshold float

float Value to threshold mask pixels, should be between 0 and 1

0.5
confidence_threshold float

float All predictions with score < confidence_threshold will be discarded

0.3
category_mapping dict | None

dict: str to str Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}

None
category_remapping dict | None

dict: str to int Remap category ids based on category names, after performing inference e.g. {"car": 3}

None
load_at_init bool

bool If True, automatically loads the model at initialization

True
image_size int | None

int Inference input size.

None
Source code in sahi/models/base.py
def __init__(
    self,
    model_path: str | None = None,
    model: Any | None = None,
    config_path: str | None = None,
    device: str | None = None,
    mask_threshold: float = 0.5,
    confidence_threshold: float = 0.3,
    category_mapping: dict | None = None,
    category_remapping: dict | None = None,
    load_at_init: bool = True,
    image_size: int | None = None,
):
    """Init object detection/instance segmentation model.

    Args:
        model_path: str
            Path for the instance segmentation model weight
        config_path: str
            Path for the mmdetection instance segmentation model config file
        device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
        mask_threshold: float
            Value to threshold mask pixels, should be between 0 and 1
        confidence_threshold: float
            All predictions with score < confidence_threshold will be discarded
        category_mapping: dict: str to str
            Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
        category_remapping: dict: str to int
            Remap category ids based on category names, after performing inference e.g. {"car": 3}
        load_at_init: bool
            If True, automatically loads the model at initialization
        image_size: int
            Inference input size.
    """

    self.model_path = model_path
    self.config_path = config_path
    self.model = None
    self.mask_threshold = mask_threshold
    self.confidence_threshold = confidence_threshold
    self.category_mapping = category_mapping
    self.category_remapping = category_remapping
    self.image_size = image_size
    self._original_predictions = None
    self._object_prediction_list_per_image = None
    self.set_device(device)

    # automatically ensure dependencies
    self.check_dependencies()

    # automatically load model if load_at_init is True
    if load_at_init:
        if model:
            self.set_model(model)
        else:
            self.load_model()
check_dependencies(packages=None)

Ensures required dependencies are installed.

If 'packages' is None, uses self.required_packages. Subclasses may still call with a custom list for dynamic needs.

Source code in sahi/models/base.py
def check_dependencies(self, packages: list[str] | None = None) -> None:
    """Ensures required dependencies are installed.

    If 'packages' is None, uses self.required_packages. Subclasses may still call with a custom list for dynamic
    needs.
    """
    pkgs = packages if packages is not None else getattr(self, "required_packages", [])
    if pkgs:
        check_requirements(pkgs)
convert_original_predictions(shift_amount=[[0, 0]], full_shape=None)

Converts original predictions of the detection model to a list of prediction.ObjectPrediction object.

Should be called after perform_inference(). Args: shift_amount: list To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y] full_shape: list Size of the full image after shifting, should be in the form of [height, width]

Source code in sahi/models/base.py
def convert_original_predictions(
    self,
    shift_amount: list[list[int]] | None = [[0, 0]],
    full_shape: list[list[int]] | None = None,
):
    """Converts original predictions of the detection model to a list of prediction.ObjectPrediction object.

    Should be called after perform_inference().
    Args:
        shift_amount: list
            To shift the box and mask predictions from sliced image to full sized image,
                should be in the form of [shift_x, shift_y]
        full_shape: list
            Size of the full image after shifting, should be in the form of [height, width]
    """
    self._create_object_prediction_list_from_original_predictions(
        shift_amount_list=shift_amount,
        full_shape_list=full_shape,
    )
    if self.category_remapping:
        self._apply_category_remapping()
load_model()

This function should be implemented in a way that detection model should be initialized and set to self.model.

(self.model_path, self.config_path, and self.device should be utilized)

Source code in sahi/models/base.py
def load_model(self):
    """This function should be implemented in a way that detection model should be initialized and set to
    self.model.

    (self.model_path, self.config_path, and self.device should be utilized)
    """
    raise NotImplementedError()
perform_inference(image)

This function should be implemented in a way that prediction should be performed using self.model and the prediction result should be set to self._original_predictions.

Parameters:

Name Type Description Default
image ndarray

np.ndarray A numpy array that contains the image to be predicted.

required
Source code in sahi/models/base.py
def perform_inference(self, image: np.ndarray):
    """This function should be implemented in a way that prediction should be performed using self.model and the
    prediction result should be set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted.
    """
    raise NotImplementedError()
set_device(device=None)

Sets the device pytorch should use for the model.

Parameters:

Name Type Description Default
device str | None

Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.

None
Source code in sahi/models/base.py
def set_device(self, device: str | None = None):
    """Sets the device pytorch should use for the model.

    Args:
        device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
    """

    self.device = select_device(device)
set_model(model, **kwargs)

This function should be implemented to instantiate a DetectionModel out of an already loaded model Args: model: Any Loaded model

Source code in sahi/models/base.py
def set_model(self, model: Any, **kwargs):
    """
    This function should be implemented to instantiate a DetectionModel out of an already loaded model
    Args:
        model: Any
            Loaded model
    """
    raise NotImplementedError()
unload_model()

Unloads the model from CPU/GPU.

Source code in sahi/models/base.py
def unload_model(self):
    """Unloads the model from CPU/GPU."""
    self.model = None
    empty_cuda_cache()

Mask

Init Mask from coco segmentation representation.

Parameters:

Name Type Description Default
segmentation

List[List] [ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

required
full_shape
list[int]

List[int] Size of the full image, should be in the form of [height, width]

required
shift_amount
list

List[int] To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
Source code in sahi/annotation.py
class Mask:
    """Init Mask from coco segmentation representation.

    Args:
        segmentation : List[List]
            [
                [x1, y1, x2, y2, x3, y3, ...],
                [x1, y1, x2, y2, x3, y3, ...],
                ...
            ]
        full_shape: List[int]
            Size of the full image, should be in the form of [height, width]
        shift_amount: List[int]
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
    """

    def __init__(
        self,
        segmentation: list[list[float]],
        full_shape: list[int],
        shift_amount: list = [0, 0],
    ):
        if full_shape is None:
            raise ValueError("full_shape must be provided")  # pyright: ignore[reportUnreachable]

        self.shift_x = shift_amount[0]
        self.shift_y = shift_amount[1]
        self.full_shape_height = full_shape[0]
        self.full_shape_width = full_shape[1]
        self.segmentation = segmentation

    @classmethod
    def from_float_mask(
        cls,
        mask: np.ndarray,
        full_shape: list[int],
        mask_threshold: float = 0.5,
        shift_amount: list = [0, 0],
    ):
        """
        Args:
            mask: np.ndarray of np.float elements
                Mask values between 0 and 1 (should have a shape of height*width)
            mask_threshold: float
                Value to threshold mask pixels between 0 and 1
            shift_amount: List
                To shift the box and mask predictions from sliced image
                to full sized image, should be in the form of [shift_x, shift_y]
            full_shape: List[int]
                Size of the full image after shifting, should be in the form of [height, width]
        """
        bool_mask = mask > mask_threshold
        return cls(
            segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    @classmethod
    def from_bool_mask(
        cls,
        bool_mask: np.ndarray,
        full_shape: list[int],
        shift_amount: list = [0, 0],
    ):
        """
        Args:
            bool_mask: np.ndarray with bool elements
                2D mask of object, should have a shape of height*width
            full_shape: List[int]
                Size of the full image, should be in the form of [height, width]
            shift_amount: List[int]
                To shift the box and mask predictions from sliced image to full
                sized image, should be in the form of [shift_x, shift_y]
        """
        return cls(
            segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    @property
    def bool_mask(self) -> np.ndarray:
        return get_bool_mask_from_coco_segmentation(
            self.segmentation, width=self.full_shape[1], height=self.full_shape[0]
        )

    @property
    def shape(self) -> list[int]:
        """Returns mask shape as [height, width]"""
        return [self.bool_mask.shape[0], self.bool_mask.shape[1]]

    @property
    def full_shape(self) -> list[int]:
        """Returns full mask shape after shifting as [height, width]"""
        return [self.full_shape_height, self.full_shape_width]

    @property
    def shift_amount(self):
        """Returns the shift amount of the mask slice as [shift_x, shift_y]"""
        return [self.shift_x, self.shift_y]

    def get_shifted_mask(self) -> Mask:
        # Confirm full_shape is specified
        if (self.full_shape_height is None) or (self.full_shape_width is None):
            raise ValueError("full_shape is None")
        shifted_segmentation = []
        for s in self.segmentation:
            xs = [min(self.shift_x + s[i], self.full_shape_width) for i in range(0, len(s) - 1, 2)]
            ys = [min(self.shift_y + s[i], self.full_shape_height) for i in range(1, len(s), 2)]
            shifted_segmentation.append([j for i in zip(xs, ys) for j in i])
        return Mask(
            segmentation=shifted_segmentation,
            shift_amount=[0, 0],
            full_shape=self.full_shape,
        )
Attributes
full_shape property

Returns full mask shape after shifting as [height, width]

shape property

Returns mask shape as [height, width]

shift_amount property

Returns the shift amount of the mask slice as [shift_x, shift_y]

Functions
from_bool_mask(bool_mask, full_shape, shift_amount=[0, 0]) classmethod

Parameters:

Name Type Description Default
bool_mask ndarray

np.ndarray with bool elements 2D mask of object, should have a shape of height*width

required
full_shape list[int]

List[int] Size of the full image, should be in the form of [height, width]

required
shift_amount list

List[int] To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
Source code in sahi/annotation.py
@classmethod
def from_bool_mask(
    cls,
    bool_mask: np.ndarray,
    full_shape: list[int],
    shift_amount: list = [0, 0],
):
    """
    Args:
        bool_mask: np.ndarray with bool elements
            2D mask of object, should have a shape of height*width
        full_shape: List[int]
            Size of the full image, should be in the form of [height, width]
        shift_amount: List[int]
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
    """
    return cls(
        segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
from_float_mask(mask, full_shape, mask_threshold=0.5, shift_amount=[0, 0]) classmethod

Parameters:

Name Type Description Default
mask ndarray

np.ndarray of np.float elements Mask values between 0 and 1 (should have a shape of height*width)

required
mask_threshold float

float Value to threshold mask pixels between 0 and 1

0.5
shift_amount list

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
full_shape list[int]

List[int] Size of the full image after shifting, should be in the form of [height, width]

required
Source code in sahi/annotation.py
@classmethod
def from_float_mask(
    cls,
    mask: np.ndarray,
    full_shape: list[int],
    mask_threshold: float = 0.5,
    shift_amount: list = [0, 0],
):
    """
    Args:
        mask: np.ndarray of np.float elements
            Mask values between 0 and 1 (should have a shape of height*width)
        mask_threshold: float
            Value to threshold mask pixels between 0 and 1
        shift_amount: List
            To shift the box and mask predictions from sliced image
            to full sized image, should be in the form of [shift_x, shift_y]
        full_shape: List[int]
            Size of the full image after shifting, should be in the form of [height, width]
    """
    bool_mask = mask > mask_threshold
    return cls(
        segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
        shift_amount=shift_amount,
        full_shape=full_shape,
    )

ObjectPrediction

Bases: ObjectAnnotation

Class for handling detection model predictions.

Source code in sahi/prediction.py
class ObjectPrediction(ObjectAnnotation):
    """Class for handling detection model predictions."""

    def __init__(
        self,
        bbox: list[int] | None = None,
        category_id: int | None = None,
        category_name: str | None = None,
        segmentation: list[list[float]] | None = None,
        score: float = 0.0,
        shift_amount: list[int] | None = [0, 0],
        full_shape: list[int] | None = None,
    ):
        """Creates ObjectPrediction from bbox, score, category_id, category_name, segmentation.

        Args:
            bbox: list
                [minx, miny, maxx, maxy]
            score: float
                Prediction score between 0 and 1
            category_id: int
                ID of the object category
            category_name: str
                Name of the object category
            segmentation: List[List]
                [
                    [x1, y1, x2, y2, x3, y3, ...],
                    [x1, y1, x2, y2, x3, y3, ...],
                    ...
                ]
            shift_amount: list
                To shift the box and mask predictions from sliced image
                to full sized image, should be in the form of [shift_x, shift_y]
            full_shape: list
                Size of the full image after shifting, should be in
                the form of [height, width]
        """
        self.score = PredictionScore(score)
        super().__init__(
            bbox=bbox,
            category_id=category_id,
            segmentation=segmentation,
            category_name=category_name,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    def get_shifted_object_prediction(self):
        """Returns shifted version ObjectPrediction.

        Shifts bbox and mask coords. Used for mapping sliced predictions over full image.
        """
        if self.mask:
            shifted_mask = self.mask.get_shifted_mask()
            return ObjectPrediction(
                bbox=self.bbox.get_shifted_box().to_xyxy(),
                category_id=self.category.id,
                score=self.score.value,
                segmentation=shifted_mask.segmentation,
                category_name=self.category.name,
                shift_amount=[0, 0],
                full_shape=shifted_mask.full_shape,
            )
        else:
            return ObjectPrediction(
                bbox=self.bbox.get_shifted_box().to_xyxy(),
                category_id=self.category.id,
                score=self.score.value,
                segmentation=None,
                category_name=self.category.name,
                shift_amount=[0, 0],
                full_shape=None,
            )

    def to_coco_prediction(self, image_id=None):
        """Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation."""
        if self.mask:
            coco_prediction = CocoPrediction.from_coco_segmentation(
                segmentation=self.mask.segmentation,
                category_id=self.category.id,
                category_name=self.category.name,
                score=self.score.value,
                image_id=image_id,
            )
        else:
            coco_prediction = CocoPrediction.from_coco_bbox(
                bbox=self.bbox.to_xywh(),
                category_id=self.category.id,
                category_name=self.category.name,
                score=self.score.value,
                image_id=image_id,
            )
        return coco_prediction

    def to_fiftyone_detection(self, image_height: int, image_width: int):
        """Returns fiftyone.Detection representation of ObjectPrediction."""
        try:
            import fiftyone as fo
        except ImportError:
            raise ImportError('Please run "pip install -U fiftyone" to install fiftyone first for fiftyone conversion.')

        x1, y1, x2, y2 = self.bbox.to_xyxy()
        rel_box = [x1 / image_width, y1 / image_height, (x2 - x1) / image_width, (y2 - y1) / image_height]
        fiftyone_detection = fo.Detection(label=self.category.name, bounding_box=rel_box, confidence=self.score.value)
        return fiftyone_detection

    def __repr__(self):
        return f"""ObjectPrediction<
    bbox: {self.bbox},
    mask: {self.mask},
    score: {self.score},
    category: {self.category}>"""
Functions
__init__(bbox=None, category_id=None, category_name=None, segmentation=None, score=0.0, shift_amount=[0, 0], full_shape=None)

Creates ObjectPrediction from bbox, score, category_id, category_name, segmentation.

Parameters:

Name Type Description Default
bbox list[int] | None

list [minx, miny, maxx, maxy]

None
score float

float Prediction score between 0 and 1

0.0
category_id int | None

int ID of the object category

None
category_name str | None

str Name of the object category

None
segmentation list[list[float]] | None

List[List] [ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

None
shift_amount list[int] | None

list To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
full_shape list[int] | None

list Size of the full image after shifting, should be in the form of [height, width]

None
Source code in sahi/prediction.py
def __init__(
    self,
    bbox: list[int] | None = None,
    category_id: int | None = None,
    category_name: str | None = None,
    segmentation: list[list[float]] | None = None,
    score: float = 0.0,
    shift_amount: list[int] | None = [0, 0],
    full_shape: list[int] | None = None,
):
    """Creates ObjectPrediction from bbox, score, category_id, category_name, segmentation.

    Args:
        bbox: list
            [minx, miny, maxx, maxy]
        score: float
            Prediction score between 0 and 1
        category_id: int
            ID of the object category
        category_name: str
            Name of the object category
        segmentation: List[List]
            [
                [x1, y1, x2, y2, x3, y3, ...],
                [x1, y1, x2, y2, x3, y3, ...],
                ...
            ]
        shift_amount: list
            To shift the box and mask predictions from sliced image
            to full sized image, should be in the form of [shift_x, shift_y]
        full_shape: list
            Size of the full image after shifting, should be in
            the form of [height, width]
    """
    self.score = PredictionScore(score)
    super().__init__(
        bbox=bbox,
        category_id=category_id,
        segmentation=segmentation,
        category_name=category_name,
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
get_shifted_object_prediction()

Returns shifted version ObjectPrediction.

Shifts bbox and mask coords. Used for mapping sliced predictions over full image.

Source code in sahi/prediction.py
def get_shifted_object_prediction(self):
    """Returns shifted version ObjectPrediction.

    Shifts bbox and mask coords. Used for mapping sliced predictions over full image.
    """
    if self.mask:
        shifted_mask = self.mask.get_shifted_mask()
        return ObjectPrediction(
            bbox=self.bbox.get_shifted_box().to_xyxy(),
            category_id=self.category.id,
            score=self.score.value,
            segmentation=shifted_mask.segmentation,
            category_name=self.category.name,
            shift_amount=[0, 0],
            full_shape=shifted_mask.full_shape,
        )
    else:
        return ObjectPrediction(
            bbox=self.bbox.get_shifted_box().to_xyxy(),
            category_id=self.category.id,
            score=self.score.value,
            segmentation=None,
            category_name=self.category.name,
            shift_amount=[0, 0],
            full_shape=None,
        )
to_coco_prediction(image_id=None)

Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation.

Source code in sahi/prediction.py
def to_coco_prediction(self, image_id=None):
    """Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation."""
    if self.mask:
        coco_prediction = CocoPrediction.from_coco_segmentation(
            segmentation=self.mask.segmentation,
            category_id=self.category.id,
            category_name=self.category.name,
            score=self.score.value,
            image_id=image_id,
        )
    else:
        coco_prediction = CocoPrediction.from_coco_bbox(
            bbox=self.bbox.to_xywh(),
            category_id=self.category.id,
            category_name=self.category.name,
            score=self.score.value,
            image_id=image_id,
        )
    return coco_prediction
to_fiftyone_detection(image_height, image_width)

Returns fiftyone.Detection representation of ObjectPrediction.

Source code in sahi/prediction.py
def to_fiftyone_detection(self, image_height: int, image_width: int):
    """Returns fiftyone.Detection representation of ObjectPrediction."""
    try:
        import fiftyone as fo
    except ImportError:
        raise ImportError('Please run "pip install -U fiftyone" to install fiftyone first for fiftyone conversion.')

    x1, y1, x2, y2 = self.bbox.to_xyxy()
    rel_box = [x1 / image_width, y1 / image_height, (x2 - x1) / image_width, (y2 - y1) / image_height]
    fiftyone_detection = fo.Detection(label=self.category.name, bounding_box=rel_box, confidence=self.score.value)
    return fiftyone_detection

Modules

annotation

Classes
BoundingBox dataclass

BoundingBox represents a rectangular region in 2D space, typically used for object detection annotations.

Attributes:

Name Type Description
box Tuple[float, float, float, float]

The bounding box coordinates in the format (minx, miny, maxx, maxy). - minx (float): Minimum x-coordinate (left). - miny (float): Minimum y-coordinate (top). - maxx (float): Maximum x-coordinate (right). - maxy (float): Maximum y-coordinate (bottom).

shift_amount Tuple[int, int]

The amount to shift the bounding box in the x and y directions. Defaults to (0, 0).

BoundingBox Usage Example

bbox = BoundingBox((10.0, 20.0, 50.0, 80.0))
area = bbox.area
expanded_bbox = bbox.get_expanded_box(ratio=0.2)
shifted_bbox = bbox.get_shifted_box()
coco_format = bbox.to_coco_bbox()
Source code in sahi/annotation.py
@dataclass(frozen=True)
class BoundingBox:
    """BoundingBox represents a rectangular region in 2D space, typically used for object detection annotations.

    Attributes:
        box (Tuple[float, float, float, float]): The bounding box coordinates in the format (minx, miny, maxx, maxy).
            - minx (float): Minimum x-coordinate (left).
            - miny (float): Minimum y-coordinate (top).
            - maxx (float): Maximum x-coordinate (right).
            - maxy (float): Maximum y-coordinate (bottom).
        shift_amount (Tuple[int, int], optional): The amount to shift the bounding box in the x and y directions.
            Defaults to (0, 0).

    !!! example "BoundingBox Usage Example"
        ```python
        bbox = BoundingBox((10.0, 20.0, 50.0, 80.0))
        area = bbox.area
        expanded_bbox = bbox.get_expanded_box(ratio=0.2)
        shifted_bbox = bbox.get_shifted_box()
        coco_format = bbox.to_coco_bbox()
        ```
    """

    box: tuple[float, float, float, float] | list[float]
    shift_amount: tuple[int, int] = (0, 0)

    def __post_init__(self):
        if len(self.box) != 4 or any(coord < 0 for coord in self.box):
            raise ValueError("box must be 4 non-negative floats: [minx, miny, maxx, maxy]")
        if len(self.shift_amount) != 2:
            raise ValueError("shift_amount must be 2 integers: [shift_x, shift_y]")

    @property
    def minx(self):
        return self.box[0]

    @property
    def miny(self):
        return self.box[1]

    @property
    def maxx(self):
        return self.box[2]

    @property
    def maxy(self):
        return self.box[3]

    @property
    def shift_x(self):
        return self.shift_amount[0]

    @property
    def shift_y(self):
        return self.shift_amount[1]

    @property
    def area(self):
        return (self.maxx - self.minx) * (self.maxy - self.miny)

    def get_expanded_box(self, ratio: float = 0.1, max_x: int | None = None, max_y: int | None = None):
        """Returns an expanded bounding box by increasing its size by a given ratio. The expansion is applied equally in
        all directions. Optionally, the expanded box can be clipped to maximum x and y boundaries.

        Args:
            ratio (float, optional): The proportion by which to expand the box size.
                Default is 0.1 (10%).
            max_x (int, optional): The maximum allowed x-coordinate for the expanded box.
                If None, no maximum is applied.
            max_y (int, optional): The maximum allowed y-coordinate for the expanded box.
                If None, no maximum is applied.

        Returns:
            BoundingBox: A new BoundingBox instance representing the expanded box.
        """

        w = self.maxx - self.minx
        h = self.maxy - self.miny
        y_mar = int(h * ratio)
        x_mar = int(w * ratio)
        maxx = min(max_x, self.maxx + x_mar) if max_x else self.maxx + x_mar
        minx = max(0, self.minx - x_mar)
        maxy = min(max_y, self.maxy + y_mar) if max_y else self.maxy + y_mar
        miny = max(0, self.miny - y_mar)
        box: list[float] = [minx, miny, maxx, maxy]
        return BoundingBox(box)

    def to_xywh(self):
        """Returns [xmin, ymin, width, height]

        Returns:
            List[float]: A list containing the bounding box in the format [xmin, ymin, width, height].
        """

        return [self.minx, self.miny, self.maxx - self.minx, self.maxy - self.miny]

    def to_coco_bbox(self):
        """
        Returns the bounding box in COCO format: [xmin, ymin, width, height]

        Returns:
            List[float]: A list containing the bounding box in COCO format.
        """
        return self.to_xywh()

    def to_xyxy(self):
        """
        Returns: [xmin, ymin, xmax, ymax]

        Returns:
            List[float]: A list containing the bounding box in the format [xmin, ymin, xmax, ymax].
        """
        return [self.minx, self.miny, self.maxx, self.maxy]

    def to_voc_bbox(self):
        """
        Returns the bounding box in VOC format: [xmin, ymin, xmax, ymax]

        Returns:
            List[float]: A list containing the bounding box in VOC format.
        """
        return self.to_xyxy()

    def get_shifted_box(self):
        """Returns shifted BoundingBox.

        Returns:
            BoundingBox: A new BoundingBox instance representing the shifted box.
        """
        box = [
            self.minx + self.shift_x,
            self.miny + self.shift_y,
            self.maxx + self.shift_x,
            self.maxy + self.shift_y,
        ]
        return BoundingBox(box)

    def __repr__(self):
        return (
            f"BoundingBox: <{(self.minx, self.miny, self.maxx, self.maxy)}, "
            f"w: {self.maxx - self.minx}, h: {self.maxy - self.miny}>"
        )
Functions
get_expanded_box(ratio=0.1, max_x=None, max_y=None)

Returns an expanded bounding box by increasing its size by a given ratio. The expansion is applied equally in all directions. Optionally, the expanded box can be clipped to maximum x and y boundaries.

Parameters:

Name Type Description Default
ratio float

The proportion by which to expand the box size. Default is 0.1 (10%).

0.1
max_x int

The maximum allowed x-coordinate for the expanded box. If None, no maximum is applied.

None
max_y int

The maximum allowed y-coordinate for the expanded box. If None, no maximum is applied.

None

Returns:

Name Type Description
BoundingBox

A new BoundingBox instance representing the expanded box.

Source code in sahi/annotation.py
def get_expanded_box(self, ratio: float = 0.1, max_x: int | None = None, max_y: int | None = None):
    """Returns an expanded bounding box by increasing its size by a given ratio. The expansion is applied equally in
    all directions. Optionally, the expanded box can be clipped to maximum x and y boundaries.

    Args:
        ratio (float, optional): The proportion by which to expand the box size.
            Default is 0.1 (10%).
        max_x (int, optional): The maximum allowed x-coordinate for the expanded box.
            If None, no maximum is applied.
        max_y (int, optional): The maximum allowed y-coordinate for the expanded box.
            If None, no maximum is applied.

    Returns:
        BoundingBox: A new BoundingBox instance representing the expanded box.
    """

    w = self.maxx - self.minx
    h = self.maxy - self.miny
    y_mar = int(h * ratio)
    x_mar = int(w * ratio)
    maxx = min(max_x, self.maxx + x_mar) if max_x else self.maxx + x_mar
    minx = max(0, self.minx - x_mar)
    maxy = min(max_y, self.maxy + y_mar) if max_y else self.maxy + y_mar
    miny = max(0, self.miny - y_mar)
    box: list[float] = [minx, miny, maxx, maxy]
    return BoundingBox(box)
get_shifted_box()

Returns shifted BoundingBox.

Returns:

Name Type Description
BoundingBox

A new BoundingBox instance representing the shifted box.

Source code in sahi/annotation.py
def get_shifted_box(self):
    """Returns shifted BoundingBox.

    Returns:
        BoundingBox: A new BoundingBox instance representing the shifted box.
    """
    box = [
        self.minx + self.shift_x,
        self.miny + self.shift_y,
        self.maxx + self.shift_x,
        self.maxy + self.shift_y,
    ]
    return BoundingBox(box)
to_coco_bbox()

Returns the bounding box in COCO format: [xmin, ymin, width, height]

Returns:

Type Description

List[float]: A list containing the bounding box in COCO format.

Source code in sahi/annotation.py
def to_coco_bbox(self):
    """
    Returns the bounding box in COCO format: [xmin, ymin, width, height]

    Returns:
        List[float]: A list containing the bounding box in COCO format.
    """
    return self.to_xywh()
to_voc_bbox()

Returns the bounding box in VOC format: [xmin, ymin, xmax, ymax]

Returns:

Type Description

List[float]: A list containing the bounding box in VOC format.

Source code in sahi/annotation.py
def to_voc_bbox(self):
    """
    Returns the bounding box in VOC format: [xmin, ymin, xmax, ymax]

    Returns:
        List[float]: A list containing the bounding box in VOC format.
    """
    return self.to_xyxy()
to_xywh()

Returns [xmin, ymin, width, height]

Returns:

Type Description

List[float]: A list containing the bounding box in the format [xmin, ymin, width, height].

Source code in sahi/annotation.py
def to_xywh(self):
    """Returns [xmin, ymin, width, height]

    Returns:
        List[float]: A list containing the bounding box in the format [xmin, ymin, width, height].
    """

    return [self.minx, self.miny, self.maxx - self.minx, self.maxy - self.miny]
to_xyxy()

Returns: [xmin, ymin, xmax, ymax]

Returns:

Type Description

List[float]: A list containing the bounding box in the format [xmin, ymin, xmax, ymax].

Source code in sahi/annotation.py
def to_xyxy(self):
    """
    Returns: [xmin, ymin, xmax, ymax]

    Returns:
        List[float]: A list containing the bounding box in the format [xmin, ymin, xmax, ymax].
    """
    return [self.minx, self.miny, self.maxx, self.maxy]
Category dataclass

Category of the annotation.

Attributes:

Name Type Description
id int

Unique identifier for the category.

name str

Name of the category.

Source code in sahi/annotation.py
@dataclass(frozen=True)
class Category:
    """Category of the annotation.

    Attributes:
        id (int): Unique identifier for the category.
        name (str): Name of the category.
    """

    id: int
    name: str

    def __post_init__(self):
        if not isinstance(self.id, int):
            raise TypeError("id should be integer")
        if not isinstance(self.name, str):
            raise TypeError("name should be string")

    def __repr__(self):
        return f"Category: <id: {self.id}, name: {self.name}>"
Mask

Init Mask from coco segmentation representation.

Parameters:

Name Type Description Default
segmentation

List[List] [ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

required
full_shape list[int]

List[int] Size of the full image, should be in the form of [height, width]

required
shift_amount list

List[int] To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
Source code in sahi/annotation.py
class Mask:
    """Init Mask from coco segmentation representation.

    Args:
        segmentation : List[List]
            [
                [x1, y1, x2, y2, x3, y3, ...],
                [x1, y1, x2, y2, x3, y3, ...],
                ...
            ]
        full_shape: List[int]
            Size of the full image, should be in the form of [height, width]
        shift_amount: List[int]
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
    """

    def __init__(
        self,
        segmentation: list[list[float]],
        full_shape: list[int],
        shift_amount: list = [0, 0],
    ):
        if full_shape is None:
            raise ValueError("full_shape must be provided")  # pyright: ignore[reportUnreachable]

        self.shift_x = shift_amount[0]
        self.shift_y = shift_amount[1]
        self.full_shape_height = full_shape[0]
        self.full_shape_width = full_shape[1]
        self.segmentation = segmentation

    @classmethod
    def from_float_mask(
        cls,
        mask: np.ndarray,
        full_shape: list[int],
        mask_threshold: float = 0.5,
        shift_amount: list = [0, 0],
    ):
        """
        Args:
            mask: np.ndarray of np.float elements
                Mask values between 0 and 1 (should have a shape of height*width)
            mask_threshold: float
                Value to threshold mask pixels between 0 and 1
            shift_amount: List
                To shift the box and mask predictions from sliced image
                to full sized image, should be in the form of [shift_x, shift_y]
            full_shape: List[int]
                Size of the full image after shifting, should be in the form of [height, width]
        """
        bool_mask = mask > mask_threshold
        return cls(
            segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    @classmethod
    def from_bool_mask(
        cls,
        bool_mask: np.ndarray,
        full_shape: list[int],
        shift_amount: list = [0, 0],
    ):
        """
        Args:
            bool_mask: np.ndarray with bool elements
                2D mask of object, should have a shape of height*width
            full_shape: List[int]
                Size of the full image, should be in the form of [height, width]
            shift_amount: List[int]
                To shift the box and mask predictions from sliced image to full
                sized image, should be in the form of [shift_x, shift_y]
        """
        return cls(
            segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    @property
    def bool_mask(self) -> np.ndarray:
        return get_bool_mask_from_coco_segmentation(
            self.segmentation, width=self.full_shape[1], height=self.full_shape[0]
        )

    @property
    def shape(self) -> list[int]:
        """Returns mask shape as [height, width]"""
        return [self.bool_mask.shape[0], self.bool_mask.shape[1]]

    @property
    def full_shape(self) -> list[int]:
        """Returns full mask shape after shifting as [height, width]"""
        return [self.full_shape_height, self.full_shape_width]

    @property
    def shift_amount(self):
        """Returns the shift amount of the mask slice as [shift_x, shift_y]"""
        return [self.shift_x, self.shift_y]

    def get_shifted_mask(self) -> Mask:
        # Confirm full_shape is specified
        if (self.full_shape_height is None) or (self.full_shape_width is None):
            raise ValueError("full_shape is None")
        shifted_segmentation = []
        for s in self.segmentation:
            xs = [min(self.shift_x + s[i], self.full_shape_width) for i in range(0, len(s) - 1, 2)]
            ys = [min(self.shift_y + s[i], self.full_shape_height) for i in range(1, len(s), 2)]
            shifted_segmentation.append([j for i in zip(xs, ys) for j in i])
        return Mask(
            segmentation=shifted_segmentation,
            shift_amount=[0, 0],
            full_shape=self.full_shape,
        )
Attributes
full_shape property

Returns full mask shape after shifting as [height, width]

shape property

Returns mask shape as [height, width]

shift_amount property

Returns the shift amount of the mask slice as [shift_x, shift_y]

Functions
from_bool_mask(bool_mask, full_shape, shift_amount=[0, 0]) classmethod

Parameters:

Name Type Description Default
bool_mask ndarray

np.ndarray with bool elements 2D mask of object, should have a shape of height*width

required
full_shape list[int]

List[int] Size of the full image, should be in the form of [height, width]

required
shift_amount list

List[int] To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
Source code in sahi/annotation.py
@classmethod
def from_bool_mask(
    cls,
    bool_mask: np.ndarray,
    full_shape: list[int],
    shift_amount: list = [0, 0],
):
    """
    Args:
        bool_mask: np.ndarray with bool elements
            2D mask of object, should have a shape of height*width
        full_shape: List[int]
            Size of the full image, should be in the form of [height, width]
        shift_amount: List[int]
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
    """
    return cls(
        segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
from_float_mask(mask, full_shape, mask_threshold=0.5, shift_amount=[0, 0]) classmethod

Parameters:

Name Type Description Default
mask ndarray

np.ndarray of np.float elements Mask values between 0 and 1 (should have a shape of height*width)

required
mask_threshold float

float Value to threshold mask pixels between 0 and 1

0.5
shift_amount list

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
full_shape list[int]

List[int] Size of the full image after shifting, should be in the form of [height, width]

required
Source code in sahi/annotation.py
@classmethod
def from_float_mask(
    cls,
    mask: np.ndarray,
    full_shape: list[int],
    mask_threshold: float = 0.5,
    shift_amount: list = [0, 0],
):
    """
    Args:
        mask: np.ndarray of np.float elements
            Mask values between 0 and 1 (should have a shape of height*width)
        mask_threshold: float
            Value to threshold mask pixels between 0 and 1
        shift_amount: List
            To shift the box and mask predictions from sliced image
            to full sized image, should be in the form of [shift_x, shift_y]
        full_shape: List[int]
            Size of the full image after shifting, should be in the form of [height, width]
    """
    bool_mask = mask > mask_threshold
    return cls(
        segmentation=get_coco_segmentation_from_bool_mask(bool_mask),
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
ObjectAnnotation

All about an annotation such as Mask, Category, BoundingBox.

Source code in sahi/annotation.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
class ObjectAnnotation:
    """All about an annotation such as Mask, Category, BoundingBox."""

    def __init__(
        self,
        bbox: list[int] | None = None,
        segmentation: np.ndarray | None = None,
        category_id: int | None = None,
        category_name: str | None = None,
        shift_amount: list[int] | None = [0, 0],
        full_shape: list[int] | None = None,
    ):
        """
        Args:
            bbox: List
                [minx, miny, maxx, maxy]
            segmentation: List[List]
                [
                    [x1, y1, x2, y2, x3, y3, ...],
                    [x1, y1, x2, y2, x3, y3, ...],
                    ...
                ]
            category_id: int
                ID of the object category
            category_name: str
                Name of the object category
            shift_amount: List
                To shift the box and mask predictions from sliced image
                to full sized image, should be in the form of [shift_x, shift_y]
            full_shape: List
                Size of the full image after shifting, should be in
                the form of [height, width]
        """
        if not isinstance(category_id, int):
            raise ValueError("category_id must be an integer")
        if (bbox is None) and (segmentation is None):
            raise ValueError("you must provide a bbox or segmentation")

        self.mask: Mask | None = None
        if segmentation is not None:
            self.mask = Mask(
                segmentation=segmentation,
                shift_amount=shift_amount,
                full_shape=full_shape,
            )
            bbox_from_segmentation = get_bbox_from_coco_segmentation(segmentation)
            # https://github.com/obss/sahi/issues/235
            if bbox_from_segmentation is not None:
                bbox = bbox_from_segmentation
            else:
                raise ValueError("Invalid segmentation mask.")

        # if bbox is a numpy object, convert it to python List[float]
        if type(bbox).__module__ == "numpy":
            bbox = copy.deepcopy(bbox).tolist()

        # make sure bbox coords lie inside [0, image_size]
        xmin = max(bbox[0], 0)
        ymin = max(bbox[1], 0)
        if full_shape:
            xmax = min(bbox[2], full_shape[1])
            ymax = min(bbox[3], full_shape[0])
        else:
            xmax = bbox[2]
            ymax = bbox[3]
        bbox = [xmin, ymin, xmax, ymax]
        # set bbox
        self.bbox = BoundingBox(bbox, shift_amount)

        category_name = category_name if category_name else str(category_id)
        self.category = Category(
            id=category_id,
            name=category_name,
        )

        self.merged = None

    @classmethod
    def from_bool_mask(
        cls,
        bool_mask,
        category_id: int | None = None,
        category_name: str | None = None,
        shift_amount: list[int] | None = [0, 0],
        full_shape: list[int] | None = None,
    ):
        """Creates ObjectAnnotation from bool_mask (2D np.ndarray)

        Args:
            bool_mask: np.ndarray with bool elements
                2D mask of object, should have a shape of height*width
            category_id: int
                ID of the object category
            category_name: str
                Name of the object category
            full_shape: List
                Size of the full image, should be in the form of [height, width]
            shift_amount: List
                To shift the box and mask predictions from sliced image to full
                sized image, should be in the form of [shift_x, shift_y]
        """
        segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
        return cls(
            category_id=category_id,
            segmentation=segmentation,
            category_name=category_name,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    @classmethod
    def from_coco_segmentation(
        cls,
        segmentation,
        full_shape: list[int],
        category_id: int | None = None,
        category_name: str | None = None,
        shift_amount: list[int] | None = [0, 0],
    ):
        """
        Creates ObjectAnnotation from coco segmentation:
        [
            [x1, y1, x2, y2, x3, y3, ...],
            [x1, y1, x2, y2, x3, y3, ...],
            ...
        ]

        Args:
            segmentation: List[List]
                [
                    [x1, y1, x2, y2, x3, y3, ...],
                    [x1, y1, x2, y2, x3, y3, ...],
                    ...
                ]
            category_id: int
                ID of the object category
            category_name: str
                Name of the object category
            full_shape: List
                Size of the full image, should be in the form of [height, width]
            shift_amount: List
                To shift the box and mask predictions from sliced image to full
                sized image, should be in the form of [shift_x, shift_y]
        """
        return cls(
            category_id=category_id,
            segmentation=segmentation,
            category_name=category_name,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    @classmethod
    def from_coco_bbox(
        cls,
        bbox: list[int],
        category_id: int | None = None,
        category_name: str | None = None,
        shift_amount: list[int] | None = [0, 0],
        full_shape: list[int] | None = None,
    ):
        """Creates ObjectAnnotation from coco bbox [minx, miny, width, height]

        Args:
            bbox: List
                [minx, miny, width, height]
            category_id: int
                ID of the object category
            category_name: str
                Name of the object category
            full_shape: List
                Size of the full image, should be in the form of [height, width]
            shift_amount: List
                To shift the box and mask predictions from sliced image to full
                sized image, should be in the form of [shift_x, shift_y]
        """
        xmin = bbox[0]
        ymin = bbox[1]
        xmax = bbox[0] + bbox[2]
        ymax = bbox[1] + bbox[3]
        bbox = [xmin, ymin, xmax, ymax]
        return cls(
            category_id=category_id,
            bbox=bbox,
            category_name=category_name,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    @classmethod
    def from_coco_annotation_dict(
        cls,
        annotation_dict: dict,
        full_shape: list[int],
        category_name: str | None = None,
        shift_amount: list[int] | None = [0, 0],
    ):
        """Creates ObjectAnnotation object from category name and COCO formatted annotation dict (with fields "bbox",
        "segmentation", "category_id").

        Args:
            annotation_dict: dict
                COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")
            category_name: str
                Category name of the annotation
            full_shape: List
                Size of the full image, should be in the form of [height, width]
            shift_amount: List
                To shift the box and mask predictions from sliced image to full
                sized image, should be in the form of [shift_x, shift_y]
        """
        if annotation_dict["segmentation"]:
            return cls.from_coco_segmentation(
                segmentation=annotation_dict["segmentation"],
                category_id=annotation_dict["category_id"],
                category_name=category_name,
                shift_amount=shift_amount,
                full_shape=full_shape,
            )
        else:
            return cls.from_coco_bbox(
                bbox=annotation_dict["bbox"],
                category_id=annotation_dict["category_id"],
                category_name=category_name,
                shift_amount=shift_amount,
                full_shape=full_shape,
            )

    @classmethod
    def from_shapely_annotation(
        cls,
        annotation: ShapelyAnnotation,
        full_shape: list[int],
        category_id: int | None = None,
        category_name: str | None = None,
        shift_amount: list[int] | None = [0, 0],
    ):
        """Creates ObjectAnnotation from shapely_utils.ShapelyAnnotation.

        Args:
            annotation: shapely_utils.ShapelyAnnotation
            category_id: int
                ID of the object category
            category_name: str
                Name of the object category
            full_shape: List
                Size of the full image, should be in the form of [height, width]
            shift_amount: List
                To shift the box and mask predictions from sliced image to full
                sized image, should be in the form of [shift_x, shift_y]
        """
        return cls(
            category_id=category_id,
            segmentation=annotation.to_coco_segmentation(),
            category_name=category_name,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    @classmethod
    def from_imantics_annotation(
        cls,
        annotation,
        shift_amount: list[int] | None = [0, 0],
        full_shape: list[int] | None = None,
    ):
        """Creates ObjectAnnotation from imantics.annotation.Annotation.

        Args:
            annotation: imantics.annotation.Annotation
            shift_amount: List
                To shift the box and mask predictions from sliced image to full
                sized image, should be in the form of [shift_x, shift_y]
            full_shape: List
                Size of the full image, should be in the form of [height, width]
        """
        return cls(
            category_id=annotation.category.id,
            bool_mask=annotation.mask.array,
            category_name=annotation.category.name,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    def to_coco_annotation(self) -> CocoAnnotation:
        """Returns sahi.utils.coco.CocoAnnotation representation of ObjectAnnotation."""
        if self.mask:
            coco_annotation = CocoAnnotation.from_coco_segmentation(
                segmentation=self.mask.segmentation,
                category_id=self.category.id,
                category_name=self.category.name,
            )
        else:
            coco_annotation = CocoAnnotation.from_coco_bbox(
                bbox=self.bbox.to_xywh(),
                category_id=self.category.id,
                category_name=self.category.name,
            )
        return coco_annotation

    def to_coco_prediction(self) -> CocoPrediction:
        """Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation."""
        if self.mask:
            coco_prediction = CocoPrediction.from_coco_segmentation(
                segmentation=self.mask.segmentation,
                category_id=self.category.id,
                category_name=self.category.name,
                score=1,
            )
        else:
            coco_prediction = CocoPrediction.from_coco_bbox(
                bbox=self.bbox.to_xywh(),
                category_id=self.category.id,
                category_name=self.category.name,
                score=1,
            )
        return coco_prediction

    def to_shapely_annotation(self) -> ShapelyAnnotation:
        """Returns sahi.utils.shapely.ShapelyAnnotation representation of ObjectAnnotation."""
        if self.mask:
            shapely_annotation = ShapelyAnnotation.from_coco_segmentation(
                segmentation=self.mask.segmentation,
            )
        else:
            shapely_annotation = ShapelyAnnotation.from_coco_bbox(
                bbox=self.bbox.to_xywh(),
            )
        return shapely_annotation

    def to_imantics_annotation(self):
        """Returns imantics.annotation.Annotation representation of ObjectAnnotation."""
        try:
            import imantics
        except ImportError:
            raise ImportError('Please run "pip install -U imantics" to install imantics first for imantics conversion.')

        imantics_category = imantics.Category(id=self.category.id, name=self.category.name)
        if self.mask is not None:
            imantics_mask = imantics.Mask.create(self.mask.bool_mask)
            imantics_annotation = imantics.annotation.Annotation.from_mask(
                mask=imantics_mask, category=imantics_category
            )
        else:
            imantics_bbox = imantics.BBox.create(self.bbox.to_xyxy())
            imantics_annotation = imantics.annotation.Annotation.from_bbox(
                bbox=imantics_bbox, category=imantics_category
            )
        return imantics_annotation

    def deepcopy(self):
        """
        Returns: deepcopy of current ObjectAnnotation instance
        """
        return copy.deepcopy(self)

    @classmethod
    def get_empty_mask(cls):
        return Mask(bool_mask=None)

    def get_shifted_object_annotation(self):
        if self.mask:
            shifted_mask = self.mask.get_shifted_mask()
            return ObjectAnnotation(
                bbox=self.bbox.get_shifted_box().to_xyxy(),
                category_id=self.category.id,
                segmentation=shifted_mask.segmentation,
                category_name=self.category.name,
                shift_amount=[0, 0],
                full_shape=shifted_mask.full_shape,
            )
        else:
            return ObjectAnnotation(
                bbox=self.bbox.get_shifted_box().to_xyxy(),
                category_id=self.category.id,
                bool_mask=None,
                category_name=self.category.name,
                shift_amount=[0, 0],
                full_shape=None,
            )

    def __repr__(self):
        return f"""ObjectAnnotation<
    bbox: {self.bbox},
    mask: {self.mask},
    category: {self.category}>"""
Functions
__init__(bbox=None, segmentation=None, category_id=None, category_name=None, shift_amount=[0, 0], full_shape=None)

Parameters:

Name Type Description Default
bbox list[int] | None

List [minx, miny, maxx, maxy]

None
segmentation ndarray | None

List[List] [ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

None
category_id int | None

int ID of the object category

None
category_name str | None

str Name of the object category

None
shift_amount list[int] | None

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
full_shape list[int] | None

List Size of the full image after shifting, should be in the form of [height, width]

None
Source code in sahi/annotation.py
def __init__(
    self,
    bbox: list[int] | None = None,
    segmentation: np.ndarray | None = None,
    category_id: int | None = None,
    category_name: str | None = None,
    shift_amount: list[int] | None = [0, 0],
    full_shape: list[int] | None = None,
):
    """
    Args:
        bbox: List
            [minx, miny, maxx, maxy]
        segmentation: List[List]
            [
                [x1, y1, x2, y2, x3, y3, ...],
                [x1, y1, x2, y2, x3, y3, ...],
                ...
            ]
        category_id: int
            ID of the object category
        category_name: str
            Name of the object category
        shift_amount: List
            To shift the box and mask predictions from sliced image
            to full sized image, should be in the form of [shift_x, shift_y]
        full_shape: List
            Size of the full image after shifting, should be in
            the form of [height, width]
    """
    if not isinstance(category_id, int):
        raise ValueError("category_id must be an integer")
    if (bbox is None) and (segmentation is None):
        raise ValueError("you must provide a bbox or segmentation")

    self.mask: Mask | None = None
    if segmentation is not None:
        self.mask = Mask(
            segmentation=segmentation,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )
        bbox_from_segmentation = get_bbox_from_coco_segmentation(segmentation)
        # https://github.com/obss/sahi/issues/235
        if bbox_from_segmentation is not None:
            bbox = bbox_from_segmentation
        else:
            raise ValueError("Invalid segmentation mask.")

    # if bbox is a numpy object, convert it to python List[float]
    if type(bbox).__module__ == "numpy":
        bbox = copy.deepcopy(bbox).tolist()

    # make sure bbox coords lie inside [0, image_size]
    xmin = max(bbox[0], 0)
    ymin = max(bbox[1], 0)
    if full_shape:
        xmax = min(bbox[2], full_shape[1])
        ymax = min(bbox[3], full_shape[0])
    else:
        xmax = bbox[2]
        ymax = bbox[3]
    bbox = [xmin, ymin, xmax, ymax]
    # set bbox
    self.bbox = BoundingBox(bbox, shift_amount)

    category_name = category_name if category_name else str(category_id)
    self.category = Category(
        id=category_id,
        name=category_name,
    )

    self.merged = None
deepcopy()

Returns: deepcopy of current ObjectAnnotation instance

Source code in sahi/annotation.py
def deepcopy(self):
    """
    Returns: deepcopy of current ObjectAnnotation instance
    """
    return copy.deepcopy(self)
from_bool_mask(bool_mask, category_id=None, category_name=None, shift_amount=[0, 0], full_shape=None) classmethod

Creates ObjectAnnotation from bool_mask (2D np.ndarray)

Parameters:

Name Type Description Default
bool_mask

np.ndarray with bool elements 2D mask of object, should have a shape of height*width

required
category_id int | None

int ID of the object category

None
category_name str | None

str Name of the object category

None
full_shape list[int] | None

List Size of the full image, should be in the form of [height, width]

None
shift_amount list[int] | None

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
Source code in sahi/annotation.py
@classmethod
def from_bool_mask(
    cls,
    bool_mask,
    category_id: int | None = None,
    category_name: str | None = None,
    shift_amount: list[int] | None = [0, 0],
    full_shape: list[int] | None = None,
):
    """Creates ObjectAnnotation from bool_mask (2D np.ndarray)

    Args:
        bool_mask: np.ndarray with bool elements
            2D mask of object, should have a shape of height*width
        category_id: int
            ID of the object category
        category_name: str
            Name of the object category
        full_shape: List
            Size of the full image, should be in the form of [height, width]
        shift_amount: List
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
    """
    segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
    return cls(
        category_id=category_id,
        segmentation=segmentation,
        category_name=category_name,
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
from_coco_annotation_dict(annotation_dict, full_shape, category_name=None, shift_amount=[0, 0]) classmethod

Creates ObjectAnnotation object from category name and COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id").

Parameters:

Name Type Description Default
annotation_dict dict

dict COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")

required
category_name str | None

str Category name of the annotation

None
full_shape list[int]

List Size of the full image, should be in the form of [height, width]

required
shift_amount list[int] | None

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
Source code in sahi/annotation.py
@classmethod
def from_coco_annotation_dict(
    cls,
    annotation_dict: dict,
    full_shape: list[int],
    category_name: str | None = None,
    shift_amount: list[int] | None = [0, 0],
):
    """Creates ObjectAnnotation object from category name and COCO formatted annotation dict (with fields "bbox",
    "segmentation", "category_id").

    Args:
        annotation_dict: dict
            COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")
        category_name: str
            Category name of the annotation
        full_shape: List
            Size of the full image, should be in the form of [height, width]
        shift_amount: List
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
    """
    if annotation_dict["segmentation"]:
        return cls.from_coco_segmentation(
            segmentation=annotation_dict["segmentation"],
            category_id=annotation_dict["category_id"],
            category_name=category_name,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )
    else:
        return cls.from_coco_bbox(
            bbox=annotation_dict["bbox"],
            category_id=annotation_dict["category_id"],
            category_name=category_name,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )
from_coco_bbox(bbox, category_id=None, category_name=None, shift_amount=[0, 0], full_shape=None) classmethod

Creates ObjectAnnotation from coco bbox [minx, miny, width, height]

Parameters:

Name Type Description Default
bbox list[int]

List [minx, miny, width, height]

required
category_id int | None

int ID of the object category

None
category_name str | None

str Name of the object category

None
full_shape list[int] | None

List Size of the full image, should be in the form of [height, width]

None
shift_amount list[int] | None

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
Source code in sahi/annotation.py
@classmethod
def from_coco_bbox(
    cls,
    bbox: list[int],
    category_id: int | None = None,
    category_name: str | None = None,
    shift_amount: list[int] | None = [0, 0],
    full_shape: list[int] | None = None,
):
    """Creates ObjectAnnotation from coco bbox [minx, miny, width, height]

    Args:
        bbox: List
            [minx, miny, width, height]
        category_id: int
            ID of the object category
        category_name: str
            Name of the object category
        full_shape: List
            Size of the full image, should be in the form of [height, width]
        shift_amount: List
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
    """
    xmin = bbox[0]
    ymin = bbox[1]
    xmax = bbox[0] + bbox[2]
    ymax = bbox[1] + bbox[3]
    bbox = [xmin, ymin, xmax, ymax]
    return cls(
        category_id=category_id,
        bbox=bbox,
        category_name=category_name,
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
from_coco_segmentation(segmentation, full_shape, category_id=None, category_name=None, shift_amount=[0, 0]) classmethod

Creates ObjectAnnotation from coco segmentation: [ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

Parameters:

Name Type Description Default
segmentation

List[List] [ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

required
category_id int | None

int ID of the object category

None
category_name str | None

str Name of the object category

None
full_shape list[int]

List Size of the full image, should be in the form of [height, width]

required
shift_amount list[int] | None

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
Source code in sahi/annotation.py
@classmethod
def from_coco_segmentation(
    cls,
    segmentation,
    full_shape: list[int],
    category_id: int | None = None,
    category_name: str | None = None,
    shift_amount: list[int] | None = [0, 0],
):
    """
    Creates ObjectAnnotation from coco segmentation:
    [
        [x1, y1, x2, y2, x3, y3, ...],
        [x1, y1, x2, y2, x3, y3, ...],
        ...
    ]

    Args:
        segmentation: List[List]
            [
                [x1, y1, x2, y2, x3, y3, ...],
                [x1, y1, x2, y2, x3, y3, ...],
                ...
            ]
        category_id: int
            ID of the object category
        category_name: str
            Name of the object category
        full_shape: List
            Size of the full image, should be in the form of [height, width]
        shift_amount: List
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
    """
    return cls(
        category_id=category_id,
        segmentation=segmentation,
        category_name=category_name,
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
from_imantics_annotation(annotation, shift_amount=[0, 0], full_shape=None) classmethod

Creates ObjectAnnotation from imantics.annotation.Annotation.

Parameters:

Name Type Description Default
annotation

imantics.annotation.Annotation

required
shift_amount list[int] | None

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
full_shape list[int] | None

List Size of the full image, should be in the form of [height, width]

None
Source code in sahi/annotation.py
@classmethod
def from_imantics_annotation(
    cls,
    annotation,
    shift_amount: list[int] | None = [0, 0],
    full_shape: list[int] | None = None,
):
    """Creates ObjectAnnotation from imantics.annotation.Annotation.

    Args:
        annotation: imantics.annotation.Annotation
        shift_amount: List
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
        full_shape: List
            Size of the full image, should be in the form of [height, width]
    """
    return cls(
        category_id=annotation.category.id,
        bool_mask=annotation.mask.array,
        category_name=annotation.category.name,
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
from_shapely_annotation(annotation, full_shape, category_id=None, category_name=None, shift_amount=[0, 0]) classmethod

Creates ObjectAnnotation from shapely_utils.ShapelyAnnotation.

Parameters:

Name Type Description Default
annotation ShapelyAnnotation

shapely_utils.ShapelyAnnotation

required
category_id int | None

int ID of the object category

None
category_name str | None

str Name of the object category

None
full_shape list[int]

List Size of the full image, should be in the form of [height, width]

required
shift_amount list[int] | None

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
Source code in sahi/annotation.py
@classmethod
def from_shapely_annotation(
    cls,
    annotation: ShapelyAnnotation,
    full_shape: list[int],
    category_id: int | None = None,
    category_name: str | None = None,
    shift_amount: list[int] | None = [0, 0],
):
    """Creates ObjectAnnotation from shapely_utils.ShapelyAnnotation.

    Args:
        annotation: shapely_utils.ShapelyAnnotation
        category_id: int
            ID of the object category
        category_name: str
            Name of the object category
        full_shape: List
            Size of the full image, should be in the form of [height, width]
        shift_amount: List
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
    """
    return cls(
        category_id=category_id,
        segmentation=annotation.to_coco_segmentation(),
        category_name=category_name,
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
to_coco_annotation()

Returns sahi.utils.coco.CocoAnnotation representation of ObjectAnnotation.

Source code in sahi/annotation.py
def to_coco_annotation(self) -> CocoAnnotation:
    """Returns sahi.utils.coco.CocoAnnotation representation of ObjectAnnotation."""
    if self.mask:
        coco_annotation = CocoAnnotation.from_coco_segmentation(
            segmentation=self.mask.segmentation,
            category_id=self.category.id,
            category_name=self.category.name,
        )
    else:
        coco_annotation = CocoAnnotation.from_coco_bbox(
            bbox=self.bbox.to_xywh(),
            category_id=self.category.id,
            category_name=self.category.name,
        )
    return coco_annotation
to_coco_prediction()

Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation.

Source code in sahi/annotation.py
def to_coco_prediction(self) -> CocoPrediction:
    """Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation."""
    if self.mask:
        coco_prediction = CocoPrediction.from_coco_segmentation(
            segmentation=self.mask.segmentation,
            category_id=self.category.id,
            category_name=self.category.name,
            score=1,
        )
    else:
        coco_prediction = CocoPrediction.from_coco_bbox(
            bbox=self.bbox.to_xywh(),
            category_id=self.category.id,
            category_name=self.category.name,
            score=1,
        )
    return coco_prediction
to_imantics_annotation()

Returns imantics.annotation.Annotation representation of ObjectAnnotation.

Source code in sahi/annotation.py
def to_imantics_annotation(self):
    """Returns imantics.annotation.Annotation representation of ObjectAnnotation."""
    try:
        import imantics
    except ImportError:
        raise ImportError('Please run "pip install -U imantics" to install imantics first for imantics conversion.')

    imantics_category = imantics.Category(id=self.category.id, name=self.category.name)
    if self.mask is not None:
        imantics_mask = imantics.Mask.create(self.mask.bool_mask)
        imantics_annotation = imantics.annotation.Annotation.from_mask(
            mask=imantics_mask, category=imantics_category
        )
    else:
        imantics_bbox = imantics.BBox.create(self.bbox.to_xyxy())
        imantics_annotation = imantics.annotation.Annotation.from_bbox(
            bbox=imantics_bbox, category=imantics_category
        )
    return imantics_annotation
to_shapely_annotation()

Returns sahi.utils.shapely.ShapelyAnnotation representation of ObjectAnnotation.

Source code in sahi/annotation.py
def to_shapely_annotation(self) -> ShapelyAnnotation:
    """Returns sahi.utils.shapely.ShapelyAnnotation representation of ObjectAnnotation."""
    if self.mask:
        shapely_annotation = ShapelyAnnotation.from_coco_segmentation(
            segmentation=self.mask.segmentation,
        )
    else:
        shapely_annotation = ShapelyAnnotation.from_coco_bbox(
            bbox=self.bbox.to_xywh(),
        )
    return shapely_annotation
Functions

auto_model

Classes
AutoDetectionModel
Source code in sahi/auto_model.py
class AutoDetectionModel:
    @staticmethod
    def from_pretrained(
        model_type: str,
        model_path: str | None = None,
        model: Any | None = None,
        config_path: str | None = None,
        device: str | None = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: dict | None = None,
        category_remapping: dict | None = None,
        load_at_init: bool = True,
        image_size: int | None = None,
        **kwargs,
    ) -> DetectionModel:
        """Loads a DetectionModel from given path.

        Args:
            model_type: str
                Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")
            model_path: str
                Path of the detection model (ex. 'model.pt')
            model: Any
                A pre-initialized model instance, if available
            config_path: str
                Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py')
            device: str
                Device, "cpu" or "cuda:0"
            mask_threshold: float
                Value to threshold mask pixels, should be between 0 and 1
            confidence_threshold: float
                All predictions with score < confidence_threshold will be discarded
            category_mapping: dict: str to str
                Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
            category_remapping: dict: str to int
                Remap category ids based on category names, after performing inference e.g. {"car": 3}
            load_at_init: bool
                If True, automatically loads the model at initialization
            image_size: int
                Inference input size.

        Returns:
            Returns an instance of a DetectionModel

        Raises:
            ImportError: If given {model_type} framework is not installed
        """
        if model_type in ULTRALYTICS_MODEL_NAMES:
            model_type = "ultralytics"
        model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
        DetectionModel = import_model_class(model_type, model_class_name)

        return DetectionModel(
            model_path=model_path,
            model=model,
            config_path=config_path,
            device=device,
            mask_threshold=mask_threshold,
            confidence_threshold=confidence_threshold,
            category_mapping=category_mapping,
            category_remapping=category_remapping,
            load_at_init=load_at_init,
            image_size=image_size,
            **kwargs,
        )
Functions
from_pretrained(model_type, model_path=None, model=None, config_path=None, device=None, mask_threshold=0.5, confidence_threshold=0.3, category_mapping=None, category_remapping=None, load_at_init=True, image_size=None, **kwargs) staticmethod

Loads a DetectionModel from given path.

Parameters:

Name Type Description Default
model_type str

str Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")

required
model_path str | None

str Path of the detection model (ex. 'model.pt')

None
model Any | None

Any A pre-initialized model instance, if available

None
config_path str | None

str Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py')

None
device str | None

str Device, "cpu" or "cuda:0"

None
mask_threshold float

float Value to threshold mask pixels, should be between 0 and 1

0.5
confidence_threshold float

float All predictions with score < confidence_threshold will be discarded

0.3
category_mapping dict | None

dict: str to str Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}

None
category_remapping dict | None

dict: str to int Remap category ids based on category names, after performing inference e.g. {"car": 3}

None
load_at_init bool

bool If True, automatically loads the model at initialization

True
image_size int | None

int Inference input size.

None

Returns:

Type Description
DetectionModel

Returns an instance of a DetectionModel

Raises:

Type Description
ImportError

If given {model_type} framework is not installed

Source code in sahi/auto_model.py
@staticmethod
def from_pretrained(
    model_type: str,
    model_path: str | None = None,
    model: Any | None = None,
    config_path: str | None = None,
    device: str | None = None,
    mask_threshold: float = 0.5,
    confidence_threshold: float = 0.3,
    category_mapping: dict | None = None,
    category_remapping: dict | None = None,
    load_at_init: bool = True,
    image_size: int | None = None,
    **kwargs,
) -> DetectionModel:
    """Loads a DetectionModel from given path.

    Args:
        model_type: str
            Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")
        model_path: str
            Path of the detection model (ex. 'model.pt')
        model: Any
            A pre-initialized model instance, if available
        config_path: str
            Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py')
        device: str
            Device, "cpu" or "cuda:0"
        mask_threshold: float
            Value to threshold mask pixels, should be between 0 and 1
        confidence_threshold: float
            All predictions with score < confidence_threshold will be discarded
        category_mapping: dict: str to str
            Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
        category_remapping: dict: str to int
            Remap category ids based on category names, after performing inference e.g. {"car": 3}
        load_at_init: bool
            If True, automatically loads the model at initialization
        image_size: int
            Inference input size.

    Returns:
        Returns an instance of a DetectionModel

    Raises:
        ImportError: If given {model_type} framework is not installed
    """
    if model_type in ULTRALYTICS_MODEL_NAMES:
        model_type = "ultralytics"
    model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
    DetectionModel = import_model_class(model_type, model_class_name)

    return DetectionModel(
        model_path=model_path,
        model=model,
        config_path=config_path,
        device=device,
        mask_threshold=mask_threshold,
        confidence_threshold=confidence_threshold,
        category_mapping=category_mapping,
        category_remapping=category_remapping,
        load_at_init=load_at_init,
        image_size=image_size,
        **kwargs,
    )
Functions

cli

Functions
app()

Cli app.

Source code in sahi/cli.py
def app() -> None:
    """Cli app."""
    fire.Fire(sahi_app)

logger

Classes
BaseSahiLogger

Bases: Logger, ABC

Source code in sahi/logger.py
class BaseSahiLogger(logging.Logger, ABC):
    @abstractmethod
    def pkg_info(self, message: str, *args, **kws) -> None:
        """Log a package info message at PKG_INFO level."""
        raise NotImplementedError
Functions
pkg_info(message, *args, **kws) abstractmethod

Log a package info message at PKG_INFO level.

Source code in sahi/logger.py
@abstractmethod
def pkg_info(self, message: str, *args, **kws) -> None:
    """Log a package info message at PKG_INFO level."""
    raise NotImplementedError

models

Modules
base
Classes
DetectionModel
Source code in sahi/models/base.py
class DetectionModel:
    required_packages: list[str] | None = None

    def __init__(
        self,
        model_path: str | None = None,
        model: Any | None = None,
        config_path: str | None = None,
        device: str | None = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: dict | None = None,
        category_remapping: dict | None = None,
        load_at_init: bool = True,
        image_size: int | None = None,
    ):
        """Init object detection/instance segmentation model.

        Args:
            model_path: str
                Path for the instance segmentation model weight
            config_path: str
                Path for the mmdetection instance segmentation model config file
            device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
            mask_threshold: float
                Value to threshold mask pixels, should be between 0 and 1
            confidence_threshold: float
                All predictions with score < confidence_threshold will be discarded
            category_mapping: dict: str to str
                Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
            category_remapping: dict: str to int
                Remap category ids based on category names, after performing inference e.g. {"car": 3}
            load_at_init: bool
                If True, automatically loads the model at initialization
            image_size: int
                Inference input size.
        """

        self.model_path = model_path
        self.config_path = config_path
        self.model = None
        self.mask_threshold = mask_threshold
        self.confidence_threshold = confidence_threshold
        self.category_mapping = category_mapping
        self.category_remapping = category_remapping
        self.image_size = image_size
        self._original_predictions = None
        self._object_prediction_list_per_image = None
        self.set_device(device)

        # automatically ensure dependencies
        self.check_dependencies()

        # automatically load model if load_at_init is True
        if load_at_init:
            if model:
                self.set_model(model)
            else:
                self.load_model()

    def check_dependencies(self, packages: list[str] | None = None) -> None:
        """Ensures required dependencies are installed.

        If 'packages' is None, uses self.required_packages. Subclasses may still call with a custom list for dynamic
        needs.
        """
        pkgs = packages if packages is not None else getattr(self, "required_packages", [])
        if pkgs:
            check_requirements(pkgs)

    def load_model(self):
        """This function should be implemented in a way that detection model should be initialized and set to
        self.model.

        (self.model_path, self.config_path, and self.device should be utilized)
        """
        raise NotImplementedError()

    def set_model(self, model: Any, **kwargs):
        """
        This function should be implemented to instantiate a DetectionModel out of an already loaded model
        Args:
            model: Any
                Loaded model
        """
        raise NotImplementedError()

    def set_device(self, device: str | None = None):
        """Sets the device pytorch should use for the model.

        Args:
            device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
        """

        self.device = select_device(device)

    def unload_model(self):
        """Unloads the model from CPU/GPU."""
        self.model = None
        empty_cuda_cache()

    def perform_inference(self, image: np.ndarray):
        """This function should be implemented in a way that prediction should be performed using self.model and the
        prediction result should be set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted.
        """
        raise NotImplementedError()

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int]] | None = [[0, 0]],
        full_shape_list: list[list[int]] | None = None,
    ):
        """This function should be implemented in a way that self._original_predictions should be converted to a list of
        prediction.ObjectPrediction and set to self._object_prediction_list.

        self.mask_threshold can also be utilized.
        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        raise NotImplementedError()

    def _apply_category_remapping(self):
        """Applies category remapping based on mapping given in self.category_remapping."""
        # confirm self.category_remapping is not None
        if self.category_remapping is None:
            raise ValueError("self.category_remapping cannot be None")
        # remap categories
        if not isinstance(self._object_prediction_list_per_image, list):
            logger.error(
                f"Unknown type for self._object_prediction_list_per_image: "
                f"{type(self._object_prediction_list_per_image)}"
            )
            return
        for object_prediction_list in self._object_prediction_list_per_image:  # type: ignore
            for object_prediction in object_prediction_list:
                old_category_id_str = str(object_prediction.category.id)
                new_category_id_int = self.category_remapping[old_category_id_str]
                object_prediction.category = Category(id=new_category_id_int, name=object_prediction.category.name)

    def convert_original_predictions(
        self,
        shift_amount: list[list[int]] | None = [[0, 0]],
        full_shape: list[list[int]] | None = None,
    ):
        """Converts original predictions of the detection model to a list of prediction.ObjectPrediction object.

        Should be called after perform_inference().
        Args:
            shift_amount: list
                To shift the box and mask predictions from sliced image to full sized image,
                    should be in the form of [shift_x, shift_y]
            full_shape: list
                Size of the full image after shifting, should be in the form of [height, width]
        """
        self._create_object_prediction_list_from_original_predictions(
            shift_amount_list=shift_amount,
            full_shape_list=full_shape,
        )
        if self.category_remapping:
            self._apply_category_remapping()

    @property
    def object_prediction_list(self) -> list[list[ObjectPrediction]]:
        if self._object_prediction_list_per_image is None:
            return []
        if len(self._object_prediction_list_per_image) == 0:
            return []
        return self._object_prediction_list_per_image[0]

    @property
    def object_prediction_list_per_image(self) -> list[list[ObjectPrediction]]:
        return self._object_prediction_list_per_image or []

    @property
    def original_predictions(self):
        return self._original_predictions
Functions
__init__(model_path=None, model=None, config_path=None, device=None, mask_threshold=0.5, confidence_threshold=0.3, category_mapping=None, category_remapping=None, load_at_init=True, image_size=None)

Init object detection/instance segmentation model.

Parameters:

Name Type Description Default
model_path str | None

str Path for the instance segmentation model weight

None
config_path str | None

str Path for the mmdetection instance segmentation model config file

None
device str | None

Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.

None
mask_threshold float

float Value to threshold mask pixels, should be between 0 and 1

0.5
confidence_threshold float

float All predictions with score < confidence_threshold will be discarded

0.3
category_mapping dict | None

dict: str to str Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}

None
category_remapping dict | None

dict: str to int Remap category ids based on category names, after performing inference e.g. {"car": 3}

None
load_at_init bool

bool If True, automatically loads the model at initialization

True
image_size int | None

int Inference input size.

None
Source code in sahi/models/base.py
def __init__(
    self,
    model_path: str | None = None,
    model: Any | None = None,
    config_path: str | None = None,
    device: str | None = None,
    mask_threshold: float = 0.5,
    confidence_threshold: float = 0.3,
    category_mapping: dict | None = None,
    category_remapping: dict | None = None,
    load_at_init: bool = True,
    image_size: int | None = None,
):
    """Init object detection/instance segmentation model.

    Args:
        model_path: str
            Path for the instance segmentation model weight
        config_path: str
            Path for the mmdetection instance segmentation model config file
        device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
        mask_threshold: float
            Value to threshold mask pixels, should be between 0 and 1
        confidence_threshold: float
            All predictions with score < confidence_threshold will be discarded
        category_mapping: dict: str to str
            Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
        category_remapping: dict: str to int
            Remap category ids based on category names, after performing inference e.g. {"car": 3}
        load_at_init: bool
            If True, automatically loads the model at initialization
        image_size: int
            Inference input size.
    """

    self.model_path = model_path
    self.config_path = config_path
    self.model = None
    self.mask_threshold = mask_threshold
    self.confidence_threshold = confidence_threshold
    self.category_mapping = category_mapping
    self.category_remapping = category_remapping
    self.image_size = image_size
    self._original_predictions = None
    self._object_prediction_list_per_image = None
    self.set_device(device)

    # automatically ensure dependencies
    self.check_dependencies()

    # automatically load model if load_at_init is True
    if load_at_init:
        if model:
            self.set_model(model)
        else:
            self.load_model()
check_dependencies(packages=None)

Ensures required dependencies are installed.

If 'packages' is None, uses self.required_packages. Subclasses may still call with a custom list for dynamic needs.

Source code in sahi/models/base.py
def check_dependencies(self, packages: list[str] | None = None) -> None:
    """Ensures required dependencies are installed.

    If 'packages' is None, uses self.required_packages. Subclasses may still call with a custom list for dynamic
    needs.
    """
    pkgs = packages if packages is not None else getattr(self, "required_packages", [])
    if pkgs:
        check_requirements(pkgs)
convert_original_predictions(shift_amount=[[0, 0]], full_shape=None)

Converts original predictions of the detection model to a list of prediction.ObjectPrediction object.

Should be called after perform_inference(). Args: shift_amount: list To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y] full_shape: list Size of the full image after shifting, should be in the form of [height, width]

Source code in sahi/models/base.py
def convert_original_predictions(
    self,
    shift_amount: list[list[int]] | None = [[0, 0]],
    full_shape: list[list[int]] | None = None,
):
    """Converts original predictions of the detection model to a list of prediction.ObjectPrediction object.

    Should be called after perform_inference().
    Args:
        shift_amount: list
            To shift the box and mask predictions from sliced image to full sized image,
                should be in the form of [shift_x, shift_y]
        full_shape: list
            Size of the full image after shifting, should be in the form of [height, width]
    """
    self._create_object_prediction_list_from_original_predictions(
        shift_amount_list=shift_amount,
        full_shape_list=full_shape,
    )
    if self.category_remapping:
        self._apply_category_remapping()
load_model()

This function should be implemented in a way that detection model should be initialized and set to self.model.

(self.model_path, self.config_path, and self.device should be utilized)

Source code in sahi/models/base.py
def load_model(self):
    """This function should be implemented in a way that detection model should be initialized and set to
    self.model.

    (self.model_path, self.config_path, and self.device should be utilized)
    """
    raise NotImplementedError()
perform_inference(image)

This function should be implemented in a way that prediction should be performed using self.model and the prediction result should be set to self._original_predictions.

Parameters:

Name Type Description Default
image ndarray

np.ndarray A numpy array that contains the image to be predicted.

required
Source code in sahi/models/base.py
def perform_inference(self, image: np.ndarray):
    """This function should be implemented in a way that prediction should be performed using self.model and the
    prediction result should be set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted.
    """
    raise NotImplementedError()
set_device(device=None)

Sets the device pytorch should use for the model.

Parameters:

Name Type Description Default
device str | None

Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.

None
Source code in sahi/models/base.py
def set_device(self, device: str | None = None):
    """Sets the device pytorch should use for the model.

    Args:
        device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
    """

    self.device = select_device(device)
set_model(model, **kwargs)

This function should be implemented to instantiate a DetectionModel out of an already loaded model Args: model: Any Loaded model

Source code in sahi/models/base.py
def set_model(self, model: Any, **kwargs):
    """
    This function should be implemented to instantiate a DetectionModel out of an already loaded model
    Args:
        model: Any
            Loaded model
    """
    raise NotImplementedError()
unload_model()

Unloads the model from CPU/GPU.

Source code in sahi/models/base.py
def unload_model(self):
    """Unloads the model from CPU/GPU."""
    self.model = None
    empty_cuda_cache()
Functions
detectron2
Classes
Detectron2DetectionModel

Bases: DetectionModel

Source code in sahi/models/detectron2.py
class Detectron2DetectionModel(DetectionModel):
    def __init__(self, *args, **kwargs):
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "torch", "detectron2"]
        super().__init__(*args, **kwargs)

    def load_model(self):
        from detectron2.config import get_cfg
        from detectron2.data import MetadataCatalog
        from detectron2.engine import DefaultPredictor
        from detectron2.model_zoo import model_zoo

        cfg = get_cfg()

        try:  # try to load from model zoo
            config_file = model_zoo.get_config_file(self.config_path)
            cfg.set_new_allowed(True)
            cfg.merge_from_file(config_file)
            cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.config_path)
        except Exception as e:  # try to load from local
            print(e)
            if self.config_path is not None:
                cfg.set_new_allowed(True)
                cfg.merge_from_file(self.config_path)
            cfg.MODEL.WEIGHTS = self.model_path

        # set model device
        cfg.MODEL.DEVICE = self.device.type
        # set input image size
        if self.image_size is not None:
            cfg.INPUT.MIN_SIZE_TEST = self.image_size
            cfg.INPUT.MAX_SIZE_TEST = self.image_size
        # init predictor
        model = DefaultPredictor(cfg)

        self.model = model

        # detectron2 category mapping
        if self.category_mapping is None:
            try:  # try to parse category names from metadata
                metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
                category_names = metadata.thing_classes
                self.category_names = category_names
                self.category_mapping = {
                    str(ind): category_name for ind, category_name in enumerate(self.category_names)
                }
            except Exception as e:
                logger.warning(e)
                # https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html#update-the-config-for-new-datasets
                if cfg.MODEL.META_ARCHITECTURE == "RetinaNet":
                    num_categories = cfg.MODEL.RETINANET.NUM_CLASSES
                else:  # fasterrcnn/maskrcnn etc
                    num_categories = cfg.MODEL.ROI_HEADS.NUM_CLASSES
                self.category_names = [str(category_id) for category_id in range(num_categories)]
                self.category_mapping = {
                    str(ind): category_name for ind, category_name in enumerate(self.category_names)
                }
        else:
            self.category_names = list(self.category_mapping.values())

    def perform_inference(self, image: np.ndarray):
        """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        """

        # Confirm model is loaded
        if self.model is None:
            raise RuntimeError("Model is not loaded, load it by calling .load_model()")

        if isinstance(image, np.ndarray) and self.model.input_format == "BGR":
            # convert RGB image to BGR format
            image = image[:, :, ::-1]

        prediction_result = self.model(image)

        self._original_predictions = prediction_result

    @property
    def num_categories(self):
        """Returns number of categories."""
        num_categories = len(self.category_mapping)
        return num_categories

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int]] | None = [[0, 0]],
        full_shape_list: list[list[int]] | None = None,
    ):
        """self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.

        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """

        original_predictions = self._original_predictions

        # compatilibty for sahi v0.8.15
        if isinstance(shift_amount_list[0], int):
            shift_amount_list = [shift_amount_list]
        if full_shape_list is not None and isinstance(full_shape_list[0], int):
            full_shape_list = [full_shape_list]

        # detectron2 DefaultPredictor supports single image
        shift_amount = shift_amount_list[0]
        full_shape = None if full_shape_list is None else full_shape_list[0]

        # parse boxes, masks, scores, category_ids from predictions
        boxes = original_predictions["instances"].pred_boxes.tensor
        scores = original_predictions["instances"].scores
        category_ids = original_predictions["instances"].pred_classes

        # check if predictions contain mask
        try:
            masks = original_predictions["instances"].pred_masks
        except AttributeError:
            masks = None

        # filter predictions with low confidence
        high_confidence_mask = scores >= self.confidence_threshold
        boxes = boxes[high_confidence_mask]
        scores = scores[high_confidence_mask]
        category_ids = category_ids[high_confidence_mask]
        if masks is not None:
            masks = masks[high_confidence_mask]
        if masks is not None:
            object_prediction_list = [
                ObjectPrediction(
                    bbox=box.tolist() if mask is None else None,
                    segmentation=(
                        get_coco_segmentation_from_bool_mask(mask.detach().cpu().numpy()) if mask is not None else None
                    ),
                    category_id=category_id.item(),
                    category_name=self.category_mapping[str(category_id.item())],
                    shift_amount=shift_amount,
                    score=score.item(),
                    full_shape=full_shape,
                )
                for box, score, category_id, mask in zip(boxes, scores, category_ids, masks)
                if mask is None
                or (
                    (
                        (segmentation := get_coco_segmentation_from_bool_mask(mask.detach().cpu().numpy()))
                        and len(segmentation) > 0
                    )
                    and get_bbox_from_bool_mask(mask.detach().cpu().numpy()) is not None
                )
            ]
        else:
            object_prediction_list = [
                ObjectPrediction(
                    bbox=box.tolist(),
                    segmentation=None,
                    category_id=category_id.item(),
                    category_name=self.category_mapping[str(category_id.item())],
                    shift_amount=shift_amount,
                    score=score.item(),
                    full_shape=full_shape,
                )
                for box, score, category_id in zip(boxes, scores, category_ids)
            ]

        # detectron2 DefaultPredictor supports single image
        object_prediction_list_per_image = [object_prediction_list]

        self._object_prediction_list_per_image = object_prediction_list_per_image
Attributes
num_categories property

Returns number of categories.

Functions
perform_inference(image)

Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Parameters:

Name Type Description Default
image ndarray

np.ndarray A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.

required
Source code in sahi/models/detectron2.py
def perform_inference(self, image: np.ndarray):
    """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
    """

    # Confirm model is loaded
    if self.model is None:
        raise RuntimeError("Model is not loaded, load it by calling .load_model()")

    if isinstance(image, np.ndarray) and self.model.input_format == "BGR":
        # convert RGB image to BGR format
        image = image[:, :, ::-1]

    prediction_result = self.model(image)

    self._original_predictions = prediction_result
Functions
huggingface
Classes
HuggingfaceDetectionModel

Bases: DetectionModel

Source code in sahi/models/huggingface.py
class HuggingfaceDetectionModel(DetectionModel):
    def __init__(
        self,
        model_path: str | None = None,
        model: Any | None = None,
        processor: Any | None = None,
        config_path: str | None = None,
        device: str | None = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: dict | None = None,
        category_remapping: dict | None = None,
        load_at_init: bool = True,
        image_size: int | None = None,
        token: str | None = None,
    ):
        self._processor = processor
        self._image_shapes: list = []
        self._token = token
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "torch", "transformers"]
        ensure_package_minimum_version("transformers", "4.42.0")
        super().__init__(
            model_path,
            model,
            config_path,
            device,
            mask_threshold,
            confidence_threshold,
            category_mapping,
            category_remapping,
            load_at_init,
            image_size,
        )

    @property
    def processor(self):
        return self._processor

    @property
    def image_shapes(self):
        return self._image_shapes

    @property
    def num_categories(self) -> int:
        """Returns number of categories."""
        return self.model.config.num_labels

    def load_model(self):
        from transformers import AutoModelForObjectDetection, AutoProcessor

        hf_token = os.getenv("HF_TOKEN", self._token)
        model = AutoModelForObjectDetection.from_pretrained(self.model_path, token=hf_token)
        if self.image_size is not None:
            if model.base_model_prefix == "rt_detr_v2":
                size = {"height": self.image_size, "width": self.image_size}
            else:
                size = {"shortest_edge": self.image_size, "longest_edge": None}
            # use_fast=True raises error: AttributeError: 'SizeDict' object has no attribute 'keys'
            processor = AutoProcessor.from_pretrained(
                self.model_path, size=size, do_resize=True, use_fast=False, token=hf_token
            )
        else:
            processor = AutoProcessor.from_pretrained(self.model_path, use_fast=False, token=hf_token)
        self.set_model(model, processor)

    def set_model(self, model: Any, processor: Any = None, **kwargs):
        processor = processor or self.processor
        if processor is None:
            raise ValueError(f"'processor' is required to be set, got {processor}.")
        elif "ObjectDetection" not in model.__class__.__name__ or "ImageProcessor" not in processor.__class__.__name__:
            raise ValueError(
                "Given 'model' is not an ObjectDetectionModel or 'processor' is not a valid ImageProcessor."
            )
        self.model = model
        self.model.to(self.device)
        self._processor = processor
        self.category_mapping = self.model.config.id2label

    def perform_inference(self, image: list | np.ndarray):
        """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        """
        import torch

        # Confirm model is loaded
        if self.model is None or self.processor is None:
            raise RuntimeError("Model is not loaded, load it by calling .load_model()")

        with torch.no_grad():
            inputs = self.processor(images=image, return_tensors="pt")
            inputs["pixel_values"] = inputs.pixel_values.to(self.device)
            if hasattr(inputs, "pixel_mask"):
                inputs["pixel_mask"] = inputs.pixel_mask.to(self.device)
            outputs = self.model(**inputs)

        if isinstance(image, list):
            self._image_shapes = [img.shape for img in image]
        else:
            self._image_shapes = [image.shape]
        self._original_predictions = outputs

    def get_valid_predictions(self, logits, pred_boxes) -> tuple:
        """
        Args:
            logits: torch.Tensor
            pred_boxes: torch.Tensor
        Returns:
            scores: torch.Tensor
            cat_ids: torch.Tensor
            boxes: torch.Tensor
        """
        import torch

        probs = logits.softmax(-1)
        scores = probs.max(-1).values
        cat_ids = probs.argmax(-1)
        valid_detections = torch.where(cat_ids < self.num_categories, 1, 0)
        valid_confidences = torch.where(scores >= self.confidence_threshold, 1, 0)
        valid_mask = valid_detections.logical_and(valid_confidences)
        scores = scores[valid_mask]
        cat_ids = cat_ids[valid_mask]
        boxes = pred_boxes[valid_mask]
        return scores, cat_ids, boxes

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int]] | None = [[0, 0]],
        full_shape_list: list[list[int]] | None = None,
    ):
        """self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.

        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        original_predictions = self._original_predictions

        # compatibility for sahi v0.8.15
        shift_amount_list = fix_shift_amount_list(shift_amount_list)
        full_shape_list = fix_full_shape_list(full_shape_list)

        n_image = original_predictions.logits.shape[0]
        object_prediction_list_per_image = []
        for image_ind in range(n_image):
            image_height, image_width, _ = self.image_shapes[image_ind]
            scores, cat_ids, boxes = self.get_valid_predictions(
                logits=original_predictions.logits[image_ind], pred_boxes=original_predictions.pred_boxes[image_ind]
            )

            # create object_prediction_list
            object_prediction_list = []

            shift_amount = shift_amount_list[image_ind]
            full_shape = None if full_shape_list is None else full_shape_list[image_ind]

            for ind in range(len(boxes)):
                category_id = cat_ids[ind].item()
                yolo_bbox = boxes[ind].tolist()
                bbox = list(
                    pbf.convert_bbox(
                        yolo_bbox,
                        from_type="yolo",
                        to_type="voc",
                        image_size=(image_width, image_height),
                        return_values=True,
                        strict=False,
                    )
                )

                # fix negative box coords
                bbox[0] = max(0, bbox[0])
                bbox[1] = max(0, bbox[1])
                bbox[2] = min(bbox[2], image_width)
                bbox[3] = min(bbox[3], image_height)

                object_prediction = ObjectPrediction(
                    bbox=bbox,
                    segmentation=None,
                    category_id=category_id,
                    category_name=self.category_mapping[category_id],
                    shift_amount=shift_amount,
                    score=scores[ind].item(),
                    full_shape=full_shape,
                )
                object_prediction_list.append(object_prediction)
            object_prediction_list_per_image.append(object_prediction_list)

        self._object_prediction_list_per_image = object_prediction_list_per_image
Attributes
num_categories property

Returns number of categories.

Functions
get_valid_predictions(logits, pred_boxes)

Parameters:

Name Type Description Default
logits

torch.Tensor

required
pred_boxes

torch.Tensor

required

Returns: scores: torch.Tensor cat_ids: torch.Tensor boxes: torch.Tensor

Source code in sahi/models/huggingface.py
def get_valid_predictions(self, logits, pred_boxes) -> tuple:
    """
    Args:
        logits: torch.Tensor
        pred_boxes: torch.Tensor
    Returns:
        scores: torch.Tensor
        cat_ids: torch.Tensor
        boxes: torch.Tensor
    """
    import torch

    probs = logits.softmax(-1)
    scores = probs.max(-1).values
    cat_ids = probs.argmax(-1)
    valid_detections = torch.where(cat_ids < self.num_categories, 1, 0)
    valid_confidences = torch.where(scores >= self.confidence_threshold, 1, 0)
    valid_mask = valid_detections.logical_and(valid_confidences)
    scores = scores[valid_mask]
    cat_ids = cat_ids[valid_mask]
    boxes = pred_boxes[valid_mask]
    return scores, cat_ids, boxes
perform_inference(image)

Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Parameters:

Name Type Description Default
image list | ndarray

np.ndarray A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.

required
Source code in sahi/models/huggingface.py
def perform_inference(self, image: list | np.ndarray):
    """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
    """
    import torch

    # Confirm model is loaded
    if self.model is None or self.processor is None:
        raise RuntimeError("Model is not loaded, load it by calling .load_model()")

    with torch.no_grad():
        inputs = self.processor(images=image, return_tensors="pt")
        inputs["pixel_values"] = inputs.pixel_values.to(self.device)
        if hasattr(inputs, "pixel_mask"):
            inputs["pixel_mask"] = inputs.pixel_mask.to(self.device)
        outputs = self.model(**inputs)

    if isinstance(image, list):
        self._image_shapes = [img.shape for img in image]
    else:
        self._image_shapes = [image.shape]
    self._original_predictions = outputs
Functions
mmdet
Classes
DetInferencerWrapper

Bases: DetInferencer

Source code in sahi/models/mmdet.py
class DetInferencerWrapper(DetInferencer):
    def __init__(
        self,
        model: ModelType | str | None = None,
        weights: str | None = None,
        device: str | None = None,
        scope: str | None = "mmdet",
        palette: str = "none",
        image_size: int | None = None,
    ) -> None:
        self.image_size = image_size
        super().__init__(model, weights, device, scope, palette)

    def __call__(self, images: list[np.ndarray], batch_size: int = 1) -> dict:
        """
        Emulate DetInferencer(images) without progressbar
        Args:
            images: list of np.ndarray
                A list of numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
            batch_size: int
                Inference batch size. Defaults to 1.
        """
        inputs = self.preprocess(images, batch_size=batch_size)
        results_dict = {"predictions": [], "visualization": []}
        for _, data in inputs:
            preds = self.forward(data)
            results = self.postprocess(
                preds,
                visualization=None,
                return_datasample=False,
                print_result=False,
                no_save_pred=True,
                pred_out_dir=None,
            )
            results_dict["predictions"].extend(results["predictions"])
        return results_dict

    def _init_pipeline(self, cfg: ConfigType) -> Compose:
        """Initialize the test pipeline."""
        pipeline_cfg = cfg.test_dataloader.dataset.pipeline

        # For inference, the key of ``img_id`` is not used.
        if "meta_keys" in pipeline_cfg[-1]:
            pipeline_cfg[-1]["meta_keys"] = tuple(
                meta_key for meta_key in pipeline_cfg[-1]["meta_keys"] if meta_key != "img_id"
            )

        load_img_idx = self._get_transform_idx(pipeline_cfg, "LoadImageFromFile")
        if load_img_idx == -1:
            raise ValueError("LoadImageFromFile is not found in the test pipeline")
        pipeline_cfg[load_img_idx]["type"] = "mmdet.InferencerLoader"

        resize_idx = self._get_transform_idx(pipeline_cfg, "Resize")
        if resize_idx == -1:
            raise ValueError("Resize is not found in the test pipeline")
        if self.image_size is not None:
            pipeline_cfg[resize_idx]["scale"] = (self.image_size, self.image_size)
        return Compose(pipeline_cfg)
Functions
__call__(images, batch_size=1)

Emulate DetInferencer(images) without progressbar Args: images: list of np.ndarray A list of numpy array that contains the image to be predicted. 3 channel image should be in RGB order. batch_size: int Inference batch size. Defaults to 1.

Source code in sahi/models/mmdet.py
def __call__(self, images: list[np.ndarray], batch_size: int = 1) -> dict:
    """
    Emulate DetInferencer(images) without progressbar
    Args:
        images: list of np.ndarray
            A list of numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        batch_size: int
            Inference batch size. Defaults to 1.
    """
    inputs = self.preprocess(images, batch_size=batch_size)
    results_dict = {"predictions": [], "visualization": []}
    for _, data in inputs:
        preds = self.forward(data)
        results = self.postprocess(
            preds,
            visualization=None,
            return_datasample=False,
            print_result=False,
            no_save_pred=True,
            pred_out_dir=None,
        )
        results_dict["predictions"].extend(results["predictions"])
    return results_dict
MmdetDetectionModel

Bases: DetectionModel

Source code in sahi/models/mmdet.py
class MmdetDetectionModel(DetectionModel):
    def __init__(
        self,
        model_path: str | None = None,
        model: Any | None = None,
        config_path: str | None = None,
        device: str | None = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: dict | None = None,
        category_remapping: dict | None = None,
        load_at_init: bool = True,
        image_size: int | None = None,
        scope: str = "mmdet",
    ):
        self.scope = scope
        self.image_size = image_size
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "mmdet", "mmcv", "torch"]
        super().__init__(
            model_path,
            model,
            config_path,
            device,
            mask_threshold,
            confidence_threshold,
            category_mapping,
            category_remapping,
            load_at_init,
            image_size,
        )

    def load_model(self):
        """Detection model is initialized and set to self.model."""

        # create model
        model = DetInferencerWrapper(
            self.config_path, self.model_path, device=self.device, scope=self.scope, image_size=self.image_size
        )

        self.set_model(model)

    def set_model(self, model: Any):
        """Sets the underlying MMDetection model.

        Args:
            model: Any
                A MMDetection model
        """

        # set self.model
        self.model = model

        # set category_mapping
        if not self.category_mapping:
            category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
            self.category_mapping = category_mapping

    def perform_inference(self, image: np.ndarray):
        """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        """

        # Confirm model is loaded
        if self.model is None:
            raise ValueError("Model is not loaded, load it by calling .load_model()")

        # Supports only batch of 1

        # perform inference
        if isinstance(image, np.ndarray):
            # https://github.com/obss/sahi/issues/265
            image = image[:, :, ::-1]
        # compatibility with sahi v0.8.15
        if not isinstance(image, list):
            image_list = [image]
        prediction_result = self.model(image_list)

        self._original_predictions = prediction_result["predictions"]

    @property
    def num_categories(self):
        """Returns number of categories."""
        return len(self.category_names)

    @property
    def has_mask(self):
        """Returns if model output contains segmentation mask.

        Considers both single dataset and ConcatDataset scenarios.
        """

        def check_pipeline_for_mask(pipeline):
            return any(
                isinstance(item, dict) and any("mask" in key and value is True for key, value in item.items())
                for item in pipeline
            )

        # Access the dataset from the configuration
        dataset_config = self.model.cfg["train_dataloader"]["dataset"]

        if dataset_config["type"] == "ConcatDataset":
            # If using ConcatDataset, check each dataset individually
            datasets = dataset_config["datasets"]
            for dataset in datasets:
                if check_pipeline_for_mask(dataset["pipeline"]):
                    return True
        else:
            # Otherwise, assume a single dataset with its own pipeline
            if check_pipeline_for_mask(dataset_config["pipeline"]):
                return True

        return False

    @property
    def category_names(self):
        classes = self.model.model.dataset_meta["classes"]
        if isinstance(classes, str):
            # https://github.com/open-mmlab/mmdetection/pull/4973
            return (classes,)
        else:
            return classes

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int]] | None = [[0, 0]],
        full_shape_list: list[list[int]] | None = None,
    ):
        """self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.

        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        try:
            from pycocotools import mask as mask_utils

            can_decode_rle = True
        except ImportError:
            can_decode_rle = False
        original_predictions = self._original_predictions
        category_mapping = self.category_mapping

        # compatilibty for sahi v0.8.15
        shift_amount_list = fix_shift_amount_list(shift_amount_list)
        full_shape_list = fix_full_shape_list(full_shape_list)

        # parse boxes and masks from predictions
        object_prediction_list_per_image = []
        for image_ind, original_prediction in enumerate(original_predictions):
            shift_amount = shift_amount_list[image_ind]
            full_shape = None if full_shape_list is None else full_shape_list[image_ind]

            boxes = original_prediction["bboxes"]
            scores = original_prediction["scores"]
            labels = original_prediction["labels"]
            if self.has_mask:
                masks = original_prediction["masks"]

            object_prediction_list = []

            n_detects = len(labels)
            # process predictions
            for i in range(n_detects):
                if self.has_mask:
                    mask = masks[i]

                bbox = boxes[i]
                score = scores[i]
                category_id = labels[i]
                category_name = category_mapping[str(category_id)]

                # ignore low scored predictions
                if score < self.confidence_threshold:
                    continue

                # parse prediction mask
                if self.has_mask:
                    if "counts" in mask:
                        if can_decode_rle:
                            bool_mask = mask_utils.decode(mask)
                        else:
                            raise ValueError(
                                "Can not decode rle mask. Please install pycocotools. ex: 'pip install pycocotools'"
                            )
                    else:
                        bool_mask = mask
                    # check if mask is valid
                    # https://github.com/obss/sahi/discussions/696
                    if get_bbox_from_bool_mask(bool_mask) is None:
                        continue
                    segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
                else:
                    segmentation = None

                # fix negative box coords
                bbox[0] = max(0, bbox[0])
                bbox[1] = max(0, bbox[1])
                bbox[2] = max(0, bbox[2])
                bbox[3] = max(0, bbox[3])

                # fix out of image box coords
                if full_shape is not None:
                    bbox[0] = min(full_shape[1], bbox[0])
                    bbox[1] = min(full_shape[0], bbox[1])
                    bbox[2] = min(full_shape[1], bbox[2])
                    bbox[3] = min(full_shape[0], bbox[3])

                # ignore invalid predictions
                if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
                    logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
                    continue

                object_prediction = ObjectPrediction(
                    bbox=bbox,
                    category_id=category_id,
                    score=score,
                    segmentation=segmentation,
                    category_name=category_name,
                    shift_amount=shift_amount,
                    full_shape=full_shape,
                )
                object_prediction_list.append(object_prediction)
            object_prediction_list_per_image.append(object_prediction_list)
        self._object_prediction_list_per_image = object_prediction_list_per_image
Attributes
has_mask property

Returns if model output contains segmentation mask.

Considers both single dataset and ConcatDataset scenarios.

num_categories property

Returns number of categories.

Functions
load_model()

Detection model is initialized and set to self.model.

Source code in sahi/models/mmdet.py
def load_model(self):
    """Detection model is initialized and set to self.model."""

    # create model
    model = DetInferencerWrapper(
        self.config_path, self.model_path, device=self.device, scope=self.scope, image_size=self.image_size
    )

    self.set_model(model)
perform_inference(image)

Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Parameters:

Name Type Description Default
image ndarray

np.ndarray A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.

required
Source code in sahi/models/mmdet.py
def perform_inference(self, image: np.ndarray):
    """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
    """

    # Confirm model is loaded
    if self.model is None:
        raise ValueError("Model is not loaded, load it by calling .load_model()")

    # Supports only batch of 1

    # perform inference
    if isinstance(image, np.ndarray):
        # https://github.com/obss/sahi/issues/265
        image = image[:, :, ::-1]
    # compatibility with sahi v0.8.15
    if not isinstance(image, list):
        image_list = [image]
    prediction_result = self.model(image_list)

    self._original_predictions = prediction_result["predictions"]
set_model(model)

Sets the underlying MMDetection model.

Parameters:

Name Type Description Default
model Any

Any A MMDetection model

required
Source code in sahi/models/mmdet.py
def set_model(self, model: Any):
    """Sets the underlying MMDetection model.

    Args:
        model: Any
            A MMDetection model
    """

    # set self.model
    self.model = model

    # set category_mapping
    if not self.category_mapping:
        category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
        self.category_mapping = category_mapping
Functions
roboflow
Classes
RoboflowDetectionModel

Bases: DetectionModel

Source code in sahi/models/roboflow.py
class RoboflowDetectionModel(DetectionModel):
    def __init__(
        self,
        model: Any | None = None,
        model_path: str | None = None,
        config_path: str | None = None,
        device: str | None = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: dict | None = None,
        category_remapping: dict | None = None,
        load_at_init: bool = True,
        image_size: int | None = None,
        api_key: str | None = None,
    ):
        """Initialize the RoboflowDetectionModel with the given parameters.

        Args:
            model_path: str
                Path for the instance segmentation model weight
            config_path: str
                Path for the mmdetection instance segmentation model config file
            device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
            mask_threshold: float
                Value to threshold mask pixels, should be between 0 and 1
            confidence_threshold: float
                All predictions with score < confidence_threshold will be discarded
            category_mapping: dict: str to str
                Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
            category_remapping: dict: str to int
                Remap category ids based on category names, after performing inference e.g. {"car": 3}
            load_at_init: bool
                If True, automatically loads the model at initialization
            image_size: int
                Inference input size.
        """
        self._use_universe = model and isinstance(model, str)
        self._model = model
        self._device = device
        self._api_key = api_key

        if self._use_universe:
            existing_packages = getattr(self, "required_packages", None) or []
            self.required_packages = [*list(existing_packages), "inference"]
        else:
            existing_packages = getattr(self, "required_packages", None) or []
            self.required_packages = [*list(existing_packages), "rfdetr"]

        super().__init__(
            model=model,
            model_path=model_path,
            config_path=config_path,
            device=device,
            mask_threshold=mask_threshold,
            confidence_threshold=confidence_threshold,
            category_mapping=category_mapping,
            category_remapping=category_remapping,
            load_at_init=False,
            image_size=image_size,
        )

        if load_at_init:
            self.load_model()

    def set_model(self, model: Any, **kwargs):
        """
        This function should be implemented to instantiate a DetectionModel out of an already loaded model
        Args:
            model: Any
                Loaded model
        """
        self.model = model

    def load_model(self):
        """This function should be implemented in a way that detection model should be initialized and set to
        self.model.

        (self.model_path, self.config_path, and self.device should be utilized)
        """
        if self._use_universe:
            from inference import get_model
            from inference.core.env import API_KEY
            from inference.core.exceptions import RoboflowAPINotAuthorizedError

            api_key = self._api_key or API_KEY

            try:
                model = get_model(self._model, api_key=api_key)
            except RoboflowAPINotAuthorizedError as e:
                raise ValueError(
                    "Authorization failed. Please pass a valid API key with "
                    "the `api_key` parameter or set the `ROBOFLOW_API_KEY` environment variable."
                ) from e

            assert model.task_type == "object-detection", "Roboflow model must be an object detection model."

        else:
            from rfdetr.detr import RFDETRBase, RFDETRLarge, RFDETRMedium, RFDETRNano, RFDETRSmall

            model, model_path = self._model, self.model_path
            model_names = ("RFDETRBase", "RFDETRNano", "RFDETRSmall", "RFDETRMedium", "RFDETRLarge")
            if hasattr(model, "__name__") and model.__name__ in model_names:
                model_params = dict(
                    resolution=int(self.image_size) if self.image_size else 560,
                    device=self._device,
                    num_classes=len(self.category_mapping.keys()) if self.category_mapping else None,
                )
                if model_path:
                    model_params["pretrain_weights"] = model_path

                model = model(**model_params)
            elif isinstance(model, (RFDETRBase, RFDETRNano, RFDETRSmall, RFDETRMedium, RFDETRLarge)):
                model = model
            else:
                raise ValueError(
                    f"Model must be a Roboflow model string or one of {model_names} models, got {self.model}."
                )

        self.set_model(model)

    def perform_inference(
        self,
        image: np.ndarray,
    ):
        """This function should be implemented in a way that prediction should be performed using self.model and the
        prediction result should be set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted.
        """
        if self._use_universe:
            self._original_predictions = self.model.infer(image, confidence=self.confidence_threshold)
        else:
            self._original_predictions = [self.model.predict(image, threshold=self.confidence_threshold)]

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int]] | None = [[0, 0]],
        full_shape_list: list[list[int]] | None = None,
    ):
        """This function should be implemented in a way that self._original_predictions should be converted to a list of
        prediction.ObjectPrediction and set to self._object_prediction_list.

        self.mask_threshold can also be utilized.
        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        # compatibility for sahi v0.8.15
        shift_amount_list = fix_shift_amount_list(shift_amount_list)
        full_shape_list = fix_full_shape_list(full_shape_list)

        object_prediction_list: list[ObjectPrediction] = []

        if self._use_universe:
            from inference.core.entities.responses.inference import (
                ObjectDetectionInferenceResponse as InferenceObjectDetectionInferenceResponse,
            )
            from inference.core.entities.responses.inference import (
                ObjectDetectionPrediction as InferenceObjectDetectionPrediction,
            )

            original_reponses: list[InferenceObjectDetectionInferenceResponse] = self._original_predictions

            assert len(original_reponses) == len(shift_amount_list) == len(full_shape_list), (
                "Length mismatch between original responses, shift amounts, and full shapes."
            )

            for original_reponse, shift_amount, full_shape in zip(
                original_reponses,
                shift_amount_list,
                full_shape_list,
            ):
                for prediction in original_reponse.predictions:
                    prediction: InferenceObjectDetectionPrediction
                    bbox = [
                        prediction.x - prediction.width / 2,
                        prediction.y - prediction.height / 2,
                        prediction.x + prediction.width / 2,
                        prediction.y + prediction.height / 2,
                    ]
                    object_prediction = ObjectPrediction(
                        bbox=bbox,
                        category_id=prediction.class_id,
                        category_name=prediction.class_name,
                        score=prediction.confidence,
                        shift_amount=shift_amount,
                        full_shape=full_shape,
                    )
                    object_prediction_list.append(object_prediction)

        else:
            from supervision.detection.core import Detections

            original_detections: list[Detections] = self._original_predictions

            assert len(original_detections) == len(shift_amount_list) == len(full_shape_list), (
                "Length mismatch between original responses, shift amounts, and full shapes."
            )

            for original_detection, shift_amount, full_shape in zip(
                original_detections,
                shift_amount_list,
                full_shape_list,
            ):
                for xyxy, confidence, class_id in zip(
                    original_detection.xyxy,
                    original_detection.confidence,
                    original_detection.class_id,
                ):
                    object_prediction = ObjectPrediction(
                        bbox=xyxy,
                        category_id=int(class_id),
                        category_name=self.category_mapping.get(int(class_id), None),
                        score=float(confidence),
                        shift_amount=shift_amount,
                        full_shape=full_shape,
                    )
                    object_prediction_list.append(object_prediction)

        object_prediction_list_per_image = [object_prediction_list]
        self._object_prediction_list_per_image = object_prediction_list_per_image
Functions
__init__(model=None, model_path=None, config_path=None, device=None, mask_threshold=0.5, confidence_threshold=0.3, category_mapping=None, category_remapping=None, load_at_init=True, image_size=None, api_key=None)

Initialize the RoboflowDetectionModel with the given parameters.

Parameters:

Name Type Description Default
model_path str | None

str Path for the instance segmentation model weight

None
config_path str | None

str Path for the mmdetection instance segmentation model config file

None
device str | None

Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.

None
mask_threshold float

float Value to threshold mask pixels, should be between 0 and 1

0.5
confidence_threshold float

float All predictions with score < confidence_threshold will be discarded

0.3
category_mapping dict | None

dict: str to str Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}

None
category_remapping dict | None

dict: str to int Remap category ids based on category names, after performing inference e.g. {"car": 3}

None
load_at_init bool

bool If True, automatically loads the model at initialization

True
image_size int | None

int Inference input size.

None
Source code in sahi/models/roboflow.py
def __init__(
    self,
    model: Any | None = None,
    model_path: str | None = None,
    config_path: str | None = None,
    device: str | None = None,
    mask_threshold: float = 0.5,
    confidence_threshold: float = 0.3,
    category_mapping: dict | None = None,
    category_remapping: dict | None = None,
    load_at_init: bool = True,
    image_size: int | None = None,
    api_key: str | None = None,
):
    """Initialize the RoboflowDetectionModel with the given parameters.

    Args:
        model_path: str
            Path for the instance segmentation model weight
        config_path: str
            Path for the mmdetection instance segmentation model config file
        device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
        mask_threshold: float
            Value to threshold mask pixels, should be between 0 and 1
        confidence_threshold: float
            All predictions with score < confidence_threshold will be discarded
        category_mapping: dict: str to str
            Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
        category_remapping: dict: str to int
            Remap category ids based on category names, after performing inference e.g. {"car": 3}
        load_at_init: bool
            If True, automatically loads the model at initialization
        image_size: int
            Inference input size.
    """
    self._use_universe = model and isinstance(model, str)
    self._model = model
    self._device = device
    self._api_key = api_key

    if self._use_universe:
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "inference"]
    else:
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "rfdetr"]

    super().__init__(
        model=model,
        model_path=model_path,
        config_path=config_path,
        device=device,
        mask_threshold=mask_threshold,
        confidence_threshold=confidence_threshold,
        category_mapping=category_mapping,
        category_remapping=category_remapping,
        load_at_init=False,
        image_size=image_size,
    )

    if load_at_init:
        self.load_model()
load_model()

This function should be implemented in a way that detection model should be initialized and set to self.model.

(self.model_path, self.config_path, and self.device should be utilized)

Source code in sahi/models/roboflow.py
def load_model(self):
    """This function should be implemented in a way that detection model should be initialized and set to
    self.model.

    (self.model_path, self.config_path, and self.device should be utilized)
    """
    if self._use_universe:
        from inference import get_model
        from inference.core.env import API_KEY
        from inference.core.exceptions import RoboflowAPINotAuthorizedError

        api_key = self._api_key or API_KEY

        try:
            model = get_model(self._model, api_key=api_key)
        except RoboflowAPINotAuthorizedError as e:
            raise ValueError(
                "Authorization failed. Please pass a valid API key with "
                "the `api_key` parameter or set the `ROBOFLOW_API_KEY` environment variable."
            ) from e

        assert model.task_type == "object-detection", "Roboflow model must be an object detection model."

    else:
        from rfdetr.detr import RFDETRBase, RFDETRLarge, RFDETRMedium, RFDETRNano, RFDETRSmall

        model, model_path = self._model, self.model_path
        model_names = ("RFDETRBase", "RFDETRNano", "RFDETRSmall", "RFDETRMedium", "RFDETRLarge")
        if hasattr(model, "__name__") and model.__name__ in model_names:
            model_params = dict(
                resolution=int(self.image_size) if self.image_size else 560,
                device=self._device,
                num_classes=len(self.category_mapping.keys()) if self.category_mapping else None,
            )
            if model_path:
                model_params["pretrain_weights"] = model_path

            model = model(**model_params)
        elif isinstance(model, (RFDETRBase, RFDETRNano, RFDETRSmall, RFDETRMedium, RFDETRLarge)):
            model = model
        else:
            raise ValueError(
                f"Model must be a Roboflow model string or one of {model_names} models, got {self.model}."
            )

    self.set_model(model)
perform_inference(image)

This function should be implemented in a way that prediction should be performed using self.model and the prediction result should be set to self._original_predictions.

Parameters:

Name Type Description Default
image ndarray

np.ndarray A numpy array that contains the image to be predicted.

required
Source code in sahi/models/roboflow.py
def perform_inference(
    self,
    image: np.ndarray,
):
    """This function should be implemented in a way that prediction should be performed using self.model and the
    prediction result should be set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted.
    """
    if self._use_universe:
        self._original_predictions = self.model.infer(image, confidence=self.confidence_threshold)
    else:
        self._original_predictions = [self.model.predict(image, threshold=self.confidence_threshold)]
set_model(model, **kwargs)

This function should be implemented to instantiate a DetectionModel out of an already loaded model Args: model: Any Loaded model

Source code in sahi/models/roboflow.py
def set_model(self, model: Any, **kwargs):
    """
    This function should be implemented to instantiate a DetectionModel out of an already loaded model
    Args:
        model: Any
            Loaded model
    """
    self.model = model
rtdetr
Classes
RTDetrDetectionModel

Bases: UltralyticsDetectionModel

Source code in sahi/models/rtdetr.py
class RTDetrDetectionModel(UltralyticsDetectionModel):
    def load_model(self):
        """Detection model is initialized and set to self.model."""
        from ultralytics import RTDETR

        try:
            model_source = self.model_path or "rtdetr-l.pt"
            model = RTDETR(model_source)
            model.to(self.device)
            self.set_model(model)
        except Exception as e:
            raise TypeError("model_path is not a valid rtdetr model path: ", e)
Functions
load_model()

Detection model is initialized and set to self.model.

Source code in sahi/models/rtdetr.py
def load_model(self):
    """Detection model is initialized and set to self.model."""
    from ultralytics import RTDETR

    try:
        model_source = self.model_path or "rtdetr-l.pt"
        model = RTDETR(model_source)
        model.to(self.device)
        self.set_model(model)
    except Exception as e:
        raise TypeError("model_path is not a valid rtdetr model path: ", e)
torchvision
Classes
TorchVisionDetectionModel

Bases: DetectionModel

Source code in sahi/models/torchvision.py
class TorchVisionDetectionModel(DetectionModel):
    def __init__(self, *args, **kwargs):
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "torch", "torchvision"]
        super().__init__(*args, **kwargs)

    def load_model(self):
        import torch

        # read config params
        model_name = None
        num_classes = None
        if self.config_path is not None:
            with open(self.config_path) as stream:
                try:
                    config = yaml.safe_load(stream)
                except yaml.YAMLError as exc:
                    raise RuntimeError(exc)

            model_name = config.get("model_name", None)
            num_classes = config.get("num_classes", None)

        # complete params if not provided in config
        if not model_name:
            model_name = "fasterrcnn_resnet50_fpn"
            logger.warning(f"model_name not provided in config, using default model_type: {model_name}'")
        if num_classes is None:
            logger.warning("num_classes not provided in config, using default num_classes: 91")
            num_classes = 91
        if self.model_path is None:
            logger.warning("model_path not provided in config, using pretrained weights and default num_classes: 91.")
            weights = "DEFAULT"
            num_classes = 91
        else:
            weights = None

        # load model
        # Note: torchvision >= 0.13 is required for the 'weights' parameter
        model = MODEL_NAME_TO_CONSTRUCTOR[model_name](num_classes=num_classes, weights=weights)
        if self.model_path:
            try:
                model.load_state_dict(torch.load(self.model_path))
            except Exception as e:
                logger.error(f"Invalid {self.model_path=}")
                raise TypeError("model_path is not a valid torchvision model path: ", e)

        self.set_model(model)

    def set_model(self, model: Any):
        """Sets the underlying TorchVision model.

        Args:
            model: Any
                A TorchVision model
        """

        model.eval()
        self.model = model.to(self.device)

        # set category_mapping

        if self.category_mapping is None:
            category_names = {str(i): COCO_CLASSES[i] for i in range(len(COCO_CLASSES))}
            self.category_mapping = category_names

    def perform_inference(self, image: np.ndarray, image_size: int | None = None):
        """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
            image_size: int
                Inference input size.
        """
        from sahi.utils.torch_utils import to_float_tensor

        # arrange model input size
        if self.image_size is not None:
            # get min and max of image height and width
            min_shape, max_shape = min(image.shape[:2]), max(image.shape[:2])
            # torchvision resize transform scales the shorter dimension to the target size
            # we want to scale the longer dimension to the target size
            image_size = self.image_size * min_shape / max_shape
            self.model.transform.min_size = (image_size,)  # default is (800,)
            self.model.transform.max_size = image_size  # default is 1333

        image = to_float_tensor(image)
        image = image.to(self.device)
        prediction_result = self.model([image])

        self._original_predictions = prediction_result

    @property
    def num_categories(self):
        """Returns number of categories."""
        return len(self.category_mapping)

    @property
    def has_mask(self):
        """Returns if model output contains segmentation mask."""
        return hasattr(self.model, "roi_heads") and hasattr(self.model.roi_heads, "mask_predictor")

    @property
    def category_names(self):
        return list(self.category_mapping.values())

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int]] | None = [[0, 0]],
        full_shape_list: list[list[int]] | None = None,
    ):
        """self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.

        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        original_predictions = self._original_predictions

        # compatilibty for sahi v0.8.20
        if isinstance(shift_amount_list[0], int):
            shift_amount_list = [shift_amount_list]
        if full_shape_list is not None and isinstance(full_shape_list[0], int):
            full_shape_list = [full_shape_list]

        for image_predictions in original_predictions:
            object_prediction_list_per_image = []

            # get indices of boxes with score > confidence_threshold
            scores = image_predictions["scores"].cpu().detach().numpy()
            selected_indices = np.where(scores > self.confidence_threshold)[0]

            # parse boxes, masks, scores, category_ids from predictions
            category_ids = list(image_predictions["labels"][selected_indices].cpu().detach().numpy())
            boxes = list(image_predictions["boxes"][selected_indices].cpu().detach().numpy())
            scores = scores[selected_indices]

            # check if predictions contain mask
            masks = image_predictions.get("masks", None)
            if masks is not None:
                masks = list(
                    (image_predictions["masks"][selected_indices] > self.mask_threshold).cpu().detach().numpy()
                )
            else:
                masks = None

            # create object_prediction_list
            object_prediction_list = []

            shift_amount = shift_amount_list[0]
            full_shape = None if full_shape_list is None else full_shape_list[0]

            for ind in range(len(boxes)):
                if masks is not None:
                    segmentation = get_coco_segmentation_from_bool_mask(np.array(masks[ind]))
                else:
                    segmentation = None

                object_prediction = ObjectPrediction(
                    bbox=boxes[ind],
                    segmentation=segmentation,
                    category_id=int(category_ids[ind]),
                    category_name=self.category_mapping[str(int(category_ids[ind]))],
                    shift_amount=shift_amount,
                    score=scores[ind],
                    full_shape=full_shape,
                )
                object_prediction_list.append(object_prediction)
            object_prediction_list_per_image.append(object_prediction_list)

        self._object_prediction_list_per_image = object_prediction_list_per_image
Attributes
has_mask property

Returns if model output contains segmentation mask.

num_categories property

Returns number of categories.

Functions
perform_inference(image, image_size=None)

Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Parameters:

Name Type Description Default
image ndarray

np.ndarray A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.

required
image_size int | None

int Inference input size.

None
Source code in sahi/models/torchvision.py
def perform_inference(self, image: np.ndarray, image_size: int | None = None):
    """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        image_size: int
            Inference input size.
    """
    from sahi.utils.torch_utils import to_float_tensor

    # arrange model input size
    if self.image_size is not None:
        # get min and max of image height and width
        min_shape, max_shape = min(image.shape[:2]), max(image.shape[:2])
        # torchvision resize transform scales the shorter dimension to the target size
        # we want to scale the longer dimension to the target size
        image_size = self.image_size * min_shape / max_shape
        self.model.transform.min_size = (image_size,)  # default is (800,)
        self.model.transform.max_size = image_size  # default is 1333

    image = to_float_tensor(image)
    image = image.to(self.device)
    prediction_result = self.model([image])

    self._original_predictions = prediction_result
set_model(model)

Sets the underlying TorchVision model.

Parameters:

Name Type Description Default
model Any

Any A TorchVision model

required
Source code in sahi/models/torchvision.py
def set_model(self, model: Any):
    """Sets the underlying TorchVision model.

    Args:
        model: Any
            A TorchVision model
    """

    model.eval()
    self.model = model.to(self.device)

    # set category_mapping

    if self.category_mapping is None:
        category_names = {str(i): COCO_CLASSES[i] for i in range(len(COCO_CLASSES))}
        self.category_mapping = category_names
Functions
ultralytics
Classes
UltralyticsDetectionModel

Bases: DetectionModel

Detection model for Ultralytics YOLO models.

Supports both PyTorch (.pt) and ONNX (.onnx) models.

Source code in sahi/models/ultralytics.py
class UltralyticsDetectionModel(DetectionModel):
    """Detection model for Ultralytics YOLO models.

    Supports both PyTorch (.pt) and ONNX (.onnx) models.
    """

    def __init__(self, *args, **kwargs):
        self.fuse: bool = kwargs.pop("fuse", False)
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "ultralytics"]
        super().__init__(*args, **kwargs)

    def load_model(self):
        """Detection model is initialized and set to self.model.

        Supports both PyTorch (.pt) and ONNX (.onnx) models.
        """

        from ultralytics import YOLO

        if self.model_path and ".onnx" in self.model_path:
            check_requirements(["onnx", "onnxruntime"])

        try:
            model = YOLO(self.model_path)
            # Only call .to(device) for PyTorch models, not ONNX
            if self.model_path and not self.model_path.endswith(".onnx"):
                model.to(self.device)
            self.set_model(model)
            if self.fuse and hasattr(model, "fuse"):
                model.fuse()

        except Exception as e:
            raise TypeError("model_path is not a valid Ultralytics model path: ", e)

    def set_model(self, model: Any, **kwargs):
        """Sets the underlying Ultralytics model.

        Args:
            model: Any
                A Ultralytics model
        """

        self.model = model
        # set category_mapping
        if not self.category_mapping:
            category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
            self.category_mapping = category_mapping

    def perform_inference(self, image: np.ndarray):
        """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        """

        # Confirm model is loaded

        import torch

        if self.model is None:
            raise ValueError("Model is not loaded, load it by calling .load_model()")

        kwargs = {"cfg": self.config_path, "verbose": False, "conf": self.confidence_threshold, "device": self.device}

        if self.image_size is not None:
            kwargs = {"imgsz": self.image_size, **kwargs}

        prediction_result = self.model(image[:, :, ::-1], **kwargs)  # YOLO expects numpy arrays to have BGR

        # Handle different result types for PyTorch vs ONNX models
        # ONNX models might return results in a different format
        if self.has_mask:
            from ultralytics.engine.results import Masks

            if not prediction_result[0].masks:
                # Create empty masks if none exist
                if hasattr(self.model, "device"):
                    device = self.model.device
                else:
                    device = "cpu"  # Default for ONNX models
                prediction_result[0].masks = Masks(
                    torch.tensor([], device=device), prediction_result[0].boxes.orig_shape
                )

            # We do not filter results again as confidence threshold is already applied above
            prediction_result = [
                (
                    result.boxes.data,
                    result.masks.data,
                )
                for result in prediction_result
            ]
        elif self.is_obb:
            # For OBB task, get OBB points in xyxyxyxy format
            device = getattr(self.model, "device", "cpu")
            prediction_result = [
                (
                    # Get OBB data: xyxy, conf, cls
                    torch.cat(
                        [
                            result.obb.xyxy,  # box coordinates
                            result.obb.conf.unsqueeze(-1),  # confidence scores
                            result.obb.cls.unsqueeze(-1),  # class ids
                        ],
                        dim=1,
                    )
                    if result.obb is not None
                    else torch.empty((0, 6), device=device),
                    # Get OBB points in (N, 4, 2) format
                    result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=device),
                )
                for result in prediction_result
            ]
        else:  # If model doesn't do segmentation or OBB then no need to check masks
            # We do not filter results again as confidence threshold is already applied above
            prediction_result = [result.boxes.data for result in prediction_result]

        self._original_predictions = prediction_result
        self._original_shape = image.shape

    @property
    def category_names(self):
        # For ONNX models, names might not be available, use category_mapping
        if hasattr(self.model, "names") and self.model.names:
            return self.model.names.values()
        elif self.category_mapping:
            return list(self.category_mapping.values())
        else:
            raise ValueError("Category names not available. Please provide category_mapping for ONNX models.")

    @property
    def num_categories(self):
        """Returns number of categories."""
        if hasattr(self.model, "names") and self.model.names:
            return len(self.model.names)
        elif self.category_mapping:
            return len(self.category_mapping)
        else:
            raise ValueError("Cannot determine number of categories. Please provide category_mapping for ONNX models.")

    @property
    def has_mask(self):
        """Returns if model output contains segmentation mask."""
        # Check if model has 'task' attribute (for both .pt and .onnx models)
        if hasattr(self.model, "overrides") and "task" in self.model.overrides:
            return self.model.overrides["task"] == "segment"
        # For ONNX models, task might be stored differently
        elif hasattr(self.model, "task"):
            return self.model.task == "segment"
        # For ONNX models without task info, check model path
        elif self.model_path and isinstance(self.model_path, str):
            return "seg" in self.model_path.lower()
        return False

    @property
    def is_obb(self):
        """Returns if model output contains oriented bounding boxes."""
        # Check if model has 'task' attribute (for both .pt and .onnx models)
        if hasattr(self.model, "overrides") and "task" in self.model.overrides:
            return self.model.overrides["task"] == "obb"
        # For ONNX models, task might be stored differently
        elif hasattr(self.model, "task"):
            return self.model.task == "obb"
        # For ONNX models without task info, check model path
        elif self.model_path and isinstance(self.model_path, str):
            return "obb" in self.model_path.lower()
        return False

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int]] | None = [[0, 0]],
        full_shape_list: list[list[int]] | None = None,
    ):
        """self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.

        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        original_predictions = self._original_predictions

        # compatibility for sahi v0.8.15
        shift_amount_list = fix_shift_amount_list(shift_amount_list)
        full_shape_list = fix_full_shape_list(full_shape_list)

        # handle all predictions
        object_prediction_list_per_image = []

        for image_ind, image_predictions in enumerate(original_predictions):
            shift_amount = shift_amount_list[image_ind]
            full_shape = None if full_shape_list is None else full_shape_list[image_ind]
            object_prediction_list = []

            # Extract boxes and optional masks/obb
            if self.has_mask or self.is_obb:
                boxes = image_predictions[0].cpu().detach().numpy()
                masks_or_points = image_predictions[1].cpu().detach().numpy()
            else:
                boxes = image_predictions.data.cpu().detach().numpy()
                masks_or_points = None

            # Process each prediction
            for pred_ind, prediction in enumerate(boxes):
                # Get bbox coordinates
                bbox = prediction[:4].tolist()
                score = prediction[4]
                category_id = int(prediction[5])
                category_name = self.category_mapping[str(category_id)]

                # Fix box coordinates
                bbox = [max(0, coord) for coord in bbox]
                if full_shape is not None:
                    bbox[0] = min(full_shape[1], bbox[0])
                    bbox[1] = min(full_shape[0], bbox[1])
                    bbox[2] = min(full_shape[1], bbox[2])
                    bbox[3] = min(full_shape[0], bbox[3])

                # Ignore invalid predictions
                if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
                    logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
                    continue

                # Get segmentation or OBB points
                segmentation = None
                if masks_or_points is not None:
                    if self.has_mask:
                        bool_mask = masks_or_points[pred_ind]
                        # Resize mask to original image size
                        bool_mask = cv2.resize(
                            bool_mask.astype(np.uint8), (self._original_shape[1], self._original_shape[0])
                        )
                        segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
                    else:  # is_obb
                        obb_points = masks_or_points[pred_ind]  # Get OBB points for this prediction
                        segmentation = [obb_points.reshape(-1).tolist()]

                    if len(segmentation) == 0:
                        continue

                # Create and append object prediction
                object_prediction = ObjectPrediction(
                    bbox=bbox,
                    category_id=category_id,
                    score=score,
                    segmentation=segmentation,
                    category_name=category_name,
                    shift_amount=shift_amount,
                    full_shape=self._original_shape[:2] if full_shape is None else full_shape,  # (height, width)
                )
                object_prediction_list.append(object_prediction)

            object_prediction_list_per_image.append(object_prediction_list)

        self._object_prediction_list_per_image = object_prediction_list_per_image
Attributes
has_mask property

Returns if model output contains segmentation mask.

is_obb property

Returns if model output contains oriented bounding boxes.

num_categories property

Returns number of categories.

Functions
load_model()

Detection model is initialized and set to self.model.

Supports both PyTorch (.pt) and ONNX (.onnx) models.

Source code in sahi/models/ultralytics.py
def load_model(self):
    """Detection model is initialized and set to self.model.

    Supports both PyTorch (.pt) and ONNX (.onnx) models.
    """

    from ultralytics import YOLO

    if self.model_path and ".onnx" in self.model_path:
        check_requirements(["onnx", "onnxruntime"])

    try:
        model = YOLO(self.model_path)
        # Only call .to(device) for PyTorch models, not ONNX
        if self.model_path and not self.model_path.endswith(".onnx"):
            model.to(self.device)
        self.set_model(model)
        if self.fuse and hasattr(model, "fuse"):
            model.fuse()

    except Exception as e:
        raise TypeError("model_path is not a valid Ultralytics model path: ", e)
perform_inference(image)

Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Parameters:

Name Type Description Default
image ndarray

np.ndarray A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.

required
Source code in sahi/models/ultralytics.py
def perform_inference(self, image: np.ndarray):
    """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
    """

    # Confirm model is loaded

    import torch

    if self.model is None:
        raise ValueError("Model is not loaded, load it by calling .load_model()")

    kwargs = {"cfg": self.config_path, "verbose": False, "conf": self.confidence_threshold, "device": self.device}

    if self.image_size is not None:
        kwargs = {"imgsz": self.image_size, **kwargs}

    prediction_result = self.model(image[:, :, ::-1], **kwargs)  # YOLO expects numpy arrays to have BGR

    # Handle different result types for PyTorch vs ONNX models
    # ONNX models might return results in a different format
    if self.has_mask:
        from ultralytics.engine.results import Masks

        if not prediction_result[0].masks:
            # Create empty masks if none exist
            if hasattr(self.model, "device"):
                device = self.model.device
            else:
                device = "cpu"  # Default for ONNX models
            prediction_result[0].masks = Masks(
                torch.tensor([], device=device), prediction_result[0].boxes.orig_shape
            )

        # We do not filter results again as confidence threshold is already applied above
        prediction_result = [
            (
                result.boxes.data,
                result.masks.data,
            )
            for result in prediction_result
        ]
    elif self.is_obb:
        # For OBB task, get OBB points in xyxyxyxy format
        device = getattr(self.model, "device", "cpu")
        prediction_result = [
            (
                # Get OBB data: xyxy, conf, cls
                torch.cat(
                    [
                        result.obb.xyxy,  # box coordinates
                        result.obb.conf.unsqueeze(-1),  # confidence scores
                        result.obb.cls.unsqueeze(-1),  # class ids
                    ],
                    dim=1,
                )
                if result.obb is not None
                else torch.empty((0, 6), device=device),
                # Get OBB points in (N, 4, 2) format
                result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=device),
            )
            for result in prediction_result
        ]
    else:  # If model doesn't do segmentation or OBB then no need to check masks
        # We do not filter results again as confidence threshold is already applied above
        prediction_result = [result.boxes.data for result in prediction_result]

    self._original_predictions = prediction_result
    self._original_shape = image.shape
set_model(model, **kwargs)

Sets the underlying Ultralytics model.

Parameters:

Name Type Description Default
model Any

Any A Ultralytics model

required
Source code in sahi/models/ultralytics.py
def set_model(self, model: Any, **kwargs):
    """Sets the underlying Ultralytics model.

    Args:
        model: Any
            A Ultralytics model
    """

    self.model = model
    # set category_mapping
    if not self.category_mapping:
        category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
        self.category_mapping = category_mapping
Functions
yolo-world
Classes
YOLOWorldDetectionModel

Bases: UltralyticsDetectionModel

Source code in sahi/models/yolo-world.py
class YOLOWorldDetectionModel(UltralyticsDetectionModel):
    def load_model(self):
        """Detection model is initialized and set to self.model."""

        from ultralytics import YOLOWorld

        try:
            model_source = self.model_path or "yolov8s-worldv2.pt"
            model = YOLOWorld(model_source)
            model.to(self.device)
            self.set_model(model)
        except Exception as e:
            raise TypeError("model_path is not a valid yolo world model path: ", e)
Functions
load_model()

Detection model is initialized and set to self.model.

Source code in sahi/models/yolo-world.py
def load_model(self):
    """Detection model is initialized and set to self.model."""

    from ultralytics import YOLOWorld

    try:
        model_source = self.model_path or "yolov8s-worldv2.pt"
        model = YOLOWorld(model_source)
        model.to(self.device)
        self.set_model(model)
    except Exception as e:
        raise TypeError("model_path is not a valid yolo world model path: ", e)
yoloe
Classes
YOLOEDetectionModel

Bases: UltralyticsDetectionModel

YOLOE Detection Model for open-vocabulary detection and segmentation.

YOLOE (Real-Time Seeing Anything) is a zero-shot, promptable YOLO model designed for open-vocabulary detection and segmentation. It supports text prompts, visual prompts, and prompt-free detection with internal vocabulary (1200+ categories).

Key Features
  • Open-vocabulary detection: Detect any object class via text prompts
  • Visual prompting: One-shot detection using reference images
  • Instance segmentation: Built-in segmentation for detected objects
  • Real-time performance: Maintains YOLO speed with no inference overhead
  • Prompt-free mode: Uses internal vocabulary for open-set recognition
Available Models

Text/Visual Prompt models: - yoloe-11s-seg.pt, yoloe-11m-seg.pt, yoloe-11l-seg.pt - yoloe-v8s-seg.pt, yoloe-v8m-seg.pt, yoloe-v8l-seg.pt

Prompt-free models: - yoloe-11s-seg-pf.pt, yoloe-11m-seg-pf.pt, yoloe-11l-seg-pf.pt - yoloe-v8s-seg-pf.pt, yoloe-v8m-seg-pf.pt, yoloe-v8l-seg-pf.pt

Usage Text Prompts

from sahi import AutoDetectionModel

# Load YOLOE model
detection_model = AutoDetectionModel.from_pretrained(
    model_type="yoloe",
    model_path="yoloe-11l-seg.pt",
    confidence_threshold=0.3,
    device="cuda:0"
)

# Set text prompts for specific classes
detection_model.model.set_classes(
    ["person", "car", "traffic light"],
    detection_model.model.get_text_pe(["person", "car", "traffic light"])
)

# Perform prediction
from sahi.predict import get_prediction
result = get_prediction("image.jpg", detection_model)

Usage for standard detection (no prompts)

from sahi import AutoDetectionModel

# Load YOLOE model (works like standard YOLO)
detection_model = AutoDetectionModel.from_pretrained(
    model_type="yoloe",
    model_path="yoloe-11l-seg.pt",
    confidence_threshold=0.3,
    device="cuda:0"
)

# Perform prediction without prompts (uses internal vocabulary)
from sahi.predict import get_sliced_prediction
result = get_sliced_prediction(
    "image.jpg",
    detection_model,
    slice_height=512,
    slice_width=512,
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2
)
Note
  • YOLOE models perform instance segmentation by default
  • When used without prompts, YOLOE performs like standard YOLO11 with identical speed
  • For visual prompting, see Ultralytics YOLOE documentation
  • YOLOE achieves +3.5 AP over YOLO-Worldv2 on LVIS with 1.4x faster inference
References
  • Paper: https://arxiv.org/abs/2503.07465
  • Docs: https://docs.ultralytics.com/models/yoloe/
  • GitHub: https://github.com/THU-MIG/yoloe
Source code in sahi/models/yoloe.py
class YOLOEDetectionModel(UltralyticsDetectionModel):
    """YOLOE Detection Model for open-vocabulary detection and segmentation.

    YOLOE (Real-Time Seeing Anything) is a zero-shot, promptable YOLO model designed for
    open-vocabulary detection and segmentation. It supports text prompts, visual prompts,
    and prompt-free detection with internal vocabulary (1200+ categories).

    Key Features:
        - Open-vocabulary detection: Detect any object class via text prompts
        - Visual prompting: One-shot detection using reference images
        - Instance segmentation: Built-in segmentation for detected objects
        - Real-time performance: Maintains YOLO speed with no inference overhead
        - Prompt-free mode: Uses internal vocabulary for open-set recognition

    Available Models:
        Text/Visual Prompt models:
            - yoloe-11s-seg.pt, yoloe-11m-seg.pt, yoloe-11l-seg.pt
            - yoloe-v8s-seg.pt, yoloe-v8m-seg.pt, yoloe-v8l-seg.pt

        Prompt-free models:
            - yoloe-11s-seg-pf.pt, yoloe-11m-seg-pf.pt, yoloe-11l-seg-pf.pt
            - yoloe-v8s-seg-pf.pt, yoloe-v8m-seg-pf.pt, yoloe-v8l-seg-pf.pt

    !!! example "Usage Text Prompts"
        ```python
        from sahi import AutoDetectionModel

        # Load YOLOE model
        detection_model = AutoDetectionModel.from_pretrained(
            model_type="yoloe",
            model_path="yoloe-11l-seg.pt",
            confidence_threshold=0.3,
            device="cuda:0"
        )

        # Set text prompts for specific classes
        detection_model.model.set_classes(
            ["person", "car", "traffic light"],
            detection_model.model.get_text_pe(["person", "car", "traffic light"])
        )

        # Perform prediction
        from sahi.predict import get_prediction
        result = get_prediction("image.jpg", detection_model)
        ```

    !!! example "Usage for standard detection (no prompts)"
        ```python
        from sahi import AutoDetectionModel

        # Load YOLOE model (works like standard YOLO)
        detection_model = AutoDetectionModel.from_pretrained(
            model_type="yoloe",
            model_path="yoloe-11l-seg.pt",
            confidence_threshold=0.3,
            device="cuda:0"
        )

        # Perform prediction without prompts (uses internal vocabulary)
        from sahi.predict import get_sliced_prediction
        result = get_sliced_prediction(
            "image.jpg",
            detection_model,
            slice_height=512,
            slice_width=512,
            overlap_height_ratio=0.2,
            overlap_width_ratio=0.2
        )
        ```

    Note:
        - YOLOE models perform instance segmentation by default
        - When used without prompts, YOLOE performs like standard YOLO11 with identical speed
        - For visual prompting, see Ultralytics YOLOE documentation
        - YOLOE achieves +3.5 AP over YOLO-Worldv2 on LVIS with 1.4x faster inference

    References:
        - Paper: https://arxiv.org/abs/2503.07465
        - Docs: https://docs.ultralytics.com/models/yoloe/
        - GitHub: https://github.com/THU-MIG/yoloe
    """

    def load_model(self):
        """Loads the YOLOE detection model from the specified path.

        Initializes the YOLOE model with the given model path or uses the default
        'yoloe-11s-seg.pt' if no path is provided. The model is then moved to the
        specified device (CPU/GPU).

        By default, YOLOE works in prompt-free mode using its internal vocabulary
        of 1200+ categories. To use text prompts for specific classes, call
        model.set_classes() after loading:

            model.set_classes(["person", "car"], model.get_text_pe(["person", "car"]))

        Raises:
            TypeError: If the model_path is not a valid YOLOE model path or if
                      the ultralytics package with YOLOE support is not installed.
        """
        from ultralytics import YOLOE

        try:
            model_source = self.model_path or "yoloe-11s-seg.pt"
            model = YOLOE(model_source)
            model.to(self.device)
            self.set_model(model)
        except Exception as e:
            raise TypeError(f"model_path is not a valid YOLOE model path: {e}") from e
Functions
load_model()

Loads the YOLOE detection model from the specified path.

Initializes the YOLOE model with the given model path or uses the default 'yoloe-11s-seg.pt' if no path is provided. The model is then moved to the specified device (CPU/GPU).

By default, YOLOE works in prompt-free mode using its internal vocabulary of 1200+ categories. To use text prompts for specific classes, call model.set_classes() after loading:

model.set_classes(["person", "car"], model.get_text_pe(["person", "car"]))

Raises:

Type Description
TypeError

If the model_path is not a valid YOLOE model path or if the ultralytics package with YOLOE support is not installed.

Source code in sahi/models/yoloe.py
def load_model(self):
    """Loads the YOLOE detection model from the specified path.

    Initializes the YOLOE model with the given model path or uses the default
    'yoloe-11s-seg.pt' if no path is provided. The model is then moved to the
    specified device (CPU/GPU).

    By default, YOLOE works in prompt-free mode using its internal vocabulary
    of 1200+ categories. To use text prompts for specific classes, call
    model.set_classes() after loading:

        model.set_classes(["person", "car"], model.get_text_pe(["person", "car"]))

    Raises:
        TypeError: If the model_path is not a valid YOLOE model path or if
                  the ultralytics package with YOLOE support is not installed.
    """
    from ultralytics import YOLOE

    try:
        model_source = self.model_path or "yoloe-11s-seg.pt"
        model = YOLOE(model_source)
        model.to(self.device)
        self.set_model(model)
    except Exception as e:
        raise TypeError(f"model_path is not a valid YOLOE model path: {e}") from e
yolov5
Classes
Yolov5DetectionModel

Bases: DetectionModel

Source code in sahi/models/yolov5.py
class Yolov5DetectionModel(DetectionModel):
    def __init__(self, *args, **kwargs):
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "yolov5", "torch"]
        super().__init__(*args, **kwargs)

    def load_model(self):
        """Detection model is initialized and set to self.model."""
        import yolov5

        try:
            model = yolov5.load(self.model_path, device=self.device)
            self.set_model(model)
        except Exception as e:
            raise TypeError("model_path is not a valid yolov5 model path: ", e)

    def set_model(self, model: Any):
        """Sets the underlying YOLOv5 model.

        Args:
            model: Any
                A YOLOv5 model
        """

        if model.__class__.__module__ not in ["yolov5.models.common", "models.common"]:
            raise Exception(f"Not a yolov5 model: {type(model)}")

        model.conf = self.confidence_threshold
        self.model = model

        # set category_mapping
        if not self.category_mapping:
            category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
            self.category_mapping = category_mapping

    def perform_inference(self, image: np.ndarray):
        """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        """

        # Confirm model is loaded
        if self.model is None:
            raise ValueError("Model is not loaded, load it by calling .load_model()")
        if self.image_size is not None:
            prediction_result = self.model(image, size=self.image_size)
        else:
            prediction_result = self.model(image)

        self._original_predictions = prediction_result

    @property
    def num_categories(self):
        """Returns number of categories."""
        return len(self.model.names)

    @property
    def has_mask(self):
        """Returns if model output contains segmentation mask."""

        return False  # fix when yolov5 supports segmentation models

    @property
    def category_names(self):
        if check_package_minimum_version("yolov5", "6.2.0"):
            return list(self.model.names.values())
        else:
            return self.model.names

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int]] | None = [[0, 0]],
        full_shape_list: list[list[int]] | None = None,
    ):
        """self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.

        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        original_predictions = self._original_predictions

        # compatilibty for sahi v0.8.15
        shift_amount_list = fix_shift_amount_list(shift_amount_list)
        full_shape_list = fix_full_shape_list(full_shape_list)

        # handle all predictions
        object_prediction_list_per_image = []
        for image_ind, image_predictions_in_xyxy_format in enumerate(original_predictions.xyxy):
            shift_amount = shift_amount_list[image_ind]
            full_shape = None if full_shape_list is None else full_shape_list[image_ind]
            object_prediction_list = []

            # process predictions
            for prediction in image_predictions_in_xyxy_format.cpu().detach().numpy():
                x1 = prediction[0]
                y1 = prediction[1]
                x2 = prediction[2]
                y2 = prediction[3]
                bbox = [x1, y1, x2, y2]
                score = prediction[4]
                category_id = int(prediction[5])
                category_name = self.category_mapping[str(category_id)]

                # fix negative box coords
                bbox[0] = max(0, bbox[0])
                bbox[1] = max(0, bbox[1])
                bbox[2] = max(0, bbox[2])
                bbox[3] = max(0, bbox[3])

                # fix out of image box coords
                if full_shape is not None:
                    bbox[0] = min(full_shape[1], bbox[0])
                    bbox[1] = min(full_shape[0], bbox[1])
                    bbox[2] = min(full_shape[1], bbox[2])
                    bbox[3] = min(full_shape[0], bbox[3])

                # ignore invalid predictions
                if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
                    logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
                    continue

                object_prediction = ObjectPrediction(
                    bbox=bbox,
                    category_id=category_id,
                    score=score,
                    segmentation=None,
                    category_name=category_name,
                    shift_amount=shift_amount,
                    full_shape=full_shape,
                )
                object_prediction_list.append(object_prediction)
            object_prediction_list_per_image.append(object_prediction_list)

        self._object_prediction_list_per_image = object_prediction_list_per_image
Attributes
has_mask property

Returns if model output contains segmentation mask.

num_categories property

Returns number of categories.

Functions
load_model()

Detection model is initialized and set to self.model.

Source code in sahi/models/yolov5.py
def load_model(self):
    """Detection model is initialized and set to self.model."""
    import yolov5

    try:
        model = yolov5.load(self.model_path, device=self.device)
        self.set_model(model)
    except Exception as e:
        raise TypeError("model_path is not a valid yolov5 model path: ", e)
perform_inference(image)

Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Parameters:

Name Type Description Default
image ndarray

np.ndarray A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.

required
Source code in sahi/models/yolov5.py
def perform_inference(self, image: np.ndarray):
    """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
    """

    # Confirm model is loaded
    if self.model is None:
        raise ValueError("Model is not loaded, load it by calling .load_model()")
    if self.image_size is not None:
        prediction_result = self.model(image, size=self.image_size)
    else:
        prediction_result = self.model(image)

    self._original_predictions = prediction_result
set_model(model)

Sets the underlying YOLOv5 model.

Parameters:

Name Type Description Default
model Any

Any A YOLOv5 model

required
Source code in sahi/models/yolov5.py
def set_model(self, model: Any):
    """Sets the underlying YOLOv5 model.

    Args:
        model: Any
            A YOLOv5 model
    """

    if model.__class__.__module__ not in ["yolov5.models.common", "models.common"]:
        raise Exception(f"Not a yolov5 model: {type(model)}")

    model.conf = self.confidence_threshold
    self.model = model

    # set category_mapping
    if not self.category_mapping:
        category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
        self.category_mapping = category_mapping
Functions

postprocess

Modules
combine
Classes
PostprocessPredictions

Utilities for calculating IOU/IOS based match for given ObjectPredictions.

Source code in sahi/postprocess/combine.py
class PostprocessPredictions:
    """Utilities for calculating IOU/IOS based match for given ObjectPredictions."""

    def __init__(
        self,
        match_threshold: float = 0.5,
        match_metric: str = "IOU",
        class_agnostic: bool = True,
    ):
        self.match_threshold = match_threshold
        self.class_agnostic = class_agnostic
        self.match_metric = match_metric

        check_requirements(["torch"])

    def __call__(self, predictions: list[ObjectPrediction]):
        raise NotImplementedError()
Functions
batched_greedy_nmm(object_predictions_as_tensor, match_metric='IOU', match_threshold=0.5)

Apply greedy version of non-maximum merging per category to avoid detecting too many overlapping bounding boxes for a given object.

Parameters:

Name Type Description Default
object_predictions_as_tensor tensor

(tensor) The location preds for the image along with the class predscores, Shape: [num_boxes,5].

required
match_metric str

(str) IOU or IOS

'IOU'
match_threshold float

(float) The overlap thresh for match metric.

0.5

Returns: keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices to keep to a list of prediction indices to be merged.

Source code in sahi/postprocess/combine.py
def batched_greedy_nmm(
    object_predictions_as_tensor: torch.tensor,
    match_metric: str = "IOU",
    match_threshold: float = 0.5,
):
    """Apply greedy version of non-maximum merging per category to avoid detecting too many overlapping bounding boxes
    for a given object.

    Args:
        object_predictions_as_tensor: (tensor) The location preds for the image
            along with the class predscores, Shape: [num_boxes,5].
        match_metric: (str) IOU or IOS
        match_threshold: (float) The overlap thresh for
            match metric.
    Returns:
        keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices
        to keep to a list of prediction indices to be merged.
    """
    category_ids = object_predictions_as_tensor[:, 5].squeeze()
    keep_to_merge_list = {}
    for category_id in torch.unique(category_ids):
        curr_indices = torch.where(category_ids == category_id)[0]
        curr_keep_to_merge_list = greedy_nmm(object_predictions_as_tensor[curr_indices], match_metric, match_threshold)
        curr_indices_list = curr_indices.tolist()
        for curr_keep, curr_merge_list in curr_keep_to_merge_list.items():
            keep = curr_indices_list[curr_keep]
            merge_list = [curr_indices_list[curr_merge_ind] for curr_merge_ind in curr_merge_list]
            keep_to_merge_list[keep] = merge_list
    return keep_to_merge_list
batched_nmm(object_predictions_as_tensor, match_metric='IOU', match_threshold=0.5)

Apply non-maximum merging per category to avoid detecting too many overlapping bounding boxes for a given object.

Parameters:

Name Type Description Default
object_predictions_as_tensor Tensor

(tensor) The location preds for the image along with the class predscores, Shape: [num_boxes,5].

required
match_metric str

(str) IOU or IOS

'IOU'
match_threshold float

(float) The overlap thresh for match metric.

0.5

Returns: keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices to keep to a list of prediction indices to be merged.

Source code in sahi/postprocess/combine.py
def batched_nmm(
    object_predictions_as_tensor: torch.Tensor,
    match_metric: str = "IOU",
    match_threshold: float = 0.5,
):
    """Apply non-maximum merging per category to avoid detecting too many overlapping bounding boxes for a given object.

    Args:
        object_predictions_as_tensor: (tensor) The location preds for the image
            along with the class predscores, Shape: [num_boxes,5].
        match_metric: (str) IOU or IOS
        match_threshold: (float) The overlap thresh for
            match metric.
    Returns:
        keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices
        to keep to a list of prediction indices to be merged.
    """
    category_ids = object_predictions_as_tensor[:, 5].squeeze()
    keep_to_merge_list = {}
    for category_id in torch.unique(category_ids):
        curr_indices = torch.where(category_ids == category_id)[0]
        curr_keep_to_merge_list = nmm(object_predictions_as_tensor[curr_indices], match_metric, match_threshold)
        curr_indices_list = curr_indices.tolist()
        for curr_keep, curr_merge_list in curr_keep_to_merge_list.items():
            keep = curr_indices_list[curr_keep]
            merge_list = [curr_indices_list[curr_merge_ind] for curr_merge_ind in curr_merge_list]
            keep_to_merge_list[keep] = merge_list
    return keep_to_merge_list
batched_nms(predictions, match_metric='IOU', match_threshold=0.5)

Apply non-maximum suppression to avoid detecting too many overlapping bounding boxes for a given object.

Parameters:

Name Type Description Default
predictions tensor

(tensor) The location preds for the image along with the class predscores, Shape: [num_boxes,5].

required
match_metric str

(str) IOU or IOS

'IOU'
match_threshold float

(float) The overlap thresh for match metric.

0.5

Returns: A list of filtered indexes, Shape: [ ,]

Source code in sahi/postprocess/combine.py
def batched_nms(predictions: torch.tensor, match_metric: str = "IOU", match_threshold: float = 0.5):
    """Apply non-maximum suppression to avoid detecting too many overlapping bounding boxes for a given object.

    Args:
        predictions: (tensor) The location preds for the image
            along with the class predscores, Shape: [num_boxes,5].
        match_metric: (str) IOU or IOS
        match_threshold: (float) The overlap thresh for
            match metric.
    Returns:
        A list of filtered indexes, Shape: [ ,]
    """

    scores = predictions[:, 4].squeeze()
    category_ids = predictions[:, 5].squeeze()
    keep_mask = torch.zeros_like(category_ids, dtype=torch.bool)
    for category_id in torch.unique(category_ids):
        curr_indices = torch.where(category_ids == category_id)[0]
        curr_keep_indices = nms(predictions[curr_indices], match_metric, match_threshold)
        keep_mask[curr_indices[curr_keep_indices]] = True
    keep_indices = torch.where(keep_mask)[0]
    # sort selected indices by their scores
    keep_indices = keep_indices[scores[keep_indices].sort(descending=True)[1]].tolist()
    return keep_indices
greedy_nmm(object_predictions_as_tensor, match_metric='IOU', match_threshold=0.5)

Optimized greedy non-maximum merging for axis-aligned bounding boxes using STRTree.

Parameters:

Name Type Description Default
object_predictions_as_tensor Tensor

(tensor) The location preds for the image along with the class predscores, Shape: [num_boxes,5].

required
match_metric str

(str) IOU or IOS

'IOU'
match_threshold float

(float) The overlap thresh for match metric.

0.5

Returns: keep_to_merge_list: (dict[int, list[int]]) mapping from prediction indices to keep to a list of prediction indices to be merged.

Source code in sahi/postprocess/combine.py
def greedy_nmm(
    object_predictions_as_tensor: torch.Tensor,
    match_metric: str = "IOU",
    match_threshold: float = 0.5,
):
    """
    Optimized greedy non-maximum merging for axis-aligned bounding boxes using STRTree.

    Args:
        object_predictions_as_tensor: (tensor) The location preds for the image
            along with the class predscores, Shape: [num_boxes,5].
        match_metric: (str) IOU or IOS
        match_threshold: (float) The overlap thresh for match metric.
    Returns:
        keep_to_merge_list: (dict[int, list[int]]) mapping from prediction indices
        to keep to a list of prediction indices to be merged.
    """
    # Extract coordinates and scores as tensors
    x1 = object_predictions_as_tensor[:, 0]
    y1 = object_predictions_as_tensor[:, 1]
    x2 = object_predictions_as_tensor[:, 2]
    y2 = object_predictions_as_tensor[:, 3]
    scores = object_predictions_as_tensor[:, 4]

    # Calculate areas as tensor (vectorized operation)
    areas = (x2 - x1) * (y2 - y1)

    # Create Shapely boxes only once
    boxes = []
    for i in range(len(object_predictions_as_tensor)):
        boxes.append(
            box(
                x1[i].item(),  # Convert only individual values
                y1[i].item(),
                x2[i].item(),
                y2[i].item(),
            )
        )

    # Sort indices by score (descending) using torch
    sorted_idxs = torch.argsort(scores, descending=True).tolist()

    # Build STRtree
    tree = STRtree(boxes)

    keep_to_merge_list = {}
    suppressed = set()

    for current_idx in sorted_idxs:
        if current_idx in suppressed:
            continue

        current_box = boxes[current_idx]
        current_area = areas[current_idx].item()  # Convert only when needed

        # Query potential intersections using STRtree
        candidate_idxs = tree.query(current_box)

        merge_list = []
        for candidate_idx in candidate_idxs:
            if candidate_idx == current_idx or candidate_idx in suppressed:
                continue

            # Only consider candidates with lower or equal score
            if scores[candidate_idx] > scores[current_idx]:
                continue

            # For equal scores, use deterministic tie-breaking based on box coordinates
            if scores[candidate_idx] == scores[current_idx]:
                # Use box coordinates for stable ordering
                current_coords = (
                    x1[current_idx].item(),
                    y1[current_idx].item(),
                    x2[current_idx].item(),
                    y2[current_idx].item(),
                )
                candidate_coords = (
                    x1[candidate_idx].item(),
                    y1[candidate_idx].item(),
                    x2[candidate_idx].item(),
                    y2[candidate_idx].item(),
                )

                # Compare coordinates lexicographically
                if candidate_coords > current_coords:
                    continue

            # Calculate intersection area
            candidate_box = boxes[candidate_idx]
            intersection = current_box.intersection(candidate_box).area

            # Calculate metric
            if match_metric == "IOU":
                union = current_area + areas[candidate_idx].item() - intersection
                metric = intersection / union if union > 0 else 0
            elif match_metric == "IOS":
                smaller = min(current_area, areas[candidate_idx].item())
                metric = intersection / smaller if smaller > 0 else 0
            else:
                raise ValueError("Invalid match_metric")

            # Add to merge list if overlap exceeds threshold
            if metric >= match_threshold:
                merge_list.append(candidate_idx)
                suppressed.add(candidate_idx)

        keep_to_merge_list[int(current_idx)] = [int(idx) for idx in merge_list]

    return keep_to_merge_list
nmm(object_predictions_as_tensor, match_metric='IOU', match_threshold=0.5)

Apply non-maximum merging to avoid detecting too many overlapping bounding boxes for a given object.

Parameters:

Name Type Description Default
object_predictions_as_tensor Tensor

(tensor) The location preds for the image along with the class predscores, Shape: [num_boxes,5].

required
match_metric str

(str) IOU or IOS

'IOU'
match_threshold float

(float) The overlap thresh for match metric.

0.5

Returns: keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices to keep to a list of prediction indices to be merged.

Source code in sahi/postprocess/combine.py
def nmm(
    object_predictions_as_tensor: torch.Tensor,
    match_metric: str = "IOU",
    match_threshold: float = 0.5,
):
    """Apply non-maximum merging to avoid detecting too many overlapping bounding boxes for a given object.

    Args:
        object_predictions_as_tensor: (tensor) The location preds for the image
            along with the class predscores, Shape: [num_boxes,5].
        match_metric: (str) IOU or IOS
        match_threshold: (float) The overlap thresh for match metric.
    Returns:
        keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indices
        to keep to a list of prediction indices to be merged.
    """
    # Extract coordinates and scores as tensors
    x1 = object_predictions_as_tensor[:, 0]
    y1 = object_predictions_as_tensor[:, 1]
    x2 = object_predictions_as_tensor[:, 2]
    y2 = object_predictions_as_tensor[:, 3]
    scores = object_predictions_as_tensor[:, 4]

    # Calculate areas as tensor (vectorized operation)
    areas = (x2 - x1) * (y2 - y1)

    # Create Shapely boxes only once
    boxes = []
    for i in range(len(object_predictions_as_tensor)):
        boxes.append(
            box(
                x1[i].item(),  # Convert only individual values
                y1[i].item(),
                x2[i].item(),
                y2[i].item(),
            )
        )

    # Sort indices by score (descending) using torch
    sorted_idxs = torch.argsort(scores, descending=True).tolist()

    # Build STRtree
    tree = STRtree(boxes)

    keep_to_merge_list = {}
    merge_to_keep = {}

    for current_idx in sorted_idxs:
        current_box = boxes[current_idx]
        current_area = areas[current_idx].item()  # Convert only when needed

        # Query potential intersections using STRtree
        candidate_idxs = tree.query(current_box)

        matched_box_indices = []
        for candidate_idx in candidate_idxs:
            if candidate_idx == current_idx:
                continue

            # Only consider candidates with lower or equal score
            if scores[candidate_idx] > scores[current_idx]:
                continue

            # For equal scores, use deterministic tie-breaking based on box coordinates
            if scores[candidate_idx] == scores[current_idx]:
                # Use box coordinates for stable ordering
                current_coords = (
                    x1[current_idx].item(),
                    y1[current_idx].item(),
                    x2[current_idx].item(),
                    y2[current_idx].item(),
                )
                candidate_coords = (
                    x1[candidate_idx].item(),
                    y1[candidate_idx].item(),
                    x2[candidate_idx].item(),
                    y2[candidate_idx].item(),
                )

                # Compare coordinates lexicographically
                if candidate_coords > current_coords:
                    continue

            # Calculate intersection area
            candidate_box = boxes[candidate_idx]
            intersection = current_box.intersection(candidate_box).area

            # Calculate metric
            if match_metric == "IOU":
                union = current_area + areas[candidate_idx].item() - intersection
                metric = intersection / union if union > 0 else 0
            elif match_metric == "IOS":
                smaller = min(current_area, areas[candidate_idx].item())
                metric = intersection / smaller if smaller > 0 else 0
            else:
                raise ValueError("Invalid match_metric")

            # Add to matched list if overlap exceeds threshold
            if metric >= match_threshold:
                matched_box_indices.append(candidate_idx)

        # Convert current_idx to native Python int
        current_idx_native = int(current_idx)

        # Create keep_ind to merge_ind_list mapping
        if current_idx_native not in merge_to_keep:
            keep_to_merge_list[current_idx_native] = []

            for matched_box_idx in matched_box_indices:
                matched_box_idx_native = int(matched_box_idx)
                if matched_box_idx_native not in merge_to_keep:
                    keep_to_merge_list[current_idx_native].append(matched_box_idx_native)
                    merge_to_keep[matched_box_idx_native] = current_idx_native
        else:
            keep_idx = merge_to_keep[current_idx_native]
            for matched_box_idx in matched_box_indices:
                matched_box_idx_native = int(matched_box_idx)
                if (
                    matched_box_idx_native not in keep_to_merge_list.get(keep_idx, [])
                    and matched_box_idx_native not in merge_to_keep
                ):
                    if keep_idx not in keep_to_merge_list:
                        keep_to_merge_list[keep_idx] = []
                    keep_to_merge_list[keep_idx].append(matched_box_idx_native)
                    merge_to_keep[matched_box_idx_native] = keep_idx

    return keep_to_merge_list
nms(predictions, match_metric='IOU', match_threshold=0.5)

Optimized non-maximum suppression for axis-aligned bounding boxes using STRTree.

Parameters:

Name Type Description Default
predictions Tensor

(tensor) The location preds for the image along with the class predscores, Shape: [num_boxes,5].

required
match_metric str

(str) IOU or IOS

'IOU'
match_threshold float

(float) The overlap thresh for match metric.

0.5

Returns:

Type Description

A list of filtered indexes, Shape: [ ,]

Source code in sahi/postprocess/combine.py
def nms(
    predictions: torch.Tensor,
    match_metric: str = "IOU",
    match_threshold: float = 0.5,
):
    """
    Optimized non-maximum suppression for axis-aligned bounding boxes using STRTree.

    Args:
        predictions: (tensor) The location preds for the image along with the class
            predscores, Shape: [num_boxes,5].
        match_metric: (str) IOU or IOS
        match_threshold: (float) The overlap thresh for match metric.

    Returns:
        A list of filtered indexes, Shape: [ ,]
    """
    if len(predictions) == 0:
        return []

    # Ensure predictions are on CPU and convert to numpy
    if predictions.device.type != "cpu":
        predictions = predictions.cpu()

    predictions_np = predictions.numpy()

    # Extract coordinates and scores
    x1 = predictions_np[:, 0]
    y1 = predictions_np[:, 1]
    x2 = predictions_np[:, 2]
    y2 = predictions_np[:, 3]
    scores = predictions_np[:, 4]

    # Calculate areas
    areas = (x2 - x1) * (y2 - y1)

    # Create Shapely boxes (vectorized)
    boxes = box(x1, y1, x2, y2)

    # Sort indices by score (descending)
    sorted_idxs = np.argsort(scores)[::-1]

    # Build STRtree
    tree = STRtree(boxes)

    keep = []
    suppressed = set()

    for current_idx in sorted_idxs:
        if current_idx in suppressed:
            continue

        keep.append(current_idx)
        current_box = boxes[current_idx]
        current_area = areas[current_idx]

        # Query potential intersections using STRtree
        candidate_idxs = tree.query(current_box)

        for candidate_idx in candidate_idxs:
            if candidate_idx == current_idx or candidate_idx in suppressed:
                continue

            # Skip candidates with higher scores (already processed)
            if scores[candidate_idx] > scores[current_idx]:
                continue

            # For equal scores, use deterministic tie-breaking based on box coordinates
            if scores[candidate_idx] == scores[current_idx]:
                # Use box coordinates for stable ordering
                current_coords = (
                    x1[current_idx],
                    y1[current_idx],
                    x2[current_idx],
                    y2[current_idx],
                )
                candidate_coords = (
                    x1[candidate_idx],
                    y1[candidate_idx],
                    x2[candidate_idx],
                    y2[candidate_idx],
                )

                # Compare coordinates lexicographically
                if candidate_coords > current_coords:
                    continue

            # Calculate intersection area
            candidate_box = boxes[candidate_idx]
            intersection = current_box.intersection(candidate_box).area

            # Calculate metric
            if match_metric == "IOU":
                union = current_area + areas[candidate_idx] - intersection
                metric = intersection / union if union > 0 else 0
            elif match_metric == "IOS":
                smaller = min(current_area, areas[candidate_idx])
                metric = intersection / smaller if smaller > 0 else 0
            else:
                raise ValueError("Invalid match_metric")

            # Suppress if overlap exceeds threshold
            if metric >= match_threshold:
                suppressed.add(candidate_idx)

    return keep
legacy
Modules
combine
Classes
PostprocessPredictions

Utilities for calculating IOU/IOS based match for given ObjectPredictions.

Source code in sahi/postprocess/legacy/combine.py
class PostprocessPredictions:
    """Utilities for calculating IOU/IOS based match for given ObjectPredictions."""

    def __init__(
        self,
        match_threshold: float = 0.5,
        match_metric: str = "IOU",
        class_agnostic: bool = True,
    ):
        self.match_threshold = match_threshold
        self.class_agnostic = class_agnostic
        if match_metric == "IOU":
            self.calculate_match = self.calculate_bbox_iou
        elif match_metric == "IOS":
            self.calculate_match = self.calculate_bbox_ios
        else:
            raise ValueError(f"'match_metric' should be one of ['IOU', 'IOS'] but given as {match_metric}")

    def _has_match(self, pred1: ObjectPrediction, pred2: ObjectPrediction) -> bool:
        threshold_condition = self.calculate_match(pred1, pred2) > self.match_threshold
        category_condition = self.has_same_category_id(pred1, pred2) or self.class_agnostic
        return threshold_condition and category_condition

    @staticmethod
    def get_score_func(object_prediction: ObjectPrediction):
        """Used for sorting predictions."""
        return object_prediction.score.value

    @staticmethod
    def has_same_category_id(pred1: ObjectPrediction, pred2: ObjectPrediction) -> bool:
        return pred1.category.id == pred2.category.id

    @staticmethod
    def calculate_bbox_iou(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float:
        """Returns the ratio of intersection area to the union."""
        box1 = np.array(pred1.bbox.to_xyxy())
        box2 = np.array(pred2.bbox.to_xyxy())
        area1 = calculate_area(box1)
        area2 = calculate_area(box2)
        intersect = calculate_intersection_area(box1, box2)
        return intersect / (area1 + area2 - intersect)

    @staticmethod
    def calculate_bbox_ios(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float:
        """Returns the ratio of intersection area to the smaller box's area."""
        box1 = np.array(pred1.bbox.to_xyxy())
        box2 = np.array(pred2.bbox.to_xyxy())
        area1 = calculate_area(box1)
        area2 = calculate_area(box2)
        intersect = calculate_intersection_area(box1, box2)
        smaller_area = np.minimum(area1, area2)
        return intersect / smaller_area

    def __call__(self):
        raise NotImplementedError()
Functions
calculate_bbox_ios(pred1, pred2) staticmethod

Returns the ratio of intersection area to the smaller box's area.

Source code in sahi/postprocess/legacy/combine.py
@staticmethod
def calculate_bbox_ios(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float:
    """Returns the ratio of intersection area to the smaller box's area."""
    box1 = np.array(pred1.bbox.to_xyxy())
    box2 = np.array(pred2.bbox.to_xyxy())
    area1 = calculate_area(box1)
    area2 = calculate_area(box2)
    intersect = calculate_intersection_area(box1, box2)
    smaller_area = np.minimum(area1, area2)
    return intersect / smaller_area
calculate_bbox_iou(pred1, pred2) staticmethod

Returns the ratio of intersection area to the union.

Source code in sahi/postprocess/legacy/combine.py
@staticmethod
def calculate_bbox_iou(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float:
    """Returns the ratio of intersection area to the union."""
    box1 = np.array(pred1.bbox.to_xyxy())
    box2 = np.array(pred2.bbox.to_xyxy())
    area1 = calculate_area(box1)
    area2 = calculate_area(box2)
    intersect = calculate_intersection_area(box1, box2)
    return intersect / (area1 + area2 - intersect)
get_score_func(object_prediction) staticmethod

Used for sorting predictions.

Source code in sahi/postprocess/legacy/combine.py
@staticmethod
def get_score_func(object_prediction: ObjectPrediction):
    """Used for sorting predictions."""
    return object_prediction.score.value
Functions
utils
Classes Functions
calculate_area(box)

Parameters:

Name Type Description Default
box List[int]

[x1, y1, x2, y2]

required
Source code in sahi/postprocess/utils.py
def calculate_area(box: list[int] | np.ndarray) -> float:
    """
    Args:
        box (List[int]): [x1, y1, x2, y2]
    """
    return (box[2] - box[0]) * (box[3] - box[1])
calculate_bbox_ios(pred1, pred2)

Returns the ratio of intersection area to the smaller box's area.

Source code in sahi/postprocess/utils.py
def calculate_bbox_ios(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float:
    """Returns the ratio of intersection area to the smaller box's area."""
    box1 = np.array(pred1.bbox.to_xyxy())
    box2 = np.array(pred2.bbox.to_xyxy())
    area1 = calculate_area(box1)
    area2 = calculate_area(box2)
    intersect = calculate_intersection_area(box1, box2)
    smaller_area = np.minimum(area1, area2)
    return intersect / smaller_area
calculate_bbox_iou(pred1, pred2)

Returns the ratio of intersection area to the union.

Source code in sahi/postprocess/utils.py
def calculate_bbox_iou(pred1: ObjectPrediction, pred2: ObjectPrediction) -> float:
    """Returns the ratio of intersection area to the union."""
    box1 = np.array(pred1.bbox.to_xyxy())
    box2 = np.array(pred2.bbox.to_xyxy())
    area1 = calculate_area(box1)
    area2 = calculate_area(box2)
    intersect = calculate_intersection_area(box1, box2)
    return intersect / (area1 + area2 - intersect)
calculate_box_union(box1, box2)

Parameters:

Name Type Description Default
box1 List[int]

[x1, y1, x2, y2]

required
box2 List[int]

[x1, y1, x2, y2]

required
Source code in sahi/postprocess/utils.py
def calculate_box_union(box1: list[int] | np.ndarray, box2: list[int] | np.ndarray) -> list[int]:
    """
    Args:
        box1 (List[int]): [x1, y1, x2, y2]
        box2 (List[int]): [x1, y1, x2, y2]
    """
    box1 = np.array(box1)
    box2 = np.array(box2)
    left_top = np.minimum(box1[:2], box2[:2])
    right_bottom = np.maximum(box1[2:], box2[2:])
    return list(np.concatenate((left_top, right_bottom)))
calculate_intersection_area(box1, box2)

Parameters:

Name Type Description Default
box1 ndarray

np.array([x1, y1, x2, y2])

required
box2 ndarray

np.array([x1, y1, x2, y2])

required
Source code in sahi/postprocess/utils.py
def calculate_intersection_area(box1: np.ndarray, box2: np.ndarray) -> float:
    """
    Args:
        box1 (np.ndarray): np.array([x1, y1, x2, y2])
        box2 (np.ndarray): np.array([x1, y1, x2, y2])
    """
    left_top = np.maximum(box1[:2], box2[:2])
    right_bottom = np.minimum(box1[2:], box2[2:])
    width_height = (right_bottom - left_top).clip(min=0)
    return width_height[0] * width_height[1]
coco_segmentation_to_shapely(segmentation)

Fix segment data in COCO format :param segmentation: segment data in COCO format :return:

Source code in sahi/postprocess/utils.py
def coco_segmentation_to_shapely(segmentation: list | list[list]):
    """Fix segment data in COCO format :param segmentation: segment data in COCO format :return:"""
    if isinstance(segmentation, list) and all([not isinstance(seg, list) for seg in segmentation]):
        segmentation = [segmentation]
    elif isinstance(segmentation, list) and all([isinstance(seg, list) for seg in segmentation]):
        pass
    else:
        raise ValueError("segmentation must be List or List[List]")

    polygon_list = []

    for coco_polygon in segmentation:
        point_list = list(zip(coco_polygon[::2], coco_polygon[1::2]))
        shapely_polygon = Polygon(point_list)
        polygon_list.append(repair_polygon(shapely_polygon))

    shapely_multipolygon = repair_multipolygon(MultiPolygon(polygon_list))
    return shapely_multipolygon
object_prediction_list_to_numpy(object_prediction_list)

Returns:

Type Description
ndarray

np.ndarray of size N x [x1, y1, x2, y2, score, category_id]

Source code in sahi/postprocess/utils.py
def object_prediction_list_to_numpy(object_prediction_list: ObjectPredictionList) -> np.ndarray:
    """
    Returns:
        np.ndarray of size N x [x1, y1, x2, y2, score, category_id]
    """
    num_predictions = len(object_prediction_list)
    numpy_predictions = np.zeros([num_predictions, 6], dtype=np.float32)
    for ind, object_prediction in enumerate(object_prediction_list):
        numpy_predictions[ind, :4] = np.array(object_prediction.tolist().bbox.to_xyxy(), dtype=np.float32)
        numpy_predictions[ind, 4] = object_prediction.tolist().score.value
        numpy_predictions[ind, 5] = object_prediction.tolist().category.id
    return numpy_predictions
object_prediction_list_to_torch(object_prediction_list)

Returns:

Type Description
tensor

torch.tensor of size N x [x1, y1, x2, y2, score, category_id]

Source code in sahi/postprocess/utils.py
def object_prediction_list_to_torch(object_prediction_list: ObjectPredictionList) -> torch.tensor:
    """
    Returns:
        torch.tensor of size N x [x1, y1, x2, y2, score, category_id]
    """
    num_predictions = len(object_prediction_list)
    torch_predictions = torch.zeros([num_predictions, 6], dtype=torch.float32)
    for ind, object_prediction in enumerate(object_prediction_list):
        torch_predictions[ind, :4] = torch.tensor(object_prediction.tolist().bbox.to_xyxy(), dtype=torch.float32)
        torch_predictions[ind, 4] = object_prediction.tolist().score.value
        torch_predictions[ind, 5] = object_prediction.tolist().category.id
    return torch_predictions
repair_multipolygon(shapely_multipolygon)

Fix invalid MultiPolygon objects :param shapely_multipolygon: Imported shapely MultiPolygon object :return:

Source code in sahi/postprocess/utils.py
def repair_multipolygon(shapely_multipolygon: MultiPolygon) -> MultiPolygon:
    """Fix invalid MultiPolygon objects :param shapely_multipolygon: Imported shapely MultiPolygon object :return:"""
    if not shapely_multipolygon.is_valid:
        fixed_geometry = shapely_multipolygon.buffer(0)

        if fixed_geometry.is_valid:
            if isinstance(fixed_geometry, MultiPolygon):
                return fixed_geometry
            elif isinstance(fixed_geometry, Polygon):
                return MultiPolygon([fixed_geometry])
            elif isinstance(fixed_geometry, GeometryCollection):
                polygons = [geom for geom in fixed_geometry.geoms if isinstance(geom, Polygon)]
                return MultiPolygon(polygons) if polygons else shapely_multipolygon

    return shapely_multipolygon
repair_polygon(shapely_polygon)

Fix polygons :param shapely_polygon: Shapely polygon object :return:

Source code in sahi/postprocess/utils.py
def repair_polygon(shapely_polygon: Polygon) -> Polygon:
    """Fix polygons :param shapely_polygon: Shapely polygon object :return:"""
    if not shapely_polygon.is_valid:
        fixed_polygon = shapely_polygon.buffer(0)
        if fixed_polygon.is_valid:
            if isinstance(fixed_polygon, Polygon):
                return fixed_polygon
            elif isinstance(fixed_polygon, MultiPolygon):
                return max(fixed_polygon.geoms, key=lambda p: p.area)
            elif isinstance(fixed_polygon, GeometryCollection):
                polygons = [geom for geom in fixed_polygon.geoms if isinstance(geom, Polygon)]
                return max(polygons, key=lambda p: p.area) if polygons else shapely_polygon

    return shapely_polygon

predict

Classes
Functions
bbox_sort(a, b, thresh)

a, b - function receives two bounding bboxes

thresh - the threshold takes into account how far two bounding bboxes differ in Y where thresh is the threshold we set for the minimum allowable difference in height between adjacent bboxes and sorts them by the X coordinate

Source code in sahi/predict.py
def bbox_sort(a, b, thresh):
    """
    a, b  - function receives two bounding bboxes

    thresh - the threshold takes into account how far two bounding bboxes differ in
    Y where thresh is the threshold we set for the
    minimum allowable difference in height between adjacent bboxes
    and sorts them by the X coordinate
    """

    bbox_a = a
    bbox_b = b

    if abs(bbox_a[1] - bbox_b[1]) <= thresh:
        return bbox_a[0] - bbox_b[0]

    return bbox_a[1] - bbox_b[1]
get_prediction(image, detection_model, shift_amount=None, full_shape=None, postprocess=None, verbose=0, exclude_classes_by_name=None, exclude_classes_by_id=None)

Function for performing prediction for given image using given detection_model.

Parameters:

Name Type Description Default
image

str or np.ndarray Location of image or numpy image matrix to slice

required
detection_model

model.DetectionMode

required
shift_amount list | None

List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

None
full_shape

List Size of the full image, should be in the form of [height, width]

None
postprocess PostprocessPredictions | None

sahi.postprocess.combine.PostprocessPredictions

None
verbose int

int 0: no print (default) 1: print prediction duration

0
exclude_classes_by_name list[str] | None

Optional[List[str]] None: if no classes are excluded List[str]: set of classes to exclude using its/their class label name/s

None
exclude_classes_by_id list[int] | None

Optional[List[int]] None: if no classes are excluded List[int]: set of classes to exclude using one or more IDs

None

Returns: A dict with fields: object_prediction_list: a list of ObjectPrediction durations_in_seconds: a dict containing elapsed times for profiling

Source code in sahi/predict.py
def get_prediction(
    image,
    detection_model,
    shift_amount: list | None = None,
    full_shape=None,
    postprocess: PostprocessPredictions | None = None,
    verbose: int = 0,
    exclude_classes_by_name: list[str] | None = None,
    exclude_classes_by_id: list[int] | None = None,
) -> PredictionResult:
    """Function for performing prediction for given image using given detection_model.

    Args:
        image: str or np.ndarray
            Location of image or numpy image matrix to slice
        detection_model: model.DetectionMode
        shift_amount: List
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
        full_shape: List
            Size of the full image, should be in the form of [height, width]
        postprocess: sahi.postprocess.combine.PostprocessPredictions
        verbose: int
            0: no print (default)
            1: print prediction duration
        exclude_classes_by_name: Optional[List[str]]
            None: if no classes are excluded
            List[str]: set of classes to exclude using its/their class label name/s
        exclude_classes_by_id: Optional[List[int]]
            None: if no classes are excluded
            List[int]: set of classes to exclude using one or more IDs
    Returns:
        A dict with fields:
            object_prediction_list: a list of ObjectPrediction
            durations_in_seconds: a dict containing elapsed times for profiling
    """
    durations_in_seconds = dict()

    # read image as pil
    image_as_pil = read_image_as_pil(image)
    # get prediction
    # ensure shift_amount is a list instance (avoid mutable default arg)
    if shift_amount is None:
        shift_amount = [0, 0]

    time_start = time.perf_counter()
    detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
    time_end = time.perf_counter() - time_start
    durations_in_seconds["prediction"] = time_end

    if full_shape is None:
        full_shape = [image_as_pil.height, image_as_pil.width]

    # process prediction
    time_start = time.perf_counter()
    # works only with 1 batch
    detection_model.convert_original_predictions(
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
    object_prediction_list: list[ObjectPrediction] = detection_model.object_prediction_list
    object_prediction_list = filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id)

    # postprocess matching predictions
    if postprocess is not None:
        object_prediction_list = postprocess(object_prediction_list)

    time_end = time.perf_counter() - time_start
    durations_in_seconds["postprocess"] = time_end

    if verbose == 1:
        print(
            "Prediction performed in",
            durations_in_seconds["prediction"],
            "seconds.",
        )

    return PredictionResult(
        image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
    )
get_sliced_prediction(image, detection_model=None, slice_height=None, slice_width=None, overlap_height_ratio=0.2, overlap_width_ratio=0.2, perform_standard_pred=True, postprocess_type='GREEDYNMM', postprocess_match_metric='IOS', postprocess_match_threshold=0.5, postprocess_class_agnostic=False, verbose=1, merge_buffer_length=None, auto_slice_resolution=True, slice_export_prefix=None, slice_dir=None, exclude_classes_by_name=None, exclude_classes_by_id=None, progress_bar=False, progress_callback=None)

Function for slice image + get predicion for each slice + combine predictions in full image.

Parameters:

Name Type Description Default
image

str or np.ndarray Location of image or numpy image matrix to slice

required
detection_model

model.DetectionModel

None
slice_height int | None

int Height of each slice. Defaults to None.

None
slice_width int | None

int Width of each slice. Defaults to None.

None
overlap_height_ratio float

float Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window of size 512 yields an overlap of 102 pixels). Default to 0.2.

0.2
overlap_width_ratio float

float Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window of size 512 yields an overlap of 102 pixels). Default to 0.2.

0.2
perform_standard_pred bool

bool Perform a standard prediction on top of sliced predictions to increase large object detection accuracy. Default: True.

True
postprocess_type str

str Type of the postprocess to be used after sliced inference while merging/eliminating predictions. Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.

'GREEDYNMM'
postprocess_match_metric str

str Metric to be used during object prediction matching after sliced prediction. 'IOU' for intersection over union, 'IOS' for intersection over smaller area.

'IOS'
postprocess_match_threshold float

float Sliced predictions having higher iou than postprocess_match_threshold will be postprocessed after sliced prediction.

0.5
postprocess_class_agnostic bool

bool If True, postprocess will ignore category ids.

False
verbose int

int 0: no print 1: print number of slices (default) 2: print number of slices and slice/prediction durations

1
merge_buffer_length int | None

int The length of buffer for slices to be used during sliced prediction, which is suitable for low memory. It may affect the AP if it is specified. The higher the amount, the closer results to the non-buffered. scenario. See the discussion.

None
auto_slice_resolution bool

bool if slice parameters (slice_height, slice_width) are not given, it enables automatically calculate these params from image resolution and orientation.

True
slice_export_prefix str | None

str Prefix for the exported slices. Defaults to None.

None
slice_dir str | None

str Directory to save the slices. Defaults to None.

None
exclude_classes_by_name list[str] | None

Optional[List[str]] None: if no classes are excluded List[str]: set of classes to exclude using its/their class label name/s

None
exclude_classes_by_id list[int] | None

Optional[List[int]] None: if no classes are excluded List[int]: set of classes to exclude using one or more IDs

None
progress_bar bool

bool Whether to show progress bar for slice processing. Default: False.

False
progress_callback

callable A callback function that will be called after each slice is processed. The function should accept two arguments: (current_slice, total_slices)

None

Returns: A Dict with fields: object_prediction_list: a list of sahi.prediction.ObjectPrediction durations_in_seconds: a dict containing elapsed times for profiling

Source code in sahi/predict.py
def get_sliced_prediction(
    image,
    detection_model=None,
    slice_height: int | None = None,
    slice_width: int | None = None,
    overlap_height_ratio: float = 0.2,
    overlap_width_ratio: float = 0.2,
    perform_standard_pred: bool = True,
    postprocess_type: str = "GREEDYNMM",
    postprocess_match_metric: str = "IOS",
    postprocess_match_threshold: float = 0.5,
    postprocess_class_agnostic: bool = False,
    verbose: int = 1,
    merge_buffer_length: int | None = None,
    auto_slice_resolution: bool = True,
    slice_export_prefix: str | None = None,
    slice_dir: str | None = None,
    exclude_classes_by_name: list[str] | None = None,
    exclude_classes_by_id: list[int] | None = None,
    progress_bar: bool = False,
    progress_callback=None,
) -> PredictionResult:
    """Function for slice image + get predicion for each slice + combine predictions in full image.

    Args:
        image: str or np.ndarray
            Location of image or numpy image matrix to slice
        detection_model: model.DetectionModel
        slice_height: int
            Height of each slice.  Defaults to ``None``.
        slice_width: int
            Width of each slice.  Defaults to ``None``.
        overlap_height_ratio: float
            Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window
            of size 512 yields an overlap of 102 pixels).
            Default to ``0.2``.
        overlap_width_ratio: float
            Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window
            of size 512 yields an overlap of 102 pixels).
            Default to ``0.2``.
        perform_standard_pred: bool
            Perform a standard prediction on top of sliced predictions to increase large object
            detection accuracy. Default: True.
        postprocess_type: str
            Type of the postprocess to be used after sliced inference while merging/eliminating predictions.
            Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.
        postprocess_match_metric: str
            Metric to be used during object prediction matching after sliced prediction.
            'IOU' for intersection over union, 'IOS' for intersection over smaller area.
        postprocess_match_threshold: float
            Sliced predictions having higher iou than postprocess_match_threshold will be
            postprocessed after sliced prediction.
        postprocess_class_agnostic: bool
            If True, postprocess will ignore category ids.
        verbose: int
            0: no print
            1: print number of slices (default)
            2: print number of slices and slice/prediction durations
        merge_buffer_length: int
            The length of buffer for slices to be used during sliced prediction, which is suitable for low memory.
            It may affect the AP if it is specified. The higher the amount, the closer results to the non-buffered.
            scenario. See [the discussion](https://github.com/obss/sahi/pull/445).
        auto_slice_resolution: bool
            if slice parameters (slice_height, slice_width) are not given,
            it enables automatically calculate these params from image resolution and orientation.
        slice_export_prefix: str
            Prefix for the exported slices. Defaults to None.
        slice_dir: str
            Directory to save the slices. Defaults to None.
        exclude_classes_by_name: Optional[List[str]]
            None: if no classes are excluded
            List[str]: set of classes to exclude using its/their class label name/s
        exclude_classes_by_id: Optional[List[int]]
            None: if no classes are excluded
            List[int]: set of classes to exclude using one or more IDs
        progress_bar: bool
            Whether to show progress bar for slice processing. Default: False.
        progress_callback: callable
            A callback function that will be called after each slice is processed.
            The function should accept two arguments: (current_slice, total_slices)
    Returns:
        A Dict with fields:
            object_prediction_list: a list of sahi.prediction.ObjectPrediction
            durations_in_seconds: a dict containing elapsed times for profiling
    """

    # for profiling
    durations_in_seconds = dict()

    # currently only 1 batch supported
    num_batch = 1
    # create slices from full image
    time_start = time.perf_counter()
    slice_image_result = slice_image(
        image=image,
        output_file_name=slice_export_prefix,
        output_dir=slice_dir,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
        auto_slice_resolution=auto_slice_resolution,
    )
    from sahi.models.ultralytics import UltralyticsDetectionModel

    num_slices = len(slice_image_result)
    time_end = time.perf_counter() - time_start
    durations_in_seconds["slice"] = time_end

    if isinstance(detection_model, UltralyticsDetectionModel) and detection_model.is_obb:
        # Only NMS is supported for OBB model outputs
        postprocess_type = "NMS"

    # init match postprocess instance
    if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():
        raise ValueError(
            f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} "
            f"but given as {postprocess_type}"
        )
    postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]
    postprocess = postprocess_constructor(
        match_threshold=postprocess_match_threshold,
        match_metric=postprocess_match_metric,
        class_agnostic=postprocess_class_agnostic,
    )

    postprocess_time = 0
    time_start = time.perf_counter()
    # create prediction input
    num_group = int(num_slices / num_batch)
    if verbose == 1 or verbose == 2:
        tqdm.write(f"Performing prediction on {num_slices} slices.")

    if progress_bar:
        slice_iterator = tqdm(range(num_group), desc="Processing slices", total=num_group)
    else:
        slice_iterator = range(num_group)

    object_prediction_list = []
    # perform sliced prediction
    for group_ind in slice_iterator:
        # prepare batch (currently supports only 1 batch)
        image_list = []
        shift_amount_list = []
        for image_ind in range(num_batch):
            image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
            shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])
        # perform batch prediction
        prediction_result = get_prediction(
            image=image_list[0],
            detection_model=detection_model,
            shift_amount=shift_amount_list[0],
            full_shape=[
                slice_image_result.original_image_height,
                slice_image_result.original_image_width,
            ],
            exclude_classes_by_name=exclude_classes_by_name,
            exclude_classes_by_id=exclude_classes_by_id,
        )
        # convert sliced predictions to full predictions
        for object_prediction in prediction_result.object_prediction_list:
            if object_prediction:  # if not empty
                object_prediction_list.append(object_prediction.get_shifted_object_prediction())

        # merge matching predictions during sliced prediction
        if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:
            postprocess_time_start = time.time()
            object_prediction_list = postprocess(object_prediction_list)
            postprocess_time += time.time() - postprocess_time_start

        # Call progress callback if provided
        if progress_callback is not None:
            progress_callback(group_ind + 1, num_group)

    # perform standard prediction
    if num_slices > 1 and perform_standard_pred:
        prediction_result = get_prediction(
            image=image,
            detection_model=detection_model,
            shift_amount=[0, 0],
            full_shape=[
                slice_image_result.original_image_height,
                slice_image_result.original_image_width,
            ],
            postprocess=None,
            exclude_classes_by_name=exclude_classes_by_name,
            exclude_classes_by_id=exclude_classes_by_id,
        )
        object_prediction_list.extend(prediction_result.object_prediction_list)

    # merge matching predictions
    if len(object_prediction_list) > 1:
        postprocess_time_start = time.time()
        object_prediction_list = postprocess(object_prediction_list)
        postprocess_time += time.time() - postprocess_time_start

    time_end = time.perf_counter() - time_start
    durations_in_seconds["prediction"] = time_end - postprocess_time
    durations_in_seconds["postprocess"] = postprocess_time

    if verbose == 2:
        print(
            "Slicing performed in",
            durations_in_seconds["slice"],
            "seconds.",
        )
        print(
            "Prediction performed in",
            durations_in_seconds["prediction"],
            "seconds.",
        )
        print(
            "Postprocessing performed in",
            durations_in_seconds["postprocess"],
            "seconds.",
        )

    return PredictionResult(
        image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
    )
predict(detection_model=None, model_type='ultralytics', model_path=None, model_config_path=None, model_confidence_threshold=0.25, model_device=None, model_category_mapping=None, model_category_remapping=None, source=None, no_standard_prediction=False, no_sliced_prediction=False, image_size=None, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2, postprocess_type='GREEDYNMM', postprocess_match_metric='IOS', postprocess_match_threshold=0.5, postprocess_class_agnostic=False, novisual=False, view_video=False, frame_skip_interval=0, export_pickle=False, export_crop=False, dataset_json_path=None, project='runs/predict', name='exp', visual_bbox_thickness=None, visual_text_size=None, visual_text_thickness=None, visual_hide_labels=False, visual_hide_conf=False, visual_export_format='png', verbose=1, return_dict=False, force_postprocess_type=False, exclude_classes_by_name=None, exclude_classes_by_id=None, progress_bar=False, **kwargs)

Performs prediction for all present images in given folder.

Parameters:

Name Type Description Default
detection_model DetectionModel | None

sahi.model.DetectionModel Optionally provide custom DetectionModel to be used for inference. When provided, model_type, model_path, config_path, model_device, model_category_mapping, image_size params will be ignored

None
model_type str

str mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'.

'ultralytics'
model_path str | None

str Path for the model weight

None
model_config_path str | None

str Path for the detection model config file

None
model_confidence_threshold float

float All predictions with score < model_confidence_threshold will be discarded.

0.25
model_device str | None

str Torch device, "cpu" or "cuda"

None
model_category_mapping dict | None

dict Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}

None
model_category_remapping dict | None

dict: str to int Remap category ids after performing inference

None
source str | None

str Folder directory that contains images or path of the image to be predicted. Also video to be predicted.

None
no_standard_prediction bool

bool Dont perform standard prediction. Default: False.

False
no_sliced_prediction bool

bool Dont perform sliced prediction. Default: False.

False
image_size int | None

int Input image size for each inference (image is scaled by preserving asp. rat.).

None
slice_height int

int Height of each slice. Defaults to 512.

512
slice_width int

int Width of each slice. Defaults to 512.

512
overlap_height_ratio float

float Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window of size 512 yields an overlap of 102 pixels). Default to 0.2.

0.2
overlap_width_ratio float

float Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window of size 512 yields an overlap of 102 pixels). Default to 0.2.

0.2
postprocess_type str

str Type of the postprocess to be used after sliced inference while merging/eliminating predictions. Options are 'NMM', 'GREEDYNMM', 'LSNMS' or 'NMS'. Default is 'GREEDYNMM'.

'GREEDYNMM'
postprocess_match_metric str

str Metric to be used during object prediction matching after sliced prediction. 'IOU' for intersection over union, 'IOS' for intersection over smaller area.

'IOS'
postprocess_match_threshold float

float Sliced predictions having higher iou than postprocess_match_threshold will be postprocessed after sliced prediction.

0.5
postprocess_class_agnostic bool

bool If True, postprocess will ignore category ids.

False
novisual bool

bool Dont export predicted video/image visuals.

False
view_video bool

bool View result of prediction during video inference.

False
frame_skip_interval int

int If view_video or export_visual is slow, you can process one frames of 3(for exp: --frame_skip_interval=3).

0
export_pickle bool

bool Export predictions as .pickle

False
export_crop bool

bool Export predictions as cropped images.

False
dataset_json_path str | None

str If coco file path is provided, detection results will be exported in coco json format.

None
project str

str Save results to project/name.

'runs/predict'
name str

str Save results to project/name.

'exp'
visual_bbox_thickness int | None

int, optional Line thickness (in pixels) for bounding boxes in exported visualizations. If None, a default thickness is chosen based on image size.

None
visual_text_size float | None

float, optional Font scale/size for label text in exported visualizations. If None, a sensible default is used.

None
visual_text_thickness int | None

int, optional Thickness of text labels. If None, a sensible default is used.

None
visual_hide_labels bool

bool, optional If True, class label names won't be shown on the exported visuals.

False
visual_hide_conf bool

bool, optional If True, confidence scores won't be shown on the exported visuals.

False
visual_export_format str

str, optional Output image format to use when exporting visuals. Supported values are 'png' (default) and 'jpg'. Note that 'jpg' uses lossy compression and may produce smaller files. This parameter is ignored when novisual is True. Exported visuals are written under the run directory: project/name/visuals (and project/name/visuals_with_gt when ground-truth overlays are created).

'png'
verbose int

int 0: no print 1: print slice/prediction durations, number of slices 2: print model loading/file exporting durations

1
return_dict bool

bool If True, returns a dict with 'export_dir' field.

False
force_postprocess_type bool

bool If True, auto postprocess check will e disabled

False
exclude_classes_by_name list[str] | None

Optional[List[str]] None: if no classes are excluded List[str]: set of classes to exclude using its/their class label name/s

None
exclude_classes_by_id list[int] | None

Optional[List[int]] None: if no classes are excluded List[int]: set of classes to exclude using one or more IDs

None
progress_bar bool

bool Whether to show a progress bar. Default is False.

False
Source code in sahi/predict.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
def predict(
    detection_model: DetectionModel | None = None,
    model_type: str = "ultralytics",
    model_path: str | None = None,
    model_config_path: str | None = None,
    model_confidence_threshold: float = 0.25,
    model_device: str | None = None,
    model_category_mapping: dict | None = None,
    model_category_remapping: dict | None = None,
    source: str | None = None,
    no_standard_prediction: bool = False,
    no_sliced_prediction: bool = False,
    image_size: int | None = None,
    slice_height: int = 512,
    slice_width: int = 512,
    overlap_height_ratio: float = 0.2,
    overlap_width_ratio: float = 0.2,
    postprocess_type: str = "GREEDYNMM",
    postprocess_match_metric: str = "IOS",
    postprocess_match_threshold: float = 0.5,
    postprocess_class_agnostic: bool = False,
    novisual: bool = False,
    view_video: bool = False,
    frame_skip_interval: int = 0,
    export_pickle: bool = False,
    export_crop: bool = False,
    dataset_json_path: str | None = None,
    project: str = "runs/predict",
    name: str = "exp",
    visual_bbox_thickness: int | None = None,
    visual_text_size: float | None = None,
    visual_text_thickness: int | None = None,
    visual_hide_labels: bool = False,
    visual_hide_conf: bool = False,
    visual_export_format: str = "png",
    verbose: int = 1,
    return_dict: bool = False,
    force_postprocess_type: bool = False,
    exclude_classes_by_name: list[str] | None = None,
    exclude_classes_by_id: list[int] | None = None,
    progress_bar: bool = False,
    **kwargs,
):
    """Performs prediction for all present images in given folder.

    Args:
        detection_model: sahi.model.DetectionModel
            Optionally provide custom DetectionModel to be used for inference. When provided,
            model_type, model_path, config_path, model_device, model_category_mapping, image_size
            params will be ignored
        model_type: str
            mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'.
        model_path: str
            Path for the model weight
        model_config_path: str
            Path for the detection model config file
        model_confidence_threshold: float
            All predictions with score < model_confidence_threshold will be discarded.
        model_device: str
            Torch device, "cpu" or "cuda"
        model_category_mapping: dict
            Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
        model_category_remapping: dict: str to int
            Remap category ids after performing inference
        source: str
            Folder directory that contains images or path of the image to be predicted. Also video to be predicted.
        no_standard_prediction: bool
            Dont perform standard prediction. Default: False.
        no_sliced_prediction: bool
            Dont perform sliced prediction. Default: False.
        image_size: int
            Input image size for each inference (image is scaled by preserving asp. rat.).
        slice_height: int
            Height of each slice.  Defaults to ``512``.
        slice_width: int
            Width of each slice.  Defaults to ``512``.
        overlap_height_ratio: float
            Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window
            of size 512 yields an overlap of 102 pixels).
            Default to ``0.2``.
        overlap_width_ratio: float
            Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window
            of size 512 yields an overlap of 102 pixels).
            Default to ``0.2``.
        postprocess_type: str
            Type of the postprocess to be used after sliced inference while merging/eliminating predictions.
            Options are 'NMM', 'GREEDYNMM', 'LSNMS' or 'NMS'. Default is 'GREEDYNMM'.
        postprocess_match_metric: str
            Metric to be used during object prediction matching after sliced prediction.
            'IOU' for intersection over union, 'IOS' for intersection over smaller area.
        postprocess_match_threshold: float
            Sliced predictions having higher iou than postprocess_match_threshold will be
            postprocessed after sliced prediction.
        postprocess_class_agnostic: bool
            If True, postprocess will ignore category ids.
        novisual: bool
            Dont export predicted video/image visuals.
        view_video: bool
            View result of prediction during video inference.
        frame_skip_interval: int
            If view_video or export_visual is slow, you can process one frames of 3(for exp: --frame_skip_interval=3).
        export_pickle: bool
            Export predictions as .pickle
        export_crop: bool
            Export predictions as cropped images.
        dataset_json_path: str
            If coco file path is provided, detection results will be exported in coco json format.
        project: str
            Save results to project/name.
        name: str
            Save results to project/name.
        visual_bbox_thickness: int, optional
            Line thickness (in pixels) for bounding boxes in exported visualizations.
            If None, a default thickness is chosen based on image size.
        visual_text_size: float, optional
            Font scale/size for label text in exported visualizations. If None, a
            sensible default is used.
        visual_text_thickness: int, optional
            Thickness of text labels. If None, a sensible default is used.
        visual_hide_labels: bool, optional
            If True, class label names won't be shown on the exported visuals.
        visual_hide_conf: bool, optional
            If True, confidence scores won't be shown on the exported visuals.
        visual_export_format: str, optional
            Output image format to use when exporting visuals. Supported values are
            'png' (default) and 'jpg'. Note that 'jpg' uses lossy compression and may
            produce smaller files. This parameter is ignored when `novisual` is True.
            Exported visuals are written under the run directory: `project/name/visuals`
            (and `project/name/visuals_with_gt` when ground-truth overlays are created).
        verbose: int
            0: no print
            1: print slice/prediction durations, number of slices
            2: print model loading/file exporting durations
        return_dict: bool
            If True, returns a dict with 'export_dir' field.
        force_postprocess_type: bool
            If True, auto postprocess check will e disabled
        exclude_classes_by_name: Optional[List[str]]
            None: if no classes are excluded
            List[str]: set of classes to exclude using its/their class label name/s
        exclude_classes_by_id: Optional[List[int]]
            None: if no classes are excluded
            List[int]: set of classes to exclude using one or more IDs
        progress_bar: bool
            Whether to show a progress bar. Default is False.
    """
    # assert prediction type
    if no_standard_prediction and no_sliced_prediction:
        raise ValueError("'no_standard_prediction' and 'no_sliced_prediction' cannot be True at the same time.")

    # auto postprocess type
    if not force_postprocess_type and model_confidence_threshold < LOW_MODEL_CONFIDENCE and postprocess_type != "NMS":
        logger.warning(
            f"Switching postprocess type/metric to NMS/IOU since confidence "
            f"threshold is low ({model_confidence_threshold})."
        )
        postprocess_type = "NMS"
        postprocess_match_metric = "IOU"

    # for profiling
    durations_in_seconds = dict()

    # Init export directories
    save_dir = Path(increment_path(Path(project) / name, exist_ok=False))  # increment run
    crop_dir = save_dir / "crops"
    visual_dir = save_dir / "visuals"
    visual_with_gt_dir = save_dir / "visuals_with_gt"
    pickle_dir = save_dir / "pickles"
    if not novisual or export_pickle or export_crop or dataset_json_path is not None:
        save_dir.mkdir(parents=True, exist_ok=True)  # make dir

    # Init image iterator
    # TODO: rewrite this as iterator class as in https://github.com/ultralytics/yolov5/blob/d059d1da03aee9a3c0059895aa4c7c14b7f25a9e/utils/datasets.py#L178
    source_is_video = False
    num_frames = None
    image_iterator: list[str] | Generator[Image.Image]
    if dataset_json_path and source:
        coco: Coco = Coco.from_coco_dict_or_path(dataset_json_path)
        image_iterator = [str(Path(source) / Path(coco_image.file_name)) for coco_image in coco.images]
        coco_json = []
    elif source and os.path.isdir(source):
        image_iterator = list_files(directory=source, contains=IMAGE_EXTENSIONS, verbose=verbose)
    elif source and Path(source).suffix in VIDEO_EXTENSIONS:
        source_is_video = True
        read_video_frame, output_video_writer, video_file_name, num_frames = get_video_reader(
            source, str(save_dir), frame_skip_interval, not novisual, view_video
        )
        image_iterator = read_video_frame
    elif source:
        image_iterator = [source]
    else:
        logger.error("No valid input given to predict function")
        return

    # init model instance
    time_start = time.time()
    if detection_model is None:
        detection_model = AutoDetectionModel.from_pretrained(
            model_type=model_type,
            model_path=model_path,
            config_path=model_config_path,
            confidence_threshold=model_confidence_threshold,
            device=model_device,
            category_mapping=model_category_mapping,
            category_remapping=model_category_remapping,
            load_at_init=False,
            image_size=image_size,
            **kwargs,
        )
        detection_model.load_model()
    time_end = time.time() - time_start
    durations_in_seconds["model_load"] = time_end

    # iterate over source images
    durations_in_seconds["prediction"] = 0
    durations_in_seconds["slice"] = 0

    input_type_str = "video frames" if source_is_video else "images"
    for ind, image_path in enumerate(
        tqdm(image_iterator, f"Performing inference on {input_type_str}", total=num_frames)
    ):
        # Source is an image: Iterating over Image objects
        if source and source_is_video:
            video_name = Path(source).stem
            relative_filepath = video_name + "_frame_" + str(ind)
        elif isinstance(image_path, Image.Image):
            raise RuntimeError("Source is not a video, but image is still an Image object ")
        # preserve source folder structure in export
        elif source and os.path.isdir(source):
            relative_filepath = str(Path(image_path)).split(str(Path(source)))[-1]
            relative_filepath = relative_filepath[1:] if relative_filepath[0] == os.sep else relative_filepath
        else:  # no process if source is single file
            relative_filepath = Path(image_path).name

        filename_without_extension = Path(relative_filepath).stem

        # load image
        image_as_pil = read_image_as_pil(image_path)

        # perform prediction
        if not no_sliced_prediction:
            # get sliced prediction
            prediction_result = get_sliced_prediction(
                image=image_as_pil,
                detection_model=detection_model,
                slice_height=slice_height,
                slice_width=slice_width,
                overlap_height_ratio=overlap_height_ratio,
                overlap_width_ratio=overlap_width_ratio,
                perform_standard_pred=not no_standard_prediction,
                postprocess_type=postprocess_type,
                postprocess_match_metric=postprocess_match_metric,
                postprocess_match_threshold=postprocess_match_threshold,
                postprocess_class_agnostic=postprocess_class_agnostic,
                verbose=1 if verbose else 0,
                exclude_classes_by_name=exclude_classes_by_name,
                exclude_classes_by_id=exclude_classes_by_id,
                progress_bar=progress_bar,
            )
            object_prediction_list = prediction_result.object_prediction_list
            if prediction_result.durations_in_seconds:
                durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"]
        else:
            # get standard prediction
            prediction_result = get_prediction(
                image=image_as_pil,
                detection_model=detection_model,
                shift_amount=[0, 0],
                full_shape=None,
                postprocess=None,
                verbose=0,
                exclude_classes_by_name=exclude_classes_by_name,
                exclude_classes_by_id=exclude_classes_by_id,
            )
            object_prediction_list = prediction_result.object_prediction_list

        durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"]
        # Show prediction time
        if verbose:
            tqdm.write(
                "Prediction time is: {:.2f} ms".format(prediction_result.durations_in_seconds["prediction"] * 1000)
            )

        if dataset_json_path:
            if source_is_video is True:
                raise NotImplementedError("Video input type not supported with coco formatted dataset json")

            # append predictions in coco format
            for object_prediction in object_prediction_list:
                coco_prediction = object_prediction.to_coco_prediction()
                coco_prediction.image_id = coco.images[ind].id
                coco_prediction_json = coco_prediction.json
                if coco_prediction_json["bbox"]:
                    coco_json.append(coco_prediction_json)
            if not novisual:
                # convert ground truth annotations to object_prediction_list
                coco_image: CocoImage = coco.images[ind]
                object_prediction_gt_list: list[ObjectPrediction] = []
                for coco_annotation in coco_image.annotations:
                    coco_annotation_dict = coco_annotation.json
                    category_name = coco_annotation.category_name
                    full_shape = [coco_image.height, coco_image.width]
                    object_prediction_gt = ObjectPrediction.from_coco_annotation_dict(
                        annotation_dict=coco_annotation_dict, category_name=category_name, full_shape=full_shape
                    )
                    object_prediction_gt_list.append(object_prediction_gt)
                # export visualizations with ground truths
                output_dir = str(visual_with_gt_dir / Path(relative_filepath).parent)
                color = (0, 255, 0)  # original annotations in green
                result = visualize_object_predictions(
                    np.ascontiguousarray(image_as_pil),
                    object_prediction_list=object_prediction_gt_list,
                    rect_th=visual_bbox_thickness,
                    text_size=visual_text_size,
                    text_th=visual_text_thickness,
                    color=color,
                    hide_labels=visual_hide_labels,
                    hide_conf=visual_hide_conf,
                    output_dir=None,
                    file_name=None,
                    export_format=None,
                )
                color = (255, 0, 0)  # model predictions in red
                _ = visualize_object_predictions(
                    result["image"],
                    object_prediction_list=object_prediction_list,
                    rect_th=visual_bbox_thickness,
                    text_size=visual_text_size,
                    text_th=visual_text_thickness,
                    color=color,
                    hide_labels=visual_hide_labels,
                    hide_conf=visual_hide_conf,
                    output_dir=output_dir,
                    file_name=filename_without_extension,
                    export_format=visual_export_format,
                )

        time_start = time.time()
        # export prediction boxes
        if export_crop:
            output_dir = str(crop_dir / Path(relative_filepath).parent)
            crop_object_predictions(
                image=np.ascontiguousarray(image_as_pil),
                object_prediction_list=object_prediction_list,
                output_dir=output_dir,
                file_name=filename_without_extension,
                export_format=visual_export_format,
            )
        # export prediction list as pickle
        if export_pickle:
            save_path = str(pickle_dir / Path(relative_filepath).parent / (filename_without_extension + ".pickle"))
            save_pickle(data=object_prediction_list, save_path=save_path)

        # export visualization
        if not novisual or view_video:
            output_dir = str(visual_dir / Path(relative_filepath).parent)
            result = visualize_object_predictions(
                np.ascontiguousarray(image_as_pil),
                object_prediction_list=object_prediction_list,
                rect_th=visual_bbox_thickness,
                text_size=visual_text_size,
                text_th=visual_text_thickness,
                hide_labels=visual_hide_labels,
                hide_conf=visual_hide_conf,
                output_dir=output_dir if not source_is_video else None,
                file_name=filename_without_extension,
                export_format=visual_export_format,
            )
            if not novisual and source_is_video:  # export video
                if output_video_writer is None:
                    raise RuntimeError("Output video writer could not be created")
                output_video_writer.write(cv2.cvtColor(result["image"], cv2.COLOR_RGB2BGR))

        # render video inference
        if view_video:
            cv2.imshow(f"Prediction of {video_file_name!s}", result["image"])
            cv2.waitKey(1)

        time_end = time.time() - time_start
        durations_in_seconds["export_files"] = time_end

    # export coco results
    if dataset_json_path:
        save_path = str(save_dir / "result.json")
        save_json(coco_json, save_path)

    if not novisual or export_pickle or export_crop or dataset_json_path is not None:
        print(f"Prediction results are successfully exported to {save_dir}")

    # print prediction duration
    if verbose == 2:
        print(
            "Model loaded in",
            durations_in_seconds["model_load"],
            "seconds.",
        )
        print(
            "Slicing performed in",
            durations_in_seconds["slice"],
            "seconds.",
        )
        print(
            "Prediction performed in",
            durations_in_seconds["prediction"],
            "seconds.",
        )
        if not novisual:
            print(
                "Exporting performed in",
                durations_in_seconds["export_files"],
                "seconds.",
            )

    if return_dict:
        return {"export_dir": save_dir}
predict_fiftyone(model_type='mmdet', model_path=None, model_config_path=None, model_confidence_threshold=0.25, model_device=None, model_category_mapping=None, model_category_remapping=None, dataset_json_path='', image_dir='', no_standard_prediction=False, no_sliced_prediction=False, image_size=None, slice_height=256, slice_width=256, overlap_height_ratio=0.2, overlap_width_ratio=0.2, postprocess_type='GREEDYNMM', postprocess_match_metric='IOS', postprocess_match_threshold=0.5, postprocess_class_agnostic=False, verbose=1, exclude_classes_by_name=None, exclude_classes_by_id=None, progress_bar=False)

Performs prediction for all present images in given folder.

Parameters:

Name Type Description Default
model_type str

str mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'.

'mmdet'
model_path str | None

str Path for the model weight

None
model_config_path str | None

str Path for the detection model config file

None
model_confidence_threshold float

float All predictions with score < model_confidence_threshold will be discarded.

0.25
model_device str | None

str Torch device, "cpu" or "cuda"

None
model_category_mapping dict | None

dict Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}

None
model_category_remapping dict | None

dict: str to int Remap category ids after performing inference

None
dataset_json_path str

str If coco file path is provided, detection results will be exported in coco json format.

''
image_dir str

str Folder directory that contains images or path of the image to be predicted.

''
no_standard_prediction bool

bool Dont perform standard prediction. Default: False.

False
no_sliced_prediction bool

bool Dont perform sliced prediction. Default: False.

False
image_size int | None

int Input image size for each inference (image is scaled by preserving asp. rat.).

None
slice_height int

int Height of each slice. Defaults to 256.

256
slice_width int

int Width of each slice. Defaults to 256.

256
overlap_height_ratio float

float Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window of size 256 yields an overlap of 51 pixels). Default to 0.2.

0.2
overlap_width_ratio float

float Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window of size 256 yields an overlap of 51 pixels). Default to 0.2.

0.2
postprocess_type str

str Type of the postprocess to be used after sliced inference while merging/eliminating predictions. Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.

'GREEDYNMM'
postprocess_match_metric str

str Metric to be used during object prediction matching after sliced prediction. 'IOU' for intersection over union, 'IOS' for intersection over smaller area.

'IOS'
postprocess_match_metric str

str Metric to be used during object prediction matching after sliced prediction. 'IOU' for intersection over union, 'IOS' for intersection over smaller area.

'IOS'
postprocess_match_threshold float

float Sliced predictions having higher iou than postprocess_match_threshold will be postprocessed after sliced prediction.

0.5
postprocess_class_agnostic bool

bool If True, postprocess will ignore category ids.

False
verbose int

int 0: no print 1: print slice/prediction durations, number of slices, model loading/file exporting durations

1
exclude_classes_by_name list[str] | None

Optional[List[str]] None: if no classes are excluded List[str]: set of classes to exclude using its/their class label name/s

None
exclude_classes_by_id list[int] | None

Optional[List[int]] None: if no classes are excluded List[int]: set of classes to exclude using one or more IDs

None
progress_bar bool

bool Whether to show progress bar for slice processing. Default: False.

False
Source code in sahi/predict.py
def predict_fiftyone(
    model_type: str = "mmdet",
    model_path: str | None = None,
    model_config_path: str | None = None,
    model_confidence_threshold: float = 0.25,
    model_device: str | None = None,
    model_category_mapping: dict | None = None,
    model_category_remapping: dict | None = None,
    dataset_json_path: str = "",
    image_dir: str = "",
    no_standard_prediction: bool = False,
    no_sliced_prediction: bool = False,
    image_size: int | None = None,
    slice_height: int = 256,
    slice_width: int = 256,
    overlap_height_ratio: float = 0.2,
    overlap_width_ratio: float = 0.2,
    postprocess_type: str = "GREEDYNMM",
    postprocess_match_metric: str = "IOS",
    postprocess_match_threshold: float = 0.5,
    postprocess_class_agnostic: bool = False,
    verbose: int = 1,
    exclude_classes_by_name: list[str] | None = None,
    exclude_classes_by_id: list[int] | None = None,
    progress_bar: bool = False,
):
    """Performs prediction for all present images in given folder.

    Args:
        model_type: str
            mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'.
        model_path: str
            Path for the model weight
        model_config_path: str
            Path for the detection model config file
        model_confidence_threshold: float
            All predictions with score < model_confidence_threshold will be discarded.
        model_device: str
            Torch device, "cpu" or "cuda"
        model_category_mapping: dict
            Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
        model_category_remapping: dict: str to int
            Remap category ids after performing inference
        dataset_json_path: str
            If coco file path is provided, detection results will be exported in coco json format.
        image_dir: str
            Folder directory that contains images or path of the image to be predicted.
        no_standard_prediction: bool
            Dont perform standard prediction. Default: False.
        no_sliced_prediction: bool
            Dont perform sliced prediction. Default: False.
        image_size: int
            Input image size for each inference (image is scaled by preserving asp. rat.).
        slice_height: int
            Height of each slice.  Defaults to ``256``.
        slice_width: int
            Width of each slice.  Defaults to ``256``.
        overlap_height_ratio: float
            Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window
            of size 256 yields an overlap of 51 pixels).
            Default to ``0.2``.
        overlap_width_ratio: float
            Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window
            of size 256 yields an overlap of 51 pixels).
            Default to ``0.2``.
        postprocess_type: str
            Type of the postprocess to be used after sliced inference while merging/eliminating predictions.
            Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.
        postprocess_match_metric: str
            Metric to be used during object prediction matching after sliced prediction.
            'IOU' for intersection over union, 'IOS' for intersection over smaller area.
        postprocess_match_metric: str
            Metric to be used during object prediction matching after sliced prediction.
            'IOU' for intersection over union, 'IOS' for intersection over smaller area.
        postprocess_match_threshold: float
            Sliced predictions having higher iou than postprocess_match_threshold will be
            postprocessed after sliced prediction.
        postprocess_class_agnostic: bool
            If True, postprocess will ignore category ids.
        verbose: int
            0: no print
            1: print slice/prediction durations, number of slices, model loading/file exporting durations
        exclude_classes_by_name: Optional[List[str]]
            None: if no classes are excluded
            List[str]: set of classes to exclude using its/their class label name/s
        exclude_classes_by_id: Optional[List[int]]
            None: if no classes are excluded
            List[int]: set of classes to exclude using one or more IDs
        progress_bar: bool
            Whether to show progress bar for slice processing. Default: False.
    """
    check_requirements(["fiftyone"])

    from sahi.utils.fiftyone import create_fiftyone_dataset_from_coco_file, fo

    # assert prediction type
    if no_standard_prediction and no_sliced_prediction:
        raise ValueError("'no_standard_pred' and 'no_sliced_prediction' cannot be True at the same time.")
    # for profiling
    durations_in_seconds = dict()

    dataset = create_fiftyone_dataset_from_coco_file(image_dir, dataset_json_path)

    # init model instance
    time_start = time.time()
    detection_model = AutoDetectionModel.from_pretrained(
        model_type=model_type,
        model_path=model_path,
        config_path=model_config_path,
        confidence_threshold=model_confidence_threshold,
        device=model_device,
        category_mapping=model_category_mapping,
        category_remapping=model_category_remapping,
        load_at_init=False,
        image_size=image_size,
    )
    detection_model.load_model()
    time_end = time.time() - time_start
    durations_in_seconds["model_load"] = time_end

    # iterate over source images
    durations_in_seconds["prediction"] = 0
    durations_in_seconds["slice"] = 0
    # Add predictions to samples
    with fo.ProgressBar() as pb:
        for sample in pb(dataset):
            # perform prediction
            if not no_sliced_prediction:
                # get sliced prediction
                prediction_result = get_sliced_prediction(
                    image=sample.filepath,
                    detection_model=detection_model,
                    slice_height=slice_height,
                    slice_width=slice_width,
                    overlap_height_ratio=overlap_height_ratio,
                    overlap_width_ratio=overlap_width_ratio,
                    perform_standard_pred=not no_standard_prediction,
                    postprocess_type=postprocess_type,
                    postprocess_match_threshold=postprocess_match_threshold,
                    postprocess_match_metric=postprocess_match_metric,
                    postprocess_class_agnostic=postprocess_class_agnostic,
                    verbose=verbose,
                    exclude_classes_by_name=exclude_classes_by_name,
                    exclude_classes_by_id=exclude_classes_by_id,
                    progress_bar=progress_bar,
                )
                durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"]
            else:
                # get standard prediction
                prediction_result = get_prediction(
                    image=sample.filepath,
                    detection_model=detection_model,
                    shift_amount=[0, 0],
                    full_shape=None,
                    postprocess=None,
                    verbose=0,
                    exclude_classes_by_name=exclude_classes_by_name,
                    exclude_classes_by_id=exclude_classes_by_id,
                )
                durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"]

            # Save predictions to dataset
            sample[model_type] = fo.Detections(detections=prediction_result.to_fiftyone_detections())
            sample.save()

    # print prediction duration
    if verbose == 1:
        print(
            "Model loaded in",
            durations_in_seconds["model_load"],
            "seconds.",
        )
        print(
            "Slicing performed in",
            durations_in_seconds["slice"],
            "seconds.",
        )
        print(
            "Prediction performed in",
            durations_in_seconds["prediction"],
            "seconds.",
        )

    # visualize results
    session = fo.launch_app()  # pyright: ignore[reportArgumentType]
    session.dataset = dataset
    # Evaluate the predictions
    results = dataset.evaluate_detections(
        model_type,
        gt_field="ground_truth",
        eval_key="eval",
        iou=postprocess_match_threshold,
        compute_mAP=True,
    )
    # Get the 10 most common classes in the dataset
    counts = dataset.count_values("ground_truth.detections.label")
    classes_top10 = sorted(counts, key=counts.get, reverse=True)[:10]
    # Print a classification report for the top-10 classes
    results.print_report(classes=classes_top10)
    # Load the view on which we ran the `eval` evaluation
    eval_view = dataset.load_evaluation_view("eval")
    # Show samples with most false positives
    session.view = eval_view.sort_by("eval_fp", reverse=True)
    while 1:
        time.sleep(3)

prediction

Classes
ObjectPrediction

Bases: ObjectAnnotation

Class for handling detection model predictions.

Source code in sahi/prediction.py
class ObjectPrediction(ObjectAnnotation):
    """Class for handling detection model predictions."""

    def __init__(
        self,
        bbox: list[int] | None = None,
        category_id: int | None = None,
        category_name: str | None = None,
        segmentation: list[list[float]] | None = None,
        score: float = 0.0,
        shift_amount: list[int] | None = [0, 0],
        full_shape: list[int] | None = None,
    ):
        """Creates ObjectPrediction from bbox, score, category_id, category_name, segmentation.

        Args:
            bbox: list
                [minx, miny, maxx, maxy]
            score: float
                Prediction score between 0 and 1
            category_id: int
                ID of the object category
            category_name: str
                Name of the object category
            segmentation: List[List]
                [
                    [x1, y1, x2, y2, x3, y3, ...],
                    [x1, y1, x2, y2, x3, y3, ...],
                    ...
                ]
            shift_amount: list
                To shift the box and mask predictions from sliced image
                to full sized image, should be in the form of [shift_x, shift_y]
            full_shape: list
                Size of the full image after shifting, should be in
                the form of [height, width]
        """
        self.score = PredictionScore(score)
        super().__init__(
            bbox=bbox,
            category_id=category_id,
            segmentation=segmentation,
            category_name=category_name,
            shift_amount=shift_amount,
            full_shape=full_shape,
        )

    def get_shifted_object_prediction(self):
        """Returns shifted version ObjectPrediction.

        Shifts bbox and mask coords. Used for mapping sliced predictions over full image.
        """
        if self.mask:
            shifted_mask = self.mask.get_shifted_mask()
            return ObjectPrediction(
                bbox=self.bbox.get_shifted_box().to_xyxy(),
                category_id=self.category.id,
                score=self.score.value,
                segmentation=shifted_mask.segmentation,
                category_name=self.category.name,
                shift_amount=[0, 0],
                full_shape=shifted_mask.full_shape,
            )
        else:
            return ObjectPrediction(
                bbox=self.bbox.get_shifted_box().to_xyxy(),
                category_id=self.category.id,
                score=self.score.value,
                segmentation=None,
                category_name=self.category.name,
                shift_amount=[0, 0],
                full_shape=None,
            )

    def to_coco_prediction(self, image_id=None):
        """Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation."""
        if self.mask:
            coco_prediction = CocoPrediction.from_coco_segmentation(
                segmentation=self.mask.segmentation,
                category_id=self.category.id,
                category_name=self.category.name,
                score=self.score.value,
                image_id=image_id,
            )
        else:
            coco_prediction = CocoPrediction.from_coco_bbox(
                bbox=self.bbox.to_xywh(),
                category_id=self.category.id,
                category_name=self.category.name,
                score=self.score.value,
                image_id=image_id,
            )
        return coco_prediction

    def to_fiftyone_detection(self, image_height: int, image_width: int):
        """Returns fiftyone.Detection representation of ObjectPrediction."""
        try:
            import fiftyone as fo
        except ImportError:
            raise ImportError('Please run "pip install -U fiftyone" to install fiftyone first for fiftyone conversion.')

        x1, y1, x2, y2 = self.bbox.to_xyxy()
        rel_box = [x1 / image_width, y1 / image_height, (x2 - x1) / image_width, (y2 - y1) / image_height]
        fiftyone_detection = fo.Detection(label=self.category.name, bounding_box=rel_box, confidence=self.score.value)
        return fiftyone_detection

    def __repr__(self):
        return f"""ObjectPrediction<
    bbox: {self.bbox},
    mask: {self.mask},
    score: {self.score},
    category: {self.category}>"""
Functions
__init__(bbox=None, category_id=None, category_name=None, segmentation=None, score=0.0, shift_amount=[0, 0], full_shape=None)

Creates ObjectPrediction from bbox, score, category_id, category_name, segmentation.

Parameters:

Name Type Description Default
bbox list[int] | None

list [minx, miny, maxx, maxy]

None
score float

float Prediction score between 0 and 1

0.0
category_id int | None

int ID of the object category

None
category_name str | None

str Name of the object category

None
segmentation list[list[float]] | None

List[List] [ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

None
shift_amount list[int] | None

list To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]

[0, 0]
full_shape list[int] | None

list Size of the full image after shifting, should be in the form of [height, width]

None
Source code in sahi/prediction.py
def __init__(
    self,
    bbox: list[int] | None = None,
    category_id: int | None = None,
    category_name: str | None = None,
    segmentation: list[list[float]] | None = None,
    score: float = 0.0,
    shift_amount: list[int] | None = [0, 0],
    full_shape: list[int] | None = None,
):
    """Creates ObjectPrediction from bbox, score, category_id, category_name, segmentation.

    Args:
        bbox: list
            [minx, miny, maxx, maxy]
        score: float
            Prediction score between 0 and 1
        category_id: int
            ID of the object category
        category_name: str
            Name of the object category
        segmentation: List[List]
            [
                [x1, y1, x2, y2, x3, y3, ...],
                [x1, y1, x2, y2, x3, y3, ...],
                ...
            ]
        shift_amount: list
            To shift the box and mask predictions from sliced image
            to full sized image, should be in the form of [shift_x, shift_y]
        full_shape: list
            Size of the full image after shifting, should be in
            the form of [height, width]
    """
    self.score = PredictionScore(score)
    super().__init__(
        bbox=bbox,
        category_id=category_id,
        segmentation=segmentation,
        category_name=category_name,
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
get_shifted_object_prediction()

Returns shifted version ObjectPrediction.

Shifts bbox and mask coords. Used for mapping sliced predictions over full image.

Source code in sahi/prediction.py
def get_shifted_object_prediction(self):
    """Returns shifted version ObjectPrediction.

    Shifts bbox and mask coords. Used for mapping sliced predictions over full image.
    """
    if self.mask:
        shifted_mask = self.mask.get_shifted_mask()
        return ObjectPrediction(
            bbox=self.bbox.get_shifted_box().to_xyxy(),
            category_id=self.category.id,
            score=self.score.value,
            segmentation=shifted_mask.segmentation,
            category_name=self.category.name,
            shift_amount=[0, 0],
            full_shape=shifted_mask.full_shape,
        )
    else:
        return ObjectPrediction(
            bbox=self.bbox.get_shifted_box().to_xyxy(),
            category_id=self.category.id,
            score=self.score.value,
            segmentation=None,
            category_name=self.category.name,
            shift_amount=[0, 0],
            full_shape=None,
        )
to_coco_prediction(image_id=None)

Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation.

Source code in sahi/prediction.py
def to_coco_prediction(self, image_id=None):
    """Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation."""
    if self.mask:
        coco_prediction = CocoPrediction.from_coco_segmentation(
            segmentation=self.mask.segmentation,
            category_id=self.category.id,
            category_name=self.category.name,
            score=self.score.value,
            image_id=image_id,
        )
    else:
        coco_prediction = CocoPrediction.from_coco_bbox(
            bbox=self.bbox.to_xywh(),
            category_id=self.category.id,
            category_name=self.category.name,
            score=self.score.value,
            image_id=image_id,
        )
    return coco_prediction
to_fiftyone_detection(image_height, image_width)

Returns fiftyone.Detection representation of ObjectPrediction.

Source code in sahi/prediction.py
def to_fiftyone_detection(self, image_height: int, image_width: int):
    """Returns fiftyone.Detection representation of ObjectPrediction."""
    try:
        import fiftyone as fo
    except ImportError:
        raise ImportError('Please run "pip install -U fiftyone" to install fiftyone first for fiftyone conversion.')

    x1, y1, x2, y2 = self.bbox.to_xyxy()
    rel_box = [x1 / image_width, y1 / image_height, (x2 - x1) / image_width, (y2 - y1) / image_height]
    fiftyone_detection = fo.Detection(label=self.category.name, bounding_box=rel_box, confidence=self.score.value)
    return fiftyone_detection
PredictionResult
Source code in sahi/prediction.py
class PredictionResult:
    def __init__(
        self,
        object_prediction_list: list[ObjectPrediction],
        image: Image.Image | str | np.ndarray,
        durations_in_seconds: dict[str, Any] = dict(),
    ):
        self.image: Image.Image = read_image_as_pil(image)
        self.image_width, self.image_height = self.image.size
        self.object_prediction_list: list[ObjectPrediction] = object_prediction_list
        self.durations_in_seconds = durations_in_seconds

    def export_visuals(
        self,
        export_dir: str,
        text_size: float | None = None,
        rect_th: int | None = None,
        hide_labels: bool = False,
        hide_conf: bool = False,
        file_name: str = "prediction_visual",
    ):
        """

        Args:
            export_dir: directory for resulting visualization to be exported
            text_size: size of the category name over box
            rect_th: rectangle thickness
            hide_labels: hide labels
            hide_conf: hide confidence
            file_name: saving name
        Returns:

        """
        Path(export_dir).mkdir(parents=True, exist_ok=True)
        visualize_object_predictions(
            image=np.ascontiguousarray(self.image),
            object_prediction_list=self.object_prediction_list,
            rect_th=rect_th,
            text_size=text_size,
            text_th=None,
            color=None,
            hide_labels=hide_labels,
            hide_conf=hide_conf,
            output_dir=export_dir,
            file_name=file_name,
            export_format="png",
        )

    def to_coco_annotations(self):
        coco_annotation_list = []
        for object_prediction in self.object_prediction_list:
            coco_annotation_list.append(object_prediction.to_coco_prediction().json)
        return coco_annotation_list

    def to_coco_predictions(self, image_id: int | None = None):
        coco_prediction_list = []
        for object_prediction in self.object_prediction_list:
            coco_prediction_list.append(object_prediction.to_coco_prediction(image_id=image_id).json)
        return coco_prediction_list

    def to_imantics_annotations(self):
        imantics_annotation_list = []
        for object_prediction in self.object_prediction_list:
            imantics_annotation_list.append(object_prediction.to_imantics_annotation())
        return imantics_annotation_list

    def to_fiftyone_detections(self):
        try:
            import fiftyone as fo
        except ImportError:
            raise ImportError('Please run "uv pip install -U fiftyone" to install fiftyone for conversion.')

        fiftyone_detection_list: list[fo.Detection] = []
        for object_prediction in self.object_prediction_list:
            fiftyone_detection_list.append(
                object_prediction.to_fiftyone_detection(image_height=self.image_height, image_width=self.image_width)
            )
        return fiftyone_detection_list
Functions
export_visuals(export_dir, text_size=None, rect_th=None, hide_labels=False, hide_conf=False, file_name='prediction_visual')

Parameters:

Name Type Description Default
export_dir str

directory for resulting visualization to be exported

required
text_size float | None

size of the category name over box

None
rect_th int | None

rectangle thickness

None
hide_labels bool

hide labels

False
hide_conf bool

hide confidence

False
file_name str

saving name

'prediction_visual'

Returns:

Source code in sahi/prediction.py
def export_visuals(
    self,
    export_dir: str,
    text_size: float | None = None,
    rect_th: int | None = None,
    hide_labels: bool = False,
    hide_conf: bool = False,
    file_name: str = "prediction_visual",
):
    """

    Args:
        export_dir: directory for resulting visualization to be exported
        text_size: size of the category name over box
        rect_th: rectangle thickness
        hide_labels: hide labels
        hide_conf: hide confidence
        file_name: saving name
    Returns:

    """
    Path(export_dir).mkdir(parents=True, exist_ok=True)
    visualize_object_predictions(
        image=np.ascontiguousarray(self.image),
        object_prediction_list=self.object_prediction_list,
        rect_th=rect_th,
        text_size=text_size,
        text_th=None,
        color=None,
        hide_labels=hide_labels,
        hide_conf=hide_conf,
        output_dir=export_dir,
        file_name=file_name,
        export_format="png",
    )
PredictionScore
Source code in sahi/prediction.py
class PredictionScore:
    def __init__(self, value: float | np.ndarray):
        """
        Args:
            score: prediction score between 0 and 1
        """
        # if score is a numpy object, convert it to python variable
        if type(value).__module__ == "numpy":
            value = copy.deepcopy(value).tolist()
        # set score
        self.value = value

    def is_greater_than_threshold(self, threshold):
        """Check if score is greater than threshold."""
        return self.value > threshold

    def __eq__(self, threshold):
        return self.value == threshold

    def __gt__(self, threshold):
        return self.value > threshold

    def __lt__(self, threshold):
        return self.value < threshold

    def __repr__(self):
        return f"PredictionScore: <value: {self.value}>"
Functions
__init__(value)

Parameters:

Name Type Description Default
score

prediction score between 0 and 1

required
Source code in sahi/prediction.py
def __init__(self, value: float | np.ndarray):
    """
    Args:
        score: prediction score between 0 and 1
    """
    # if score is a numpy object, convert it to python variable
    if type(value).__module__ == "numpy":
        value = copy.deepcopy(value).tolist()
    # set score
    self.value = value
is_greater_than_threshold(threshold)

Check if score is greater than threshold.

Source code in sahi/prediction.py
def is_greater_than_threshold(self, threshold):
    """Check if score is greater than threshold."""
    return self.value > threshold
Functions

scripts

Modules
coco2fiftyone
Functions
main(image_dir, dataset_json_path, *result_json_paths, iou_thresh=0.5)

Parameters:

Name Type Description Default
image_dir str

directory for coco images

required
dataset_json_path str

file path for the coco dataset json file

required
result_json_paths str

one or more paths for the coco result json file

()
iou_thresh float

iou threshold for coco evaluation

0.5
Source code in sahi/scripts/coco2fiftyone.py
def main(
    image_dir: str,
    dataset_json_path: str,
    *result_json_paths,
    iou_thresh: float = 0.5,
):
    """
    Args:
        image_dir (str): directory for coco images
        dataset_json_path (str): file path for the coco dataset json file
        result_json_paths (str): one or more paths for the coco result json file
        iou_thresh (float): iou threshold for coco evaluation
    """

    from fiftyone.utils.coco import add_coco_labels

    from sahi.utils.fiftyone import create_fiftyone_dataset_from_coco_file, fo

    coco_result_list = []
    result_name_list = []
    if result_json_paths:
        for result_json_path in result_json_paths:
            coco_result = load_json(result_json_path)
            coco_result_list.append(coco_result)

            # use file names as fiftyone name, create unique names if duplicate
            result_name_temp = Path(result_json_path).stem
            result_name = result_name_temp
            name_increment = 2
            while result_name in result_name_list:
                result_name = result_name_temp + "_" + str(name_increment)
                name_increment += 1
            result_name_list.append(result_name)

    dataset = create_fiftyone_dataset_from_coco_file(image_dir, dataset_json_path)

    # submit detections if coco result is given
    if result_json_paths:
        for result_name, coco_result in zip(result_name_list, coco_result_list):
            add_coco_labels(dataset, result_name, coco_result, coco_id_field="gt_coco_id")

    # visualize results
    session = fo.launch_app()  # pyright: ignore[reportArgumentType]
    session.dataset = dataset

    # order by false positives if any coco result is given
    if result_json_paths:
        # Evaluate the predictions
        first_coco_result_name = result_name_list[0]
        _ = dataset.evaluate_detections(
            first_coco_result_name,
            gt_field="gt_detections",
            eval_key=f"{first_coco_result_name}_eval",
            iou=iou_thresh,
            compute_mAP=False,
        )
        # Get the 10 most common classes in the dataset
        # counts = dataset.count_values("gt_detections.detections.label")
        # classes_top10 = sorted(counts, key=counts.get, reverse=True)[:10]
        # Print a classification report for the top-10 classes
        # results.print_report(classes=classes_top10)
        # Load the view on which we ran the `eval` evaluation
        eval_view = dataset.load_evaluation_view(f"{first_coco_result_name}_eval")
        # Show samples with most false positives
        session.view = eval_view.sort_by(f"{first_coco_result_name}_eval_fp", reverse=True)

        print(f"SAHI has successfully launched a Fiftyone app at http://localhost:{fo.config.default_app_port}")
    while 1:
        time.sleep(3)
coco2yolo
Classes Functions
main(image_dir, dataset_json_path, train_split=0.9, project='runs/coco2yolo', name='exp', seed=1, disable_symlink=False)

Parameters:

Name Type Description Default
images_dir str

directory for coco images

required
dataset_json_path str

file path for the coco json file to be converted

required
train_split float or int

set the training split ratio

0.9
project str

save results to project/name

'runs/coco2yolo'
name str

save results to project/name"

'exp'
seed int

fix the seed for reproducibility

1
disable_symlink bool

required in google colab env

False
Source code in sahi/scripts/coco2yolo.py
def main(
    image_dir: str,
    dataset_json_path: str,
    train_split: int | float = 0.9,
    project: str = "runs/coco2yolo",
    name: str = "exp",
    seed: int = 1,
    disable_symlink=False,
):
    """
    Args:
        images_dir (str): directory for coco images
        dataset_json_path (str): file path for the coco json file to be converted
        train_split (float or int): set the training split ratio
        project (str): save results to project/name
        name (str): save results to project/name"
        seed (int): fix the seed for reproducibility
        disable_symlink (bool): required in google colab env
    """

    # increment run
    save_dir = Path(increment_path(Path(project) / name, exist_ok=False))
    # load coco dict
    coco = Coco.from_coco_dict_or_path(
        coco_dict_or_path=dataset_json_path,
        image_dir=image_dir,
    )
    # export as YOLO
    coco.export_as_yolo(
        output_dir=str(save_dir),
        train_split_rate=train_split,
        numpy_seed=seed,
        disable_symlink=disable_symlink,
    )

    print(f"COCO to YOLO conversion results are successfully exported to {save_dir}")
coco_error_analysis
Functions
analyse(dataset_json_path, result_json_path, out_dir=None, type='bbox', no_extraplots=False, areas=[1024, 9216, 10000000000], max_detections=500, return_dict=False)

Parameters:

Name Type Description Default
dataset_json_path str

file path for the coco dataset json file

required
result_json_paths str

file path for the coco result json file

required
out_dir str

dir to save analyse result images

None
no_extraplots bool

dont export export extra bar/stat plots

False
type str

'bbox' or 'mask'

'bbox'
areas List[int]

area regions for coco evaluation calculations

[1024, 9216, 10000000000]
max_detections int

Maximum number of detections to consider for AP alculation. Default: 500

500
return_dict bool

If True, returns a dict export paths.

False
Source code in sahi/scripts/coco_error_analysis.py
def analyse(
    dataset_json_path: str,
    result_json_path: str,
    out_dir: str | None = None,
    type: str = "bbox",
    no_extraplots: bool = False,
    areas: list[int] = [1024, 9216, 10000000000],
    max_detections: int = 500,
    return_dict: bool = False,
) -> dict | None:
    """
    Args:
        dataset_json_path (str): file path for the coco dataset json file
        result_json_paths (str): file path for the coco result json file
        out_dir (str): dir to save analyse result images
        no_extraplots (bool): dont export export extra bar/stat plots
        type (str): 'bbox' or 'mask'
        areas (List[int]): area regions for coco evaluation calculations
        max_detections (int): Maximum number of detections to consider for AP alculation. Default: 500
        return_dict (bool): If True, returns a dict export paths.
    """
    if not has_matplotlib:
        logger.error("Please run 'uv pip install -U matplotlib' first for visualization.")
        raise ModuleNotFoundError("matplotlib not installed")
    if not has_pycocotools:
        logger.error("Please run 'uv pip install -U pycocotools' first for Coco analysis.")
        raise ModuleNotFoundError("pycocotools not installed")

    result = _analyse_results(
        result_json_path,
        dataset_json_path,
        res_types=[type],
        out_dir=out_dir,
        extraplots=not no_extraplots,
        areas=areas,
        max_detections=max_detections,
    )
    if return_dict:
        return result
coco_evaluation
Functions
evaluate(dataset_json_path, result_json_path, out_dir=None, type='bbox', classwise=False, max_detections=500, iou_thrs=None, areas=[1024, 9216, 10000000000], return_dict=False)

Parameters:

Name Type Description Default
dataset_json_path str

file path for the coco dataset json file

required
result_json_path str

file path for the coco result json file

required
out_dir str

dir to save eval result

None
type bool

'bbox' or 'segm'

'bbox'
classwise bool

whether to evaluate the AP for each class

False
max_detections int

Maximum number of detections to consider for AP alculation. Default: 500

500
iou_thrs float

IoU threshold used for evaluating recalls/mAPs

None
areas List[int]

area regions for coco evaluation calculations

[1024, 9216, 10000000000]
return_dict bool

If True, returns a dict with 'eval_results' 'export_path' fields.

False
Source code in sahi/scripts/coco_evaluation.py
def evaluate(
    dataset_json_path: str,
    result_json_path: str,
    out_dir: str | None = None,
    type: Literal["bbox", "segm"] = "bbox",
    classwise: bool = False,
    max_detections: int = 500,
    iou_thrs: list[float] | float | None = None,
    areas: list[int] = [1024, 9216, 10000000000],
    return_dict: bool = False,
):
    """
    Args:
        dataset_json_path (str): file path for the coco dataset json file
        result_json_path (str): file path for the coco result json file
        out_dir (str): dir to save eval result
        type (bool): 'bbox' or 'segm'
        classwise (bool): whether to evaluate the AP for each class
        max_detections (int): Maximum number of detections to consider for AP alculation. Default: 500
        iou_thrs (float): IoU threshold used for evaluating recalls/mAPs
        areas (List[int]): area regions for coco evaluation calculations
        return_dict (bool): If True, returns a dict with 'eval_results' 'export_path' fields.
    """
    try:
        from pycocotools.coco import COCO
        from pycocotools.cocoeval import COCOeval
    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            'Please run "pip install -U pycocotools" to install pycocotools first for coco evaluation.'
        )

    # perform coco eval
    result = evaluate_core(
        dataset_path=dataset_json_path,
        result_path=result_json_path,
        metric=type,
        classwise=classwise,
        max_detections=max_detections,
        iou_thrs=iou_thrs,
        out_dir=out_dir,
        areas=areas,
        COCO=COCO,
        COCOeval=COCOeval,
    )
    if return_dict:
        return result
evaluate_core(dataset_path, result_path, COCO, COCOeval, metric='bbox', classwise=False, max_detections=500, iou_thrs=None, metric_items=None, out_dir=None, areas=[1024, 9216, 10000000000])

Evaluation in COCO protocol.

Parameters:

Name Type Description Default
dataset_path str

COCO dataset json path.

required
result_path str

COCO result json path.

required
COCO, COCOeval

Pass COCO and COCOeval class after safely imported

required
metric str | list[str]

Metrics to be evaluated. Options are 'bbox', 'segm', 'proposal'.

'bbox'
classwise bool

Whether to evaluating the AP for each class.

False
max_detections int

Maximum number of detections to consider for AP calculation. Default: 500

500
iou_thrs List[float]

IoU threshold used for evaluating recalls/mAPs. If set to a list, the average of all IoUs will also be computed. If not specified, [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used. Default: None.

None
metric_items list[str] | str

Metric items that will be returned. If not specified, ['AR@10', 'AR@100', 'AR@500', 'AR_s@500', 'AR_m@500', 'AR_l@500' ] will be used when metric=='proposal', ['mAP', 'mAP50', 'mAP75', 'mAP_s', 'mAP_m', 'mAP_l', 'mAP50_s', 'mAP50_m', 'mAP50_l'] will be used when metric=='bbox' or metric=='segm'.

None
out_dir str

Directory to save evaluation result json.

None
areas List[int]

area regions for coco evaluation calculations

[1024, 9216, 10000000000]

Returns: dict: eval_results (dict[str, float]): COCO style evaluation metric. export_path (str): Path for the exported eval result json.

Source code in sahi/scripts/coco_evaluation.py
def evaluate_core(
    dataset_path: str,
    result_path: str,
    COCO: type,
    COCOeval: type,
    metric: str = "bbox",
    classwise: bool = False,
    max_detections: int = 500,
    iou_thrs=None,
    metric_items=None,
    out_dir: str | Path | None = None,
    areas: list[int] = [1024, 9216, 10000000000],
):
    """Evaluation in COCO protocol.

    Args:
        dataset_path (str): COCO dataset json path.
        result_path (str): COCO result json path.
        COCO, COCOeval: Pass COCO and COCOeval class after safely imported
        metric (str | list[str]): Metrics to be evaluated. Options are
            'bbox', 'segm', 'proposal'.
        classwise (bool): Whether to evaluating the AP for each class.
        max_detections (int): Maximum number of detections to consider for AP
            calculation.
            Default: 500
        iou_thrs (List[float], optional): IoU threshold used for
            evaluating recalls/mAPs. If set to a list, the average of all
            IoUs will also be computed. If not specified, [0.50, 0.55,
            0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
            Default: None.
        metric_items (list[str] | str, optional): Metric items that will
            be returned. If not specified, ``['AR@10', 'AR@100',
            'AR@500', 'AR_s@500', 'AR_m@500', 'AR_l@500' ]`` will be
            used when ``metric=='proposal'``, ``['mAP', 'mAP50', 'mAP75',
            'mAP_s', 'mAP_m', 'mAP_l', 'mAP50_s', 'mAP50_m', 'mAP50_l']``
            will be used when ``metric=='bbox' or metric=='segm'``.
        out_dir (str): Directory to save evaluation result json.
        areas (List[int]): area regions for coco evaluation calculations
    Returns:
        dict:
            eval_results (dict[str, float]): COCO style evaluation metric.
            export_path (str): Path for the exported eval result json.
    """

    metrics = metric if isinstance(metric, list) else [metric]
    allowed_metrics = ["bbox", "segm"]
    for metric in metrics:
        if metric not in allowed_metrics:
            raise KeyError(f"metric {metric} is not supported")
    if iou_thrs is None:
        iou_thrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
    if metric_items is not None:
        if not isinstance(metric_items, list):
            metric_items = [metric_items]
    if areas is not None:
        if len(areas) != 3:
            raise ValueError("3 integers should be specified as areas, representing 3 area regions")
    eval_results = OrderedDict()

    # Load dataset json and add empty 'info' field if missing
    with open(dataset_path) as f:
        dataset_dict = json.load(f)
    if "info" not in dataset_dict:
        dataset_dict["info"] = {}

    # Create temporary file with updated dataset
    import tempfile

    with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp_file:
        json.dump(dataset_dict, tmp_file)
        temp_dataset_path = tmp_file.name

    try:
        cocoGt = COCO(temp_dataset_path)
        cat_ids = list(cocoGt.cats.keys())
        for metric in metrics:
            msg = f"Evaluating {metric}..."
            msg = "\n" + msg
            print(msg)

            iou_type = metric
            with open(result_path) as json_file:
                results = json.load(json_file)
            try:
                cocoDt = cocoGt.loadRes(results)
            except IndexError:
                print("The testing results of the whole dataset is empty.")
                break

            cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
            if areas is not None:
                cocoEval.params.areaRng = [
                    [0**2, areas[2]],
                    [0**2, areas[0]],
                    [areas[0], areas[1]],
                    [areas[1], areas[2]],
                ]
            cocoEval.params.catIds = cat_ids
            cocoEval.params.maxDets = [max_detections]
            cocoEval.params.iouThrs = (
                [iou_thrs] if not isinstance(iou_thrs, list) and not isinstance(iou_thrs, np.ndarray) else iou_thrs
            )
            # mapping of cocoEval.stats
            coco_metric_names = {
                "mAP": 0,
                "mAP75": 1,
                "mAP50": 2,
                "mAP_s": 3,
                "mAP_m": 4,
                "mAP_l": 5,
                "mAP50_s": 6,
                "mAP50_m": 7,
                "mAP50_l": 8,
                "AR_s": 9,
                "AR_m": 10,
                "AR_l": 11,
            }
            if metric_items is not None:
                for metric_item in metric_items:
                    if metric_item not in coco_metric_names:
                        raise KeyError(f"metric item {metric_item} is not supported")

            cocoEval.evaluate()
            cocoEval.accumulate()
            # calculate mAP50_s/m/l
            mAP = _cocoeval_summarize(cocoEval, ap=1, iouThr=None, areaRng="all", maxDets=max_detections)
            mAP50 = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.5, areaRng="all", maxDets=max_detections)
            mAP75 = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.75, areaRng="all", maxDets=max_detections)
            mAP50_s = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.5, areaRng="small", maxDets=max_detections)
            mAP50_m = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.5, areaRng="medium", maxDets=max_detections)
            mAP50_l = _cocoeval_summarize(cocoEval, ap=1, iouThr=0.5, areaRng="large", maxDets=max_detections)
            mAP_s = _cocoeval_summarize(cocoEval, ap=1, iouThr=None, areaRng="small", maxDets=max_detections)
            mAP_m = _cocoeval_summarize(cocoEval, ap=1, iouThr=None, areaRng="medium", maxDets=max_detections)
            mAP_l = _cocoeval_summarize(cocoEval, ap=1, iouThr=None, areaRng="large", maxDets=max_detections)
            AR_s = _cocoeval_summarize(cocoEval, ap=0, iouThr=None, areaRng="small", maxDets=max_detections)
            AR_m = _cocoeval_summarize(cocoEval, ap=0, iouThr=None, areaRng="medium", maxDets=max_detections)
            AR_l = _cocoeval_summarize(cocoEval, ap=0, iouThr=None, areaRng="large", maxDets=max_detections)
            cocoEval.stats = np.append(
                [mAP, mAP75, mAP50, mAP_s, mAP_m, mAP_l, mAP50_s, mAP50_m, mAP50_l, AR_s, AR_m, AR_l], 0
            )

            if classwise:  # Compute per-category AP
                # Compute per-category AP
                # from https://github.com/facebookresearch/detectron2/
                precisions = cocoEval.eval["precision"]
                # precision: (iou, recall, cls, area range, max dets)
                if len(cat_ids) != precisions.shape[2]:
                    raise ValueError(
                        f"The number of categories {len(cat_ids)} is not equal "
                        f"to the number of precisions {precisions.shape[2]}"
                    )
                max_cat_name_len = 0
                for idx, catId in enumerate(cat_ids):
                    nm = cocoGt.loadCats(catId)[0]
                    cat_name_len = len(nm["name"])
                    max_cat_name_len = cat_name_len if cat_name_len > max_cat_name_len else max_cat_name_len

                results_per_category = []
                for idx, catId in enumerate(cat_ids):
                    # skip if no image with this category
                    image_ids = cocoGt.getImgIds(catIds=[catId])
                    if len(image_ids) == 0:
                        continue
                    # area range index 0: all area ranges
                    # max dets index -1: typically 100 per image
                    nm = cocoGt.loadCats(catId)[0]
                    ap = _cocoeval_summarize(
                        cocoEval,
                        ap=1,
                        catIdx=idx,
                        areaRng="all",
                        maxDets=max_detections,
                        catName=nm["name"],
                        nameStrLen=max_cat_name_len,
                    )
                    ap_s = _cocoeval_summarize(
                        cocoEval,
                        ap=1,
                        catIdx=idx,
                        areaRng="small",
                        maxDets=max_detections,
                        catName=nm["name"],
                        nameStrLen=max_cat_name_len,
                    )
                    ap_m = _cocoeval_summarize(
                        cocoEval,
                        ap=1,
                        catIdx=idx,
                        areaRng="medium",
                        maxDets=max_detections,
                        catName=nm["name"],
                        nameStrLen=max_cat_name_len,
                    )
                    ap_l = _cocoeval_summarize(
                        cocoEval,
                        ap=1,
                        catIdx=idx,
                        areaRng="large",
                        maxDets=max_detections,
                        catName=nm["name"],
                        nameStrLen=max_cat_name_len,
                    )
                    ap50 = _cocoeval_summarize(
                        cocoEval,
                        ap=1,
                        iouThr=0.5,
                        catIdx=idx,
                        areaRng="all",
                        maxDets=max_detections,
                        catName=nm["name"],
                        nameStrLen=max_cat_name_len,
                    )
                    ap50_s = _cocoeval_summarize(
                        cocoEval,
                        ap=1,
                        iouThr=0.5,
                        catIdx=idx,
                        areaRng="small",
                        maxDets=max_detections,
                        catName=nm["name"],
                        nameStrLen=max_cat_name_len,
                    )
                    ap50_m = _cocoeval_summarize(
                        cocoEval,
                        ap=1,
                        iouThr=0.5,
                        catIdx=idx,
                        areaRng="medium",
                        maxDets=max_detections,
                        catName=nm["name"],
                        nameStrLen=max_cat_name_len,
                    )
                    ap50_l = _cocoeval_summarize(
                        cocoEval,
                        ap=1,
                        iouThr=0.5,
                        catIdx=idx,
                        areaRng="large",
                        maxDets=max_detections,
                        catName=nm["name"],
                        nameStrLen=max_cat_name_len,
                    )
                    results_per_category.append((f"{metric}_{nm['name']}_mAP", f"{float(ap):0.3f}"))
                    results_per_category.append((f"{metric}_{nm['name']}_mAP_s", f"{float(ap_s):0.3f}"))
                    results_per_category.append((f"{metric}_{nm['name']}_mAP_m", f"{float(ap_m):0.3f}"))
                    results_per_category.append((f"{metric}_{nm['name']}_mAP_l", f"{float(ap_l):0.3f}"))
                    results_per_category.append((f"{metric}_{nm['name']}_mAP50", f"{float(ap50):0.3f}"))
                    results_per_category.append((f"{metric}_{nm['name']}_mAP50_s", f"{float(ap50_s):0.3f}"))
                    results_per_category.append((f"{metric}_{nm['name']}_mAP50_m", f"{float(ap50_m):0.3f}"))
                    results_per_category.append((f"{metric}_{nm['name']}_mAP50_l", f"{float(ap50_l):0.3f}"))

                num_columns = min(6, len(results_per_category) * 2)
                results_flatten = list(itertools.chain(*results_per_category))
                headers = ["category", "AP"] * (num_columns // 2)
                results_2d = itertools.zip_longest(*[results_flatten[i::num_columns] for i in range(num_columns)])
                table_data = [headers]
                table_data += [result for result in results_2d]
                table = AsciiTable(table_data)
                print("\n" + table.table)

            if metric_items is None:
                metric_items = ["mAP", "mAP50", "mAP75", "mAP_s", "mAP_m", "mAP_l", "mAP50_s", "mAP50_m", "mAP50_l"]

            for metric_item in metric_items:
                key = f"{metric}_{metric_item}"
                val = float(f"{cocoEval.stats[coco_metric_names[metric_item]]:.3f}")
                eval_results[key] = val
            ap = cocoEval.stats
            eval_results[f"{metric}_mAP_copypaste"] = (
                f"{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} "
                f"{ap[4]:.3f} {ap[5]:.3f} {ap[6]:.3f} {ap[7]:.3f} "
                f"{ap[8]:.3f}"
            )
            if classwise:
                eval_results["results_per_category"] = {key: value for key, value in results_per_category}
    finally:
        # Clean up temporary file
        os.unlink(temp_dataset_path)

    # set save path
    if not out_dir:
        out_dir = Path(result_path).parent
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    export_path = str(Path(out_dir) / "eval.json")
    # export as json
    with open(export_path, "w", encoding="utf-8") as outfile:
        json.dump(eval_results, outfile, indent=4, separators=(",", ":"))
    print(f"COCO evaluation results are successfully exported to {export_path}")
    return {"eval_results": eval_results, "export_path": export_path}
slice_coco
Functions
slicer(image_dir, dataset_json_path, slice_size=512, overlap_ratio=0.2, ignore_negative_samples=False, output_dir='runs/slice_coco', min_area_ratio=0.1)

Parameters:

Name Type Description Default
image_dir str

directory for coco images

required
dataset_json_path str

file path for the coco dataset json file

required
overlap_ratio float

slice overlap ratio

0.2
ignore_negative_samples bool

ignore images without annotation

False
output_dir str

output export dir

'runs/slice_coco'
min_area_ratio float

If the cropped annotation area to original annotation ratio is smaller than this value, the annotation is filtered out. Default 0.1.

0.1
Source code in sahi/scripts/slice_coco.py
def slicer(
    image_dir: str,
    dataset_json_path: str,
    slice_size: int = 512,
    overlap_ratio: float = 0.2,
    ignore_negative_samples: bool = False,
    output_dir: str = "runs/slice_coco",
    min_area_ratio: float = 0.1,
):
    """
    Args:
        image_dir (str): directory for coco images
        dataset_json_path (str): file path for the coco dataset json file
        slice_size (int)
        overlap_ratio (float): slice overlap ratio
        ignore_negative_samples (bool): ignore images without annotation
        output_dir (str): output export dir
        min_area_ratio (float): If the cropped annotation area to original
            annotation ratio is smaller than this value, the annotation
            is filtered out. Default 0.1.
    """

    # assure slice_size is list
    slice_size_list = slice_size
    if isinstance(slice_size_list, (int, float)):
        slice_size_list = [slice_size_list]

    # slice coco dataset images and annotations
    print("Slicing step is starting...")
    for slice_size in slice_size_list:
        # in format: train_images_512_01
        output_images_folder_name = (
            Path(dataset_json_path).stem + f"_images_{slice_size!s}_{str(overlap_ratio).replace('.', '')}"
        )
        output_images_dir = str(Path(output_dir) / output_images_folder_name)
        sliced_coco_name = Path(dataset_json_path).name.replace(
            ".json", f"_{slice_size!s}_{str(overlap_ratio).replace('.', '')}"
        )
        coco_dict, _ = slice_coco(
            coco_annotation_file_path=dataset_json_path,
            image_dir=image_dir,
            output_coco_annotation_file_name="",
            output_dir=output_images_dir,
            ignore_negative_samples=ignore_negative_samples,
            slice_height=slice_size,
            slice_width=slice_size,
            min_area_ratio=min_area_ratio,
            overlap_height_ratio=overlap_ratio,
            overlap_width_ratio=overlap_ratio,
            out_ext=".jpg",
            verbose=False,
        )
        output_coco_annotation_file_path = os.path.join(output_dir, sliced_coco_name + ".json")
        save_json(coco_dict, output_coco_annotation_file_path)
        print(f"Sliced dataset for 'slice_size: {slice_size}' is exported to {output_dir}")

slicing

Classes
SliceImageResult
Source code in sahi/slicing.py
class SliceImageResult:
    def __init__(self, original_image_size: list[int], image_dir: str | None = None):
        """
        image_dir: str
            Directory of the sliced image exports.
        original_image_size: list of int
            Size of the unsliced original image in [height, width]
        """
        self.original_image_height = original_image_size[0]
        self.original_image_width = original_image_size[1]
        self.image_dir = image_dir

        self._sliced_image_list: list[SlicedImage] = []

    def add_sliced_image(self, sliced_image: SlicedImage):
        if not isinstance(sliced_image, SlicedImage):
            raise TypeError("sliced_image must be a SlicedImage instance")

        self._sliced_image_list.append(sliced_image)

    @property
    def sliced_image_list(self):
        return self._sliced_image_list

    @property
    def images(self):
        """Returns sliced images.

        Returns:
            images: a list of np.array
        """
        images = []
        for sliced_image in self._sliced_image_list:
            images.append(sliced_image.image)
        return images

    @property
    def coco_images(self) -> list[CocoImage]:
        """Returns CocoImage representation of SliceImageResult.

        Returns:
            coco_images: a list of CocoImage
        """
        coco_images: list = []
        for sliced_image in self._sliced_image_list:
            coco_images.append(sliced_image.coco_image)
        return coco_images

    @property
    def starting_pixels(self) -> list[int]:
        """Returns a list of starting pixels for each slice.

        Returns:
            starting_pixels: a list of starting pixel coords [x,y]
        """
        starting_pixels = []
        for sliced_image in self._sliced_image_list:
            starting_pixels.append(sliced_image.starting_pixel)
        return starting_pixels

    @property
    def filenames(self) -> list[int]:
        """Returns a list of filenames for each slice.

        Returns:
            filenames: a list of filenames as str
        """
        filenames = []
        for sliced_image in self._sliced_image_list:
            filenames.append(sliced_image.coco_image.file_name)
        return filenames

    def __getitem__(self, i):
        def _prepare_ith_dict(i):
            return {
                "image": self.images[i],
                "coco_image": self.coco_images[i],
                "starting_pixel": self.starting_pixels[i],
                "filename": self.filenames[i],
            }

        if isinstance(i, np.ndarray):
            i = i.tolist()

        if isinstance(i, int):
            return _prepare_ith_dict(i)
        elif isinstance(i, slice):
            start, stop, step = i.indices(len(self))
            return [_prepare_ith_dict(i) for i in range(start, stop, step)]
        elif isinstance(i, (tuple, list)):
            accessed_mapping = map(_prepare_ith_dict, i)
            return list(accessed_mapping)
        else:
            raise NotImplementedError(f"{type(i)}")

    def __len__(self):
        return len(self._sliced_image_list)
Attributes
coco_images property

Returns CocoImage representation of SliceImageResult.

Returns:

Name Type Description
coco_images list[CocoImage]

a list of CocoImage

filenames property

Returns a list of filenames for each slice.

Returns:

Name Type Description
filenames list[int]

a list of filenames as str

images property

Returns sliced images.

Returns:

Name Type Description
images

a list of np.array

starting_pixels property

Returns a list of starting pixels for each slice.

Returns:

Name Type Description
starting_pixels list[int]

a list of starting pixel coords [x,y]

Functions
__init__(original_image_size, image_dir=None)
str

Directory of the sliced image exports.

original_image_size: list of int Size of the unsliced original image in [height, width]

Source code in sahi/slicing.py
def __init__(self, original_image_size: list[int], image_dir: str | None = None):
    """
    image_dir: str
        Directory of the sliced image exports.
    original_image_size: list of int
        Size of the unsliced original image in [height, width]
    """
    self.original_image_height = original_image_size[0]
    self.original_image_width = original_image_size[1]
    self.image_dir = image_dir

    self._sliced_image_list: list[SlicedImage] = []
SlicedImage
Source code in sahi/slicing.py
class SlicedImage:
    def __init__(self, image, coco_image, starting_pixel):
        """
        image: np.array
            Sliced image.
        coco_image: CocoImage
            Coco styled image object that belong to sliced image.
        starting_pixel: list of list of int
            Starting pixel coordinates of the sliced image.
        """
        self.image = image
        self.coco_image = coco_image
        self.starting_pixel = starting_pixel
Functions
__init__(image, coco_image, starting_pixel)
np.array

Sliced image.

coco_image: CocoImage Coco styled image object that belong to sliced image. starting_pixel: list of list of int Starting pixel coordinates of the sliced image.

Source code in sahi/slicing.py
def __init__(self, image, coco_image, starting_pixel):
    """
    image: np.array
        Sliced image.
    coco_image: CocoImage
        Coco styled image object that belong to sliced image.
    starting_pixel: list of list of int
        Starting pixel coordinates of the sliced image.
    """
    self.image = image
    self.coco_image = coco_image
    self.starting_pixel = starting_pixel
Functions
annotation_inside_slice(annotation, slice_bbox)

Check whether annotation coordinates lie inside slice coordinates.

Parameters:

Name Type Description Default
annotation dict

Single annotation entry in COCO format.

required
slice_bbox List[int]

Generated from get_slice_bboxes. Format for each slice bbox: [x_min, y_min, x_max, y_max].

required

Returns:

Type Description
bool

True if any annotation coordinate lies inside slice.

Source code in sahi/slicing.py
def annotation_inside_slice(annotation: dict, slice_bbox: list[int]) -> bool:
    """Check whether annotation coordinates lie inside slice coordinates.

    Args:
        annotation (dict): Single annotation entry in COCO format.
        slice_bbox (List[int]): Generated from `get_slice_bboxes`.
            Format for each slice bbox: [x_min, y_min, x_max, y_max].

    Returns:
        (bool): True if any annotation coordinate lies inside slice.
    """
    left, top, width, height = annotation["bbox"]

    right = left + width
    bottom = top + height

    if left >= slice_bbox[2]:
        return False
    if top >= slice_bbox[3]:
        return False
    if right <= slice_bbox[0]:
        return False
    if bottom <= slice_bbox[1]:
        return False

    return True
calc_aspect_ratio_orientation(width, height)

Parameters:

Name Type Description Default
width int
required
height int
required

Returns:

Type Description
str

image capture orientation

Source code in sahi/slicing.py
def calc_aspect_ratio_orientation(width: int, height: int) -> str:
    """

    Args:
        width:
        height:

    Returns:
        image capture orientation
    """

    if width < height:
        return "vertical"
    elif width > height:
        return "horizontal"
    else:
        return "square"
calc_ratio_and_slice(orientation, slide=1, ratio=0.1)

According to image resolution calculation overlap params Args: orientation: image capture angle slide: sliding window ratio: buffer value

Returns:

Type Description

overlap params

Source code in sahi/slicing.py
def calc_ratio_and_slice(orientation: Literal["vertical", "horizontal", "square"], slide: int = 1, ratio: float = 0.1):
    """
    According to image resolution calculation overlap params
    Args:
        orientation: image capture angle
        slide: sliding window
        ratio: buffer value

    Returns:
        overlap params
    """
    if orientation == "vertical":
        slice_row, slice_col, overlap_height_ratio, overlap_width_ratio = slide, slide * 2, ratio, ratio
    elif orientation == "horizontal":
        slice_row, slice_col, overlap_height_ratio, overlap_width_ratio = slide * 2, slide, ratio, ratio
    elif orientation == "square":
        slice_row, slice_col, overlap_height_ratio, overlap_width_ratio = slide, slide, ratio, ratio
    else:
        raise ValueError(f"Invalid orientation: {orientation}. Must be one of 'vertical', 'horizontal', or 'square'.")

    return slice_row, slice_col, overlap_height_ratio, overlap_width_ratio
calc_resolution_factor(resolution)

According to image resolution calculate power(2,n) and return the closest smaller n. Args: resolution: the width and height of the image multiplied. such as 1024x720 = 737280

Returns:

Source code in sahi/slicing.py
def calc_resolution_factor(resolution: int) -> int:
    """
    According to image resolution calculate power(2,n) and return the closest smaller `n`.
    Args:
        resolution: the width and height of the image multiplied. such as 1024x720 = 737280

    Returns:

    """
    expo = 0
    while np.power(2, expo) < resolution:
        expo += 1

    return expo - 1
calc_slice_and_overlap_params(resolution, height, width, orientation)

This function calculate according to image resolution slice and overlap params. Args: resolution: str height: int width: int orientation: str

Returns:

Type Description
tuple[int, int, int, int]

x_overlap, y_overlap, slice_width, slice_height

Source code in sahi/slicing.py
def calc_slice_and_overlap_params(
    resolution: str, height: int, width: int, orientation: str
) -> tuple[int, int, int, int]:
    """
    This function calculate according to image resolution slice and overlap params.
    Args:
        resolution: str
        height: int
        width: int
        orientation: str

    Returns:
        x_overlap, y_overlap, slice_width, slice_height
    """

    if resolution == "medium":
        split_row, split_col, overlap_height_ratio, overlap_width_ratio = calc_ratio_and_slice(
            orientation, slide=1, ratio=0.8
        )

    elif resolution == "high":
        split_row, split_col, overlap_height_ratio, overlap_width_ratio = calc_ratio_and_slice(
            orientation, slide=2, ratio=0.4
        )

    elif resolution == "ultra-high":
        split_row, split_col, overlap_height_ratio, overlap_width_ratio = calc_ratio_and_slice(
            orientation, slide=4, ratio=0.4
        )
    else:  # low condition
        split_col = 1
        split_row = 1
        overlap_width_ratio = 1
        overlap_height_ratio = 1

    slice_height = height // split_col
    slice_width = width // split_row

    x_overlap = int(slice_width * overlap_width_ratio)
    y_overlap = int(slice_height * overlap_height_ratio)

    return x_overlap, y_overlap, slice_width, slice_height
get_auto_slice_params(height, width)

According to Image HxW calculate overlap sliding window and buffer params factor is the power value of 2 closest to the image resolution. factor <= 18: low resolution image such as 300x300, 640x640 18 < factor <= 21: medium resolution image such as 1024x1024, 1336x960 21 < factor <= 24: high resolution image such as 2048x2048, 2048x4096, 4096x4096 factor > 24: ultra-high resolution image such as 6380x6380, 4096x8192 Args: height: width:

Returns:

Type Description
tuple[int, int, int, int]

slicing overlap params x_overlap, y_overlap, slice_width, slice_height

Source code in sahi/slicing.py
def get_auto_slice_params(height: int, width: int) -> tuple[int, int, int, int]:
    """
    According to Image HxW calculate overlap sliding window and buffer params
    factor is the power value of 2 closest to the image resolution.
        factor <= 18: low resolution image such as 300x300, 640x640
        18 < factor <= 21: medium resolution image such as 1024x1024, 1336x960
        21 < factor <= 24: high resolution image such as 2048x2048, 2048x4096, 4096x4096
        factor > 24: ultra-high resolution image such as 6380x6380, 4096x8192
    Args:
        height:
        width:

    Returns:
        slicing overlap params x_overlap, y_overlap, slice_width, slice_height
    """
    resolution = height * width
    factor = calc_resolution_factor(resolution)
    if factor <= 18:
        return get_resolution_selector("low", height=height, width=width)
    elif 18 <= factor < 21:
        return get_resolution_selector("medium", height=height, width=width)
    elif 21 <= factor < 24:
        return get_resolution_selector("high", height=height, width=width)
    else:
        return get_resolution_selector("ultra-high", height=height, width=width)
get_resolution_selector(res, height, width)

Parameters:

Name Type Description Default
res str

resolution of image such as low, medium

required
height int
required
width int
required

Returns:

Type Description
tuple[int, int, int, int]

trigger slicing params function and return overlap params

Source code in sahi/slicing.py
def get_resolution_selector(res: str, height: int, width: int) -> tuple[int, int, int, int]:
    """

    Args:
        res: resolution of image such as low, medium
        height:
        width:

    Returns:
        trigger slicing params function and return overlap params
    """
    orientation = calc_aspect_ratio_orientation(width=width, height=height)
    x_overlap, y_overlap, slice_width, slice_height = calc_slice_and_overlap_params(
        resolution=res, height=height, width=width, orientation=orientation
    )

    return x_overlap, y_overlap, slice_width, slice_height
get_slice_bboxes(image_height, image_width, slice_height=None, slice_width=None, auto_slice_resolution=True, overlap_height_ratio=0.2, overlap_width_ratio=0.2)

Generate bounding boxes for slicing an image into crops.

The function calculates the coordinates for each slice based on the provided image dimensions, slice size, and overlap ratios. If slice size is not provided and auto_slice_resolution is True, the function will automatically determine appropriate slice parameters.

Parameters:

Name Type Description Default
image_height int

Height of the original image.

required
image_width int

Width of the original image.

required
slice_height int

Height of each slice. Default None.

None
slice_width int

Width of each slice. Default None.

None
overlap_height_ratio float

Fractional overlap in height of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels). Default 0.2.

0.2
overlap_width_ratio float

Fractional overlap in width of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels). Default 0.2.

0.2
auto_slice_resolution bool

if not set slice parameters such as slice_height and slice_width, it enables automatically calculate these parameters from image resolution and orientation.

True

Returns:

Type Description
list[list[int]]

List[List[int]]: List of 4 corner coordinates for each N slices. [ [slice_0_left, slice_0_top, slice_0_right, slice_0_bottom], ... [slice_N_left, slice_N_top, slice_N_right, slice_N_bottom] ]

Source code in sahi/slicing.py
def get_slice_bboxes(
    image_height: int,
    image_width: int,
    slice_height: int | None = None,
    slice_width: int | None = None,
    auto_slice_resolution: bool | None = True,
    overlap_height_ratio: float | None = 0.2,
    overlap_width_ratio: float | None = 0.2,
) -> list[list[int]]:
    """Generate bounding boxes for slicing an image into crops.

    The function calculates the coordinates for each slice based on the provided
    image dimensions, slice size, and overlap ratios. If slice size is not provided
    and auto_slice_resolution is True, the function will automatically determine
    appropriate slice parameters.

    Args:
        image_height (int): Height of the original image.
        image_width (int): Width of the original image.
        slice_height (int, optional): Height of each slice. Default None.
        slice_width (int, optional): Width of each slice. Default None.
        overlap_height_ratio (float, optional): Fractional overlap in height of each
            slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
            overlap of 20 pixels). Default 0.2.
        overlap_width_ratio(float, optional): Fractional overlap in width of each
            slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
            overlap of 20 pixels). Default 0.2.
        auto_slice_resolution (bool, optional): if not set slice parameters such as slice_height and slice_width,
            it enables automatically calculate these parameters from image resolution and orientation.

    Returns:
        List[List[int]]: List of 4 corner coordinates for each N slices.
            [
                [slice_0_left, slice_0_top, slice_0_right, slice_0_bottom],
                ...
                [slice_N_left, slice_N_top, slice_N_right, slice_N_bottom]
            ]
    """
    slice_bboxes = []
    y_max = y_min = 0

    if slice_height and slice_width:
        if overlap_height_ratio is not None and overlap_height_ratio >= 1.0:
            raise ValueError("Overlap ratio must be less than 1.0")
        if overlap_width_ratio is not None and overlap_width_ratio >= 1.0:
            raise ValueError("Overlap ratio must be less than 1.0")
        y_overlap = int((overlap_height_ratio if overlap_height_ratio is not None else 0.2) * slice_height)
        x_overlap = int((overlap_width_ratio if overlap_width_ratio is not None else 0.2) * slice_width)
    elif auto_slice_resolution:
        x_overlap, y_overlap, slice_width, slice_height = get_auto_slice_params(height=image_height, width=image_width)
    else:
        raise ValueError("Compute type is not auto and slice width and height are not provided.")

    while y_max < image_height:
        x_min = x_max = 0
        y_max = y_min + slice_height
        while x_max < image_width:
            x_max = x_min + slice_width
            if y_max > image_height or x_max > image_width:
                xmax = min(image_width, x_max)
                ymax = min(image_height, y_max)
                xmin = max(0, xmax - slice_width)
                ymin = max(0, ymax - slice_height)
                slice_bboxes.append([xmin, ymin, xmax, ymax])
            else:
                slice_bboxes.append([x_min, y_min, x_max, y_max])
            x_min = x_max - x_overlap
        y_min = y_max - y_overlap
    return slice_bboxes
process_coco_annotations(coco_annotation_list, slice_bbox, min_area_ratio)

Slices and filters given list of CocoAnnotation objects with given 'slice_bbox' and 'min_area_ratio'.

Parameters:

Name Type Description Default
slice_bbox List[int]

Generated from get_slice_bboxes. Format for each slice bbox: [x_min, y_min, x_max, y_max].

required
min_area_ratio float

If the cropped annotation area to original annotation ratio is smaller than this value, the annotation is filtered out. Default 0.1.

required

Returns:

Type Description
List[CocoAnnotation]

Sliced annotations.

Source code in sahi/slicing.py
def process_coco_annotations(
    coco_annotation_list: list[CocoAnnotation], slice_bbox: list[int], min_area_ratio
) -> list[CocoAnnotation]:
    """Slices and filters given list of CocoAnnotation objects with given 'slice_bbox' and 'min_area_ratio'.

    Args:
        coco_annotation_list (List[CocoAnnotation])
        slice_bbox (List[int]): Generated from `get_slice_bboxes`.
            Format for each slice bbox: [x_min, y_min, x_max, y_max].
        min_area_ratio (float): If the cropped annotation area to original
            annotation ratio is smaller than this value, the annotation is
            filtered out. Default 0.1.

    Returns:
        (List[CocoAnnotation]): Sliced annotations.
    """

    sliced_coco_annotation_list: list[CocoAnnotation] = []
    for coco_annotation in coco_annotation_list:
        if annotation_inside_slice(coco_annotation.json, slice_bbox):
            sliced_coco_annotation = coco_annotation.get_sliced_coco_annotation(slice_bbox)
            if sliced_coco_annotation.area / coco_annotation.area >= min_area_ratio:
                sliced_coco_annotation_list.append(sliced_coco_annotation)
    return sliced_coco_annotation_list
shift_bboxes(bboxes, offset)

Shift bboxes w.r.t offset.

Suppo

Parameters:

Name Type Description Default
bboxes (Tensor, ndarray, list)

The bboxes need to be translated. Its shape can be (n, 4), which means (x, y, x, y).

required
offset Sequence[int]

The translation offsets with shape of (2, ).

required

Returns: Tensor, np.ndarray, list: Shifted bboxes.

Source code in sahi/slicing.py
def shift_bboxes(bboxes, offset: Sequence[int]):
    """Shift bboxes w.r.t offset.

    Suppo

    Args:
        bboxes (Tensor, np.ndarray, list): The bboxes need to be translated. Its shape can
            be (n, 4), which means (x, y, x, y).
        offset (Sequence[int]): The translation offsets with shape of (2, ).
    Returns:
        Tensor, np.ndarray, list: Shifted bboxes.
    """
    shifted_bboxes = []

    if type(bboxes).__module__ == "torch":
        bboxes_is_torch_tensor = True
    else:
        bboxes_is_torch_tensor = False

    for bbox in bboxes:
        if bboxes_is_torch_tensor or isinstance(bbox, np.ndarray):
            bbox = bbox.tolist()
        bbox = BoundingBox(bbox, shift_amount=offset)
        bbox = bbox.get_shifted_box()
        shifted_bboxes.append(bbox.to_xyxy())

    if isinstance(bboxes, np.ndarray):
        return np.stack(shifted_bboxes, axis=0)
    elif bboxes_is_torch_tensor:
        return bboxes.new_tensor(shifted_bboxes)
    else:
        return shifted_bboxes
shift_masks(masks, offset, full_shape)

Shift masks to the original image.

Parameters:

Name Type Description Default
masks ndarray

masks that need to be shifted.

required
offset Sequence[int]

The offset to translate with shape of (2, ).

required
full_shape Sequence[int]

A (height, width) tuple of the huge image's shape.

required

Returns: np.ndarray: Shifted masks.

Source code in sahi/slicing.py
def shift_masks(masks: np.ndarray, offset: Sequence[int], full_shape: Sequence[int]) -> np.ndarray:
    """Shift masks to the original image.

    Args:
        masks (np.ndarray): masks that need to be shifted.
        offset (Sequence[int]): The offset to translate with shape of (2, ).
        full_shape (Sequence[int]): A (height, width) tuple of the huge image's shape.
    Returns:
        np.ndarray: Shifted masks.
    """
    # empty masks
    if masks is None:
        return masks

    shifted_masks = []
    for mask in masks:
        mask = Mask(segmentation=mask, shift_amount=offset, full_shape=full_shape)
        mask = mask.get_shifted_mask()
        shifted_masks.append(mask.bool_mask)

    return np.stack(shifted_masks, axis=0)
slice_coco(coco_annotation_file_path, image_dir, output_coco_annotation_file_name, output_dir=None, ignore_negative_samples=False, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2, min_area_ratio=0.1, out_ext=None, verbose=False, exif_fix=True)

Slice large images given in a directory, into smaller windows. If output_dir is given, export sliced images and coco file.

Parameters:

Name Type Description Default
coco_annotation_file_path str

Location of the coco annotation file

required
image_dir str

Base directory for the images

required
output_coco_annotation_file_name str

File name of the exported coco dataset json.

required
output_dir str

Output directory

None
ignore_negative_samples bool

If True, images without annotations are ignored. Defaults to False.

False
slice_height int

Height of each slice. Default 512.

512
slice_width int

Width of each slice. Default 512.

512
overlap_height_ratio float

Fractional overlap in height of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels). Default 0.2.

0.2
overlap_width_ratio float

Fractional overlap in width of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels). Default 0.2.

0.2
min_area_ratio float

If the cropped annotation area to original annotation ratio is smaller than this value, the annotation is filtered out. Default 0.1.

0.1
out_ext str

Extension of saved images. Default is the original suffix.

None
verbose bool

Switch to print relevant values to screen.

False
exif_fix bool

Whether to apply an EXIF fix to the image.

True

Returns:

Name Type Description
coco_dict list[dict | str]

dict COCO dict for sliced images and annotations

save_path list[dict | str]

str Path to the saved coco file

Source code in sahi/slicing.py
def slice_coco(
    coco_annotation_file_path: str,
    image_dir: str,
    output_coco_annotation_file_name: str,
    output_dir: str | None = None,
    ignore_negative_samples: bool | None = False,
    slice_height: int | None = 512,
    slice_width: int | None = 512,
    overlap_height_ratio: float | None = 0.2,
    overlap_width_ratio: float | None = 0.2,
    min_area_ratio: float | None = 0.1,
    out_ext: str | None = None,
    verbose: bool | None = False,
    exif_fix: bool = True,
) -> list[dict | str]:
    """Slice large images given in a directory, into smaller windows. If output_dir is given, export sliced images and
    coco file.

    Args:
        coco_annotation_file_path (str): Location of the coco annotation file
        image_dir (str): Base directory for the images
        output_coco_annotation_file_name (str): File name of the exported coco
            dataset json.
        output_dir (str, optional): Output directory
        ignore_negative_samples (bool, optional): If True, images without annotations
            are ignored. Defaults to False.
        slice_height (int, optional): Height of each slice. Default 512.
        slice_width (int, optional): Width of each slice. Default 512.
        overlap_height_ratio (float, optional): Fractional overlap in height of each
            slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
            overlap of 20 pixels). Default 0.2.
        overlap_width_ratio (float, optional): Fractional overlap in width of each
            slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
            overlap of 20 pixels). Default 0.2.
        min_area_ratio (float): If the cropped annotation area to original annotation
            ratio is smaller than this value, the annotation is filtered out. Default 0.1.
        out_ext (str, optional): Extension of saved images. Default is the
            original suffix.
        verbose (bool, optional): Switch to print relevant values to screen.
        exif_fix (bool, optional): Whether to apply an EXIF fix to the image.

    Returns:
        coco_dict: dict
            COCO dict for sliced images and annotations
        save_path: str
            Path to the saved coco file
    """

    # read coco file
    coco_dict: dict = load_json(coco_annotation_file_path)
    # create image_id_to_annotation_list mapping
    coco = Coco.from_coco_dict_or_path(coco_dict)
    # init sliced coco_utils.CocoImage list
    sliced_coco_images: list = []

    # iterate over images and slice
    for idx, coco_image in enumerate(tqdm(coco.images)):
        # get image path
        image_path: str = os.path.join(image_dir, coco_image.file_name)
        # get annotation json list corresponding to selected coco image
        # slice image
        try:
            slice_image_result = slice_image(
                image=image_path,
                coco_annotation_list=coco_image.annotations,
                output_file_name=f"{Path(coco_image.file_name).stem}_{idx}",
                output_dir=output_dir,
                slice_height=slice_height,
                slice_width=slice_width,
                overlap_height_ratio=overlap_height_ratio,
                overlap_width_ratio=overlap_width_ratio,
                min_area_ratio=min_area_ratio,
                out_ext=out_ext,
                verbose=verbose,
                exif_fix=exif_fix,
            )
            # append slice outputs
            sliced_coco_images.extend(slice_image_result.coco_images)
        except TopologicalError:
            logger.warning(f"Invalid annotation found, skipping this image: {image_path}")

    # create and save coco dict
    coco_dict = create_coco_dict(
        sliced_coco_images, coco_dict["categories"], ignore_negative_samples=ignore_negative_samples
    )
    save_path = ""
    if output_coco_annotation_file_name and output_dir:
        save_path = Path(output_dir) / (output_coco_annotation_file_name + "_coco.json")
        save_json(coco_dict, save_path)

    return coco_dict, save_path
slice_image(image, coco_annotation_list=None, output_file_name=None, output_dir=None, slice_height=None, slice_width=None, overlap_height_ratio=0.2, overlap_width_ratio=0.2, auto_slice_resolution=True, min_area_ratio=0.1, out_ext=None, verbose=False, exif_fix=True)

Slice a large image into smaller windows. If output_file_name and output_dir is given, export sliced images.

Parameters:

Name Type Description Default
image str or Image

File path of image or Pillow Image to be sliced.

required
coco_annotation_list List[CocoAnnotation]

List of CocoAnnotation objects.

None
output_file_name str

Root name of output files (coordinates will be appended to this)

None
output_dir str

Output directory

None
slice_height int

Height of each slice. Default None.

None
slice_width int

Width of each slice. Default None.

None
overlap_height_ratio float

Fractional overlap in height of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels). Default 0.2.

0.2
overlap_width_ratio float

Fractional overlap in width of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels). Default 0.2.

0.2
auto_slice_resolution bool

if not set slice parameters such as slice_height and slice_width, it enables automatically calculate these params from image resolution and orientation.

True
min_area_ratio float

If the cropped annotation area to original annotation ratio is smaller than this value, the annotation is filtered out. Default 0.1.

0.1
out_ext str

Extension of saved images. Default is the original suffix for lossless image formats and png for lossy formats ('.jpg','.jpeg').

None
verbose bool

Switch to print relevant values to screen. Default 'False'.

False
exif_fix bool

Whether to apply an EXIF fix to the image.

True

Returns:

Name Type Description
sliced_image_result SliceImageResult

SliceImageResult: sliced_image_list: list of SlicedImage image_dir: str Directory of the sliced image exports. original_image_size: list of int Size of the unsliced original image in [height, width]

Source code in sahi/slicing.py
def slice_image(
    image: str | Image.Image,
    coco_annotation_list: list[CocoAnnotation] | None = None,
    output_file_name: str | None = None,
    output_dir: str | None = None,
    slice_height: int | None = None,
    slice_width: int | None = None,
    overlap_height_ratio: float | None = 0.2,
    overlap_width_ratio: float | None = 0.2,
    auto_slice_resolution: bool | None = True,
    min_area_ratio: float | None = 0.1,
    out_ext: str | None = None,
    verbose: bool | None = False,
    exif_fix: bool = True,
) -> SliceImageResult:
    """Slice a large image into smaller windows. If output_file_name and output_dir is given, export sliced images.

    Args:
        image (str or PIL.Image): File path of image or Pillow Image to be sliced.
        coco_annotation_list (List[CocoAnnotation], optional): List of CocoAnnotation objects.
        output_file_name (str, optional): Root name of output files (coordinates will
            be appended to this)
        output_dir (str, optional): Output directory
        slice_height (int, optional): Height of each slice. Default None.
        slice_width (int, optional): Width of each slice. Default None.
        overlap_height_ratio (float, optional): Fractional overlap in height of each
            slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
            overlap of 20 pixels). Default 0.2.
        overlap_width_ratio (float, optional): Fractional overlap in width of each
            slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
            overlap of 20 pixels). Default 0.2.
        auto_slice_resolution (bool, optional): if not set slice parameters such as slice_height and slice_width,
            it enables automatically calculate these params from image resolution and orientation.
        min_area_ratio (float, optional): If the cropped annotation area to original annotation
            ratio is smaller than this value, the annotation is filtered out. Default 0.1.
        out_ext (str, optional): Extension of saved images. Default is the
            original suffix for lossless image formats and png for lossy formats ('.jpg','.jpeg').
        verbose (bool, optional): Switch to print relevant values to screen.
            Default 'False'.
        exif_fix (bool): Whether to apply an EXIF fix to the image.

    Returns:
        sliced_image_result: SliceImageResult:
                                sliced_image_list: list of SlicedImage
                                image_dir: str
                                    Directory of the sliced image exports.
                                original_image_size: list of int
                                    Size of the unsliced original image in [height, width]
    """

    # define verboseprint
    verboselog = logger.info if verbose else lambda *a, **k: None

    def _export_single_slice(image: np.ndarray, output_dir: str, slice_file_name: str):
        image_pil = read_image_as_pil(image, exif_fix=exif_fix)
        slice_file_path = str(Path(output_dir) / slice_file_name)
        # export sliced image
        image_pil.save(slice_file_path)
        image_pil.close()  # to fix https://github.com/obss/sahi/issues/565
        verboselog("sliced image path: " + slice_file_path)

    # create outdir if not present
    if output_dir is not None:
        Path(output_dir).mkdir(parents=True, exist_ok=True)

    # read image
    image_pil = read_image_as_pil(image, exif_fix=exif_fix)
    verboselog("image.shape: " + str(image_pil.size))

    image_width, image_height = image_pil.size
    if not (image_width != 0 and image_height != 0):
        raise RuntimeError(f"invalid image size: {image_pil.size} for 'slice_image'.")
    slice_bboxes = get_slice_bboxes(
        image_height=image_height,
        image_width=image_width,
        auto_slice_resolution=auto_slice_resolution,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
    )

    n_ims = 0

    # init images and annotations lists
    sliced_image_result = SliceImageResult(original_image_size=[image_height, image_width], image_dir=output_dir)

    image_pil_arr = np.asarray(image_pil)
    # iterate over slices
    for slice_bbox in slice_bboxes:
        n_ims += 1

        # extract image
        tlx = slice_bbox[0]
        tly = slice_bbox[1]
        brx = slice_bbox[2]
        bry = slice_bbox[3]
        image_pil_slice = image_pil_arr[tly:bry, tlx:brx]

        # set image file suffixes
        slice_suffixes = "_".join(map(str, slice_bbox))
        if out_ext:
            suffix = out_ext
        elif hasattr(image_pil, "filename"):
            suffix = Path(getattr(image_pil, "filename")).suffix
            if suffix in IMAGE_EXTENSIONS_LOSSY:
                suffix = ".png"
            elif suffix in IMAGE_EXTENSIONS_LOSSLESS:
                suffix = Path(image_pil.filename).suffix
        else:
            suffix = ".png"

        # set image file name and path
        slice_file_name = f"{output_file_name}_{slice_suffixes}{suffix}"

        # create coco image
        slice_width = slice_bbox[2] - slice_bbox[0]
        slice_height = slice_bbox[3] - slice_bbox[1]
        coco_image = CocoImage(file_name=slice_file_name, height=slice_height, width=slice_width)

        # append coco annotations (if present) to coco image
        if coco_annotation_list is not None:
            for sliced_coco_annotation in process_coco_annotations(coco_annotation_list, slice_bbox, min_area_ratio):
                coco_image.add_annotation(sliced_coco_annotation)

        # create sliced image and append to sliced_image_result
        sliced_image = SlicedImage(
            image=image_pil_slice, coco_image=coco_image, starting_pixel=[slice_bbox[0], slice_bbox[1]]
        )
        sliced_image_result.add_sliced_image(sliced_image)

    # export slices if output directory is provided
    if output_file_name and output_dir:
        # Use a context-managed ThreadPoolExecutor for clean shutdown and
        # limit workers based on CPU count to avoid oversubscription.
        max_workers = min(MAX_WORKERS, len(sliced_image_result))
        max_workers = max(1, max_workers)
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            # map will schedule tasks and wait for completion when the context exits
            list(
                executor.map(
                    _export_single_slice,
                    sliced_image_result.images,
                    [output_dir] * len(sliced_image_result),
                    sliced_image_result.filenames,
                )
            )

    verboselog(
        "Num slices: " + str(n_ims) + " slice_height: " + str(slice_height) + " slice_width: " + str(slice_width)
    )

    return sliced_image_result

utils

Modules
coco
Classes
Coco
Source code in sahi/utils/coco.py
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
class Coco:
    def __init__(
        self,
        name: str | None = None,
        image_dir: str | None = None,
        remapping_dict: dict[int, int] | None = None,
        ignore_negative_samples: bool = False,
        clip_bboxes_to_img_dims: bool = False,
        image_id_setting: Literal["auto", "manual"] = "auto",
    ):
        """Creates Coco object.

        Args:
            name: str
                Name of the Coco dataset, it determines exported json name.
            image_dir: str
                Base file directory that contains dataset images. Required for dataset merging.
            remapping_dict: dict
                {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1
            ignore_negative_samples: bool
                If True ignores images without annotations in all operations.
            image_id_setting: str
                how to assign image ids while exporting can be
                auto -> will assign id from scratch (<CocoImage>.id will be ignored)
                manual -> you will need to provide image ids in <CocoImage> instances (<CocoImage>.id can not be None)
        """
        if image_id_setting not in ["auto", "manual"]:
            raise ValueError("image_id_setting must be either 'auto' or 'manual'")
        self.name: str | None = name
        self.image_dir: str | None = image_dir
        self.remapping_dict: dict[int, int] | None = remapping_dict
        self.ignore_negative_samples = ignore_negative_samples
        self.categories: list[CocoCategory] = []
        self.images = []
        self._stats = None
        self.clip_bboxes_to_img_dims = clip_bboxes_to_img_dims
        self.image_id_setting = image_id_setting

    def add_categories_from_coco_category_list(self, coco_category_list):
        """Creates CocoCategory object using coco category list.

        Args:
            coco_category_list: List[Dict]
                [
                    {"supercategory": "person", "id": 1, "name": "person"},
                    {"supercategory": "vehicle", "id": 2, "name": "bicycle"}
                ]
        """

        for coco_category in coco_category_list:
            if self.remapping_dict is not None:
                for source_id in self.remapping_dict.keys():
                    if coco_category["id"] == source_id:
                        target_id = self.remapping_dict[source_id]
                        coco_category["id"] = target_id

            self.add_category(CocoCategory.from_coco_category(coco_category))

    def add_category(self, category):
        """Adds category to this Coco instance.

        Args:
            category: CocoCategory
        """

        # assert type(category) == CocoCategory, "category must be a CocoCategory instance"
        if not isinstance(category, CocoCategory):
            raise TypeError("category must be a CocoCategory instance")
        self.categories.append(category)

    def add_image(self, image):
        """Adds image to this Coco instance.

        Args:
            image: CocoImage
        """

        if self.image_id_setting == "manual" and image.id is None:
            raise ValueError("image id should be manually set for image_id_setting='manual'")
        self.images.append(image)

    def update_categories(self, desired_name2id: dict[str, int], update_image_filenames: bool = False):
        """Rearranges category mapping of given COCO object based on given desired_name2id. Can also be used to filter
        some of the categories.

        Args:
            desired_name2id: dict
                {"big_vehicle": 1, "car": 2, "human": 3}
            update_image_filenames: bool
                If True, updates coco image file_names with absolute file paths.
        """
        # init vars
        currentid2desiredid_mapping: dict[int, int | None] = {}
        updated_coco = Coco(
            name=self.name,
            image_dir=self.image_dir,
            remapping_dict=self.remapping_dict,
            ignore_negative_samples=self.ignore_negative_samples,
        )
        # create category id mapping (currentid2desiredid_mapping)
        for coco_category in self.categories:
            current_category_id = coco_category.id
            current_category_name = coco_category.name
            if not current_category_name:
                logger.warning("no category name provided to update categories")
                continue
            if current_category_name in desired_name2id.keys():
                currentid2desiredid_mapping[current_category_id] = desired_name2id[current_category_name]
            else:
                # ignore categories that are not included in desired_name2id
                currentid2desiredid_mapping[current_category_id] = None

        # add updated categories
        for name in desired_name2id.keys():
            updated_coco_category = CocoCategory(id=desired_name2id[name], name=name, supercategory=name)
            updated_coco.add_category(updated_coco_category)

        # add updated images & annotations
        for coco_image in copy.deepcopy(self.images):
            updated_coco_image = CocoImage.from_coco_image_dict(coco_image.json)
            # update filename to abspath
            file_name_is_abspath = True if os.path.abspath(coco_image.file_name) == coco_image.file_name else False
            if update_image_filenames and not file_name_is_abspath:
                if not self.image_dir:
                    logger.error("image directory not set")
                else:
                    updated_coco_image.file_name = str(Path(os.path.abspath(self.image_dir)) / coco_image.file_name)
            # update annotations
            for coco_annotation in coco_image.annotations:
                current_category_id = coco_annotation.category_id
                desired_category_id = currentid2desiredid_mapping[current_category_id]
                # append annotations with category id present in desired_name2id
                if desired_category_id is not None:
                    # update cetegory id
                    coco_annotation.category_id = desired_category_id
                    # append updated annotation to target coco dict
                    updated_coco_image.add_annotation(coco_annotation)
            updated_coco.add_image(updated_coco_image)

        # overwrite instance
        self.__dict__ = updated_coco.__dict__

    def merge(self, coco, desired_name2id=None, verbose=1):
        """Combines the images/annotations/categories of given coco object with current one.

        Args:
            coco : sahi.utils.coco.Coco instance
                A COCO dataset object
            desired_name2id : dict
                {"human": 1, "car": 2, "big_vehicle": 3}
            verbose: bool
                If True, merging info is printed
        """
        if self.image_dir is None or coco.image_dir is None:
            raise ValueError("image_dir should be provided for merging.")
        if verbose:
            if not desired_name2id:
                print("'desired_name2id' is not specified, combining all categories.")

        # create desired_name2id by combining all categories, if desired_name2id is not specified
        coco1 = self
        coco2 = coco
        category_ind = 0
        if desired_name2id is None:
            desired_name2id = {}
            for coco in [coco1, coco2]:
                temp_categories = copy.deepcopy(coco.json_categories)
                for temp_category in temp_categories:
                    if temp_category["name"] not in desired_name2id:
                        desired_name2id[temp_category["name"]] = category_ind
                        category_ind += 1
                    else:
                        continue

        # update categories and image paths
        for coco in [coco1, coco2]:
            coco.update_categories(desired_name2id=desired_name2id, update_image_filenames=True)

        # combine images and categories
        coco1.images.extend(coco2.images)
        self.images: list[CocoImage] = coco1.images
        self.categories = coco1.categories

        # print categories
        if verbose:
            print(
                "Categories are formed as:\n",
                self.json_categories,
            )

    @classmethod
    def from_coco_dict_or_path(
        cls,
        coco_dict_or_path: dict | str,
        image_dir: str | None = None,
        remapping_dict: dict | None = None,
        ignore_negative_samples: bool = False,
        clip_bboxes_to_img_dims: bool = False,
        use_threads: bool = False,
        num_threads: int = 10,
    ):
        """Creates coco object from COCO formatted dict or COCO dataset file path.

        Args:
            coco_dict_or_path: dict/str or List[dict/str]
                COCO formatted dict or COCO dataset file path
                List of COCO formatted dict or COCO dataset file path
            image_dir: str
                Base file directory that contains dataset images. Required for merging and yolov5 conversion.
            remapping_dict: dict
                {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1
            ignore_negative_samples: bool
                If True ignores images without annotations in all operations.
            clip_bboxes_to_img_dims: bool = False
                Limits bounding boxes to image dimensions.
            use_threads: bool = False
                Use threads when processing the json image list, defaults to False
            num_threads: int = 10
                Slice the image list to given number of chunks, defaults to 10

        Properties:
            images: list of CocoImage
            category_mapping: dict
        """
        # init coco object
        coco = cls(
            image_dir=image_dir,
            remapping_dict=remapping_dict,
            ignore_negative_samples=ignore_negative_samples,
            clip_bboxes_to_img_dims=clip_bboxes_to_img_dims,
        )

        if type(coco_dict_or_path) not in [str, dict]:
            raise TypeError("coco_dict_or_path should be a dict or str")

        # load coco dict if path is given
        if isinstance(coco_dict_or_path, str):
            coco_dict = load_json(coco_dict_or_path)
        else:
            coco_dict = coco_dict_or_path

        dict_size = len(coco_dict["images"])

        # arrange image id to annotation id mapping
        coco.add_categories_from_coco_category_list(coco_dict["categories"])
        image_id_to_annotation_list = get_imageid2annotationlist_mapping(coco_dict)
        category_mapping = coco.category_mapping

        # https://github.com/obss/sahi/issues/98
        image_id_set: set = set()

        lock = Lock()

        def fill_image_id_set(start, finish, image_list, _image_id_set, _image_id_to_annotation_list, _coco, lock):
            for coco_image_dict in tqdm(
                image_list[start:finish], f"Loading coco annotations between {start} and {finish}"
            ):
                coco_image = CocoImage.from_coco_image_dict(coco_image_dict)
                image_id = coco_image_dict["id"]
                # https://github.com/obss/sahi/issues/98
                if image_id in _image_id_set:
                    print(f"duplicate image_id: {image_id}, will be ignored.")
                    continue
                else:
                    lock.acquire()
                    _image_id_set.add(image_id)
                    lock.release()

                # select annotations of the image
                annotation_list = _image_id_to_annotation_list[image_id]
                for coco_annotation_dict in annotation_list:
                    # apply category remapping if remapping_dict is provided
                    if _coco.remapping_dict is not None:
                        # apply category remapping (id:id)
                        category_id = _coco.remapping_dict[coco_annotation_dict["category_id"]]
                        # update category id
                        coco_annotation_dict["category_id"] = category_id
                    else:
                        category_id = coco_annotation_dict["category_id"]
                    # get category name (id:name)
                    category_name = category_mapping[category_id]
                    coco_annotation = CocoAnnotation.from_coco_annotation_dict(
                        category_name=category_name, annotation_dict=coco_annotation_dict
                    )
                    coco_image.add_annotation(coco_annotation)
                _coco.add_image(coco_image)

        chunk_size = dict_size / num_threads

        if use_threads is True:
            for i in range(num_threads):
                start = i * chunk_size
                finish = start + chunk_size
                if finish > dict_size:
                    finish = dict_size
                t = Thread(
                    target=fill_image_id_set,
                    args=(start, finish, coco_dict["images"], image_id_set, image_id_to_annotation_list, coco, lock),
                )
                t.start()

            main_thread = threading.currentThread()
            for t in threading.enumerate():
                if t is not main_thread:
                    t.join()

        else:
            for coco_image_dict in tqdm(coco_dict["images"], "Loading coco annotations"):
                coco_image = CocoImage.from_coco_image_dict(coco_image_dict)
                image_id = coco_image_dict["id"]
                # https://github.com/obss/sahi/issues/98
                if image_id in image_id_set:
                    print(f"duplicate image_id: {image_id}, will be ignored.")
                    continue
                else:
                    image_id_set.add(image_id)
                # select annotations of the image
                annotation_list = image_id_to_annotation_list[image_id]
                # TODO: coco_annotation_dict is of type CocoAnnotation according to how image_id_to_annotation_list
                # was created. Either image_id_to_annotation_list is not defined correctly or the following
                # loop is wrong as it expects a dict.
                for coco_annotation_dict in annotation_list:
                    # apply category remapping if remapping_dict is provided
                    if coco.remapping_dict is not None:
                        # apply category remapping (id:id)
                        category_id = coco.remapping_dict[coco_annotation_dict["category_id"]]
                        # update category id
                        coco_annotation_dict["category_id"] = category_id
                    else:
                        category_id = coco_annotation_dict["category_id"]
                    # get category name (id:name)
                    category_name = category_mapping[category_id]
                    coco_annotation = CocoAnnotation.from_coco_annotation_dict(
                        category_name=category_name, annotation_dict=coco_annotation_dict
                    )
                    coco_image.add_annotation(coco_annotation)
                coco.add_image(coco_image)

        if clip_bboxes_to_img_dims:
            coco = coco.get_coco_with_clipped_bboxes()
        return coco

    @property
    def json_categories(self):
        categories = []
        for category in self.categories:
            categories.append(category.json)
        return categories

    @property
    def category_mapping(self):
        category_mapping = {}
        for category in self.categories:
            category_mapping[category.id] = category.name
        return category_mapping

    @property
    def json(self):
        return create_coco_dict(
            images=self.images,
            categories=self.json_categories,
            ignore_negative_samples=self.ignore_negative_samples,
            image_id_setting=self.image_id_setting,
        )

    @property
    def prediction_array(self):
        return create_coco_prediction_array(
            images=self.images,
            ignore_negative_samples=self.ignore_negative_samples,
            image_id_setting=self.image_id_setting,
        )

    @property
    def stats(self):
        if not self._stats:
            self.calculate_stats()
        return self._stats

    def calculate_stats(self):
        """Iterates over all annotations and calculates total number of."""
        # init all stats
        num_annotations = 0
        num_images = len(self.images)
        num_negative_images = 0
        num_categories = len(self.json_categories)
        category_name_to_zero = {category["name"]: 0 for category in self.json_categories}
        category_name_to_inf = {category["name"]: float("inf") for category in self.json_categories}
        num_images_per_category = copy.deepcopy(category_name_to_zero)
        num_annotations_per_category = copy.deepcopy(category_name_to_zero)
        min_annotation_area_per_category = copy.deepcopy(category_name_to_inf)
        max_annotation_area_per_category = copy.deepcopy(category_name_to_zero)
        min_num_annotations_in_image = float("inf")
        max_num_annotations_in_image = 0
        total_annotation_area = 0
        min_annotation_area = 1e10
        max_annotation_area = 0
        for image in self.images:
            image_contains_category = {}
            for annotation in image.annotations:
                annotation_area = annotation.area
                total_annotation_area += annotation_area
                num_annotations_per_category[annotation.category_name] += 1
                image_contains_category[annotation.category_name] = 1
                # update min&max annotation area
                if annotation_area > max_annotation_area:
                    max_annotation_area = annotation_area
                if annotation_area < min_annotation_area:
                    min_annotation_area = annotation_area
                if annotation_area > max_annotation_area_per_category[annotation.category_name]:
                    max_annotation_area_per_category[annotation.category_name] = annotation_area
                if annotation_area < min_annotation_area_per_category[annotation.category_name]:
                    min_annotation_area_per_category[annotation.category_name] = annotation_area
            # update num_negative_images
            if len(image.annotations) == 0:
                num_negative_images += 1
            # update num_annotations
            num_annotations += len(image.annotations)
            # update num_images_per_category
            num_images_per_category = dict(Counter(num_images_per_category) + Counter(image_contains_category))
            # update min&max_num_annotations_in_image
            num_annotations_in_image = len(image.annotations)
            if num_annotations_in_image > max_num_annotations_in_image:
                max_num_annotations_in_image = num_annotations_in_image
            if num_annotations_in_image < min_num_annotations_in_image:
                min_num_annotations_in_image = num_annotations_in_image
        if (num_images - num_negative_images) > 0:
            avg_num_annotations_in_image = num_annotations / (num_images - num_negative_images)
            avg_annotation_area = total_annotation_area / num_annotations
        else:
            avg_num_annotations_in_image = 0
            avg_annotation_area = 0

        self._stats = {
            "num_images": num_images,
            "num_annotations": num_annotations,
            "num_categories": num_categories,
            "num_negative_images": num_negative_images,
            "num_images_per_category": num_images_per_category,
            "num_annotations_per_category": num_annotations_per_category,
            "min_num_annotations_in_image": min_num_annotations_in_image,
            "max_num_annotations_in_image": max_num_annotations_in_image,
            "avg_num_annotations_in_image": avg_num_annotations_in_image,
            "min_annotation_area": min_annotation_area,
            "max_annotation_area": max_annotation_area,
            "avg_annotation_area": avg_annotation_area,
            "min_annotation_area_per_category": min_annotation_area_per_category,
            "max_annotation_area_per_category": max_annotation_area_per_category,
        }

    def split_coco_as_train_val(self, train_split_rate=0.9, numpy_seed=0):
        """Split images into train-val and returns them as sahi.utils.coco.Coco objects.

        Args:
            train_split_rate: float
            numpy_seed: int
                random seed. Actually, this doesn't use numpy, but the random package
                from the standard library, but it is called numpy for compatibility.

        Returns:
            result : dict
                {
                    "train_coco": "",
                    "val_coco": "",
                }
        """
        # divide images
        num_images = len(self.images)
        shuffled_images = copy.deepcopy(self.images)
        random.seed(numpy_seed)
        random.shuffle(shuffled_images)
        num_train = int(num_images * train_split_rate)
        train_images = shuffled_images[:num_train]
        val_images = shuffled_images[num_train:]

        # form train val coco objects
        train_coco = Coco(
            name=self.name if self.name else "split" + "_train",
            image_dir=self.image_dir,
        )
        train_coco.images = train_images
        train_coco.categories = self.categories

        val_coco = Coco(name=self.name if self.name else "split" + "_val", image_dir=self.image_dir)
        val_coco.images = val_images
        val_coco.categories = self.categories

        # return result
        return {
            "train_coco": train_coco,
            "val_coco": val_coco,
        }

    def export_as_yolov5(
        self,
        output_dir: str | Path,
        train_split_rate: float = 1.0,
        numpy_seed: int = 0,
        mp: bool = False,
        disable_symlink: bool = False,
    ):
        """Deprecated.

        Please use export_as_yolo instead. Calls export_as_yolo with the same arguments.
        """
        warnings.warn(
            "export_as_yolov5 is deprecated. Please use export_as_yolo instead.",
            DeprecationWarning,
        )
        self.export_as_yolo(
            output_dir=output_dir,
            train_split_rate=train_split_rate,
            numpy_seed=numpy_seed,
            mp=mp,
            disable_symlink=disable_symlink,
        )

    def export_as_yolo(
        self,
        output_dir: str | Path,
        train_split_rate: float = 1.0,
        numpy_seed: int = 0,
        mp: bool = False,
        disable_symlink: bool = False,
    ):
        """Exports current COCO dataset in ultralytics/yolo format. Creates train val folders with image symlinks and
        txt files and a data yaml file.

        Args:
            output_dir: str
                Export directory.
            train_split_rate: float
                If given 1, will be exported as train split.
                If given 0, will be exported as val split.
                If in between 0-1, both train/val splits will be calculated and exported.
            numpy_seed: int
                To fix the numpy seed.
            mp: bool
                If True, multiprocess mode is on.
                Should be called in 'if __name__ == __main__:' block.
            disable_symlink: bool
                If True, symlinks will not be created. Instead, images will be copied.
        """
        try:
            import yaml
        except ImportError:
            raise ImportError('Please run "pip install -U pyyaml" to install yaml first for yolo formatted exporting.')

        # set split_mode
        if 0 < train_split_rate and train_split_rate < 1:
            split_mode = "TRAINVAL"
        elif train_split_rate == 0:
            split_mode = "VAL"
        elif train_split_rate == 1:
            split_mode = "TRAIN"
        else:
            raise ValueError("train_split_rate cannot be <0 or >1")

        # split dataset
        if split_mode == "TRAINVAL":
            result = self.split_coco_as_train_val(
                train_split_rate=train_split_rate,
                numpy_seed=numpy_seed,
            )
            train_coco = result["train_coco"]
            val_coco = result["val_coco"]
        elif split_mode == "TRAIN":
            train_coco = self
            val_coco = None
        elif split_mode == "VAL":
            train_coco = None
            val_coco = self

        # create train val image dirs
        train_dir = ""
        val_dir = ""
        if split_mode in ["TRAINVAL", "TRAIN"]:
            train_dir = Path(os.path.abspath(output_dir)) / "train/"
            train_dir.mkdir(parents=True, exist_ok=True)  # create dir
        if split_mode in ["TRAINVAL", "VAL"]:
            val_dir = Path(os.path.abspath(output_dir)) / "val/"
            val_dir.mkdir(parents=True, exist_ok=True)  # create dir

        # create image symlinks and annotation txts
        if split_mode in ["TRAINVAL", "TRAIN"]:
            export_yolo_images_and_txts_from_coco_object(
                output_dir=train_dir,
                coco=train_coco,
                ignore_negative_samples=self.ignore_negative_samples,
                mp=mp,
                disable_symlink=disable_symlink,
            )
        if split_mode in ["TRAINVAL", "VAL"]:
            export_yolo_images_and_txts_from_coco_object(
                output_dir=val_dir,
                coco=val_coco,
                ignore_negative_samples=self.ignore_negative_samples,
                mp=mp,
                disable_symlink=disable_symlink,
            )

        # create yolov5 data yaml
        data = {
            "train": str(train_dir),
            "val": str(val_dir),
            "nc": len(self.category_mapping),
            "names": list(self.category_mapping.values()),
        }
        yaml_path = str(Path(output_dir) / "data.yml")
        with open(yaml_path, "w") as outfile:
            yaml.dump(data, outfile, default_flow_style=None)

    def get_subsampled_coco(self, subsample_ratio: int = 2, category_id: int | None = None):
        """Subsamples images with subsample_ratio and returns as sahi.utils.coco.Coco object.

        Args:
            subsample_ratio: int
                10 means take every 10th image with its annotations
            category_id: int
                subsample only images containing given category_id, if -1 then subsamples negative samples
        Returns:
            subsampled_coco: sahi.utils.coco.Coco
        """
        subsampled_coco = Coco(
            name=self.name,
            image_dir=self.image_dir,
            remapping_dict=self.remapping_dict,
            ignore_negative_samples=self.ignore_negative_samples,
        )
        subsampled_coco.add_categories_from_coco_category_list(self.json_categories)

        if category_id is not None:
            # get images that contain given category id
            images_that_contain_category: list[CocoImage] = []
            annotation: CocoAnnotation
            for image in self.images:
                category_id_to_contains = defaultdict(int)
                for annotation in image.annotations:
                    category_id_to_contains[annotation.category_id] = 1
                if category_id_to_contains[category_id]:
                    add_this_image = True
                elif category_id == -1 and len(image.annotations) == 0:
                    # if category_id is given as -1, select negative samples
                    add_this_image = True
                else:
                    add_this_image = False

                if add_this_image:
                    images_that_contain_category.append(image)

            # get images that does not contain given category id
            images_that_doesnt_contain_category: list[CocoImage] = []
            for image in self.images:
                category_id_to_contains = defaultdict(int)
                for annotation in image.annotations:
                    category_id_to_contains[annotation.category_id] = 1
                if category_id_to_contains[category_id]:
                    add_this_image = False
                elif category_id == -1 and len(image.annotations) == 0:
                    # if category_id is given as -1, dont select negative samples
                    add_this_image = False
                else:
                    add_this_image = True

                if add_this_image:
                    images_that_doesnt_contain_category.append(image)

        if category_id:
            selected_images = images_that_contain_category
            # add images that does not contain given category without subsampling
            for image_ind in range(len(images_that_doesnt_contain_category)):
                subsampled_coco.add_image(images_that_doesnt_contain_category[image_ind])
        else:
            selected_images = self.images
        for image_ind in range(0, len(selected_images), subsample_ratio):
            subsampled_coco.add_image(selected_images[image_ind])

        return subsampled_coco

    def get_upsampled_coco(self, upsample_ratio: int = 2, category_id: int | None = None):
        """Upsamples images with upsample_ratio and returns as sahi.utils.coco.Coco object.

        Args:
            upsample_ratio: int
                10 means copy each sample 10 times
            category_id: int
                upsample only images containing given category_id, if -1 then upsamples negative samples
        Returns:
            upsampled_coco: sahi.utils.coco.Coco
        """
        upsampled_coco = Coco(
            name=self.name,
            image_dir=self.image_dir,
            remapping_dict=self.remapping_dict,
            ignore_negative_samples=self.ignore_negative_samples,
        )
        upsampled_coco.add_categories_from_coco_category_list(self.json_categories)
        for ind in range(upsample_ratio):
            for image_ind in range(len(self.images)):
                # calculate add_this_image
                if category_id is not None:
                    category_id_to_contains = defaultdict(int)
                    annotation: CocoAnnotation
                    for annotation in self.images[image_ind].annotations:
                        category_id_to_contains[annotation.category_id] = 1
                    if category_id_to_contains[category_id]:
                        add_this_image = True
                    elif category_id == -1 and len(self.images[image_ind].annotations) == 0:
                        # if category_id is given as -1, select negative samples
                        add_this_image = True
                    elif ind == 0:
                        # in first iteration add all images
                        add_this_image = True
                    else:
                        add_this_image = False
                else:
                    add_this_image = True

                if add_this_image:
                    upsampled_coco.add_image(self.images[image_ind])

        return upsampled_coco

    def get_area_filtered_coco(self, min=0, max_val=float("inf"), intervals_per_category=None):
        """Filters annotation areas with given min and max values and returns remaining images as sahi.utils.coco.Coco
        object.

        Args:
            min: int
                minimum allowed area
            max_val: int
                maximum allowed area
            intervals_per_category: dict of dicts
                {
                    "human": {"min": 20, "max": 10000},
                    "vehicle": {"min": 50, "max": 15000},
                }
        Returns:
            area_filtered_coco: sahi.utils.coco.Coco
        """
        area_filtered_coco = Coco(
            name=self.name,
            image_dir=self.image_dir,
            remapping_dict=self.remapping_dict,
            ignore_negative_samples=self.ignore_negative_samples,
        )
        area_filtered_coco.add_categories_from_coco_category_list(self.json_categories)
        for image in self.images:
            is_valid_image = True
            for annotation in image.annotations:
                if intervals_per_category is not None and annotation.category_name in intervals_per_category.keys():
                    category_based_min = intervals_per_category[annotation.category_name]["min"]
                    category_based_max = intervals_per_category[annotation.category_name]["max"]
                    if annotation.area < category_based_min or annotation.area > category_based_max:
                        is_valid_image = False
                if annotation.area < min or annotation.area > max_val:
                    is_valid_image = False
            if is_valid_image:
                area_filtered_coco.add_image(image)

        return area_filtered_coco

    def get_coco_with_clipped_bboxes(self):
        """Limits overflowing bounding boxes to image dimensions."""
        from sahi.slicing import annotation_inside_slice

        coco = Coco(
            name=self.name,
            image_dir=self.image_dir,
            remapping_dict=self.remapping_dict,
            ignore_negative_samples=self.ignore_negative_samples,
        )
        coco.add_categories_from_coco_category_list(self.json_categories)

        for coco_img in self.images:
            img_dims = [0, 0, coco_img.width, coco_img.height]
            coco_image = CocoImage(
                file_name=coco_img.file_name, height=coco_img.height, width=coco_img.width, id=coco_img.id
            )
            for coco_ann in coco_img.annotations:
                ann_dict: dict = coco_ann.json
                if annotation_inside_slice(annotation=ann_dict, slice_bbox=img_dims):
                    shapely_ann = coco_ann.get_sliced_coco_annotation(img_dims)
                    bbox = ShapelyAnnotation.to_xywh(shapely_ann._shapely_annotation)
                    coco_ann_from_shapely = CocoAnnotation(
                        bbox=bbox,
                        category_id=coco_ann.category_id,
                        category_name=coco_ann.category_name,
                        image_id=coco_ann.image_id,
                    )
                    coco_image.add_annotation(coco_ann_from_shapely)
                else:
                    continue
            coco.add_image(coco_image)
        return coco
Functions
__init__(name=None, image_dir=None, remapping_dict=None, ignore_negative_samples=False, clip_bboxes_to_img_dims=False, image_id_setting='auto')

Creates Coco object.

Parameters:

Name Type Description Default
name str | None

str Name of the Coco dataset, it determines exported json name.

None
image_dir str | None

str Base file directory that contains dataset images. Required for dataset merging.

None
remapping_dict dict[int, int] | None

dict {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1

None
ignore_negative_samples bool

bool If True ignores images without annotations in all operations.

False
image_id_setting Literal['auto', 'manual']

str how to assign image ids while exporting can be auto -> will assign id from scratch (.id will be ignored) manual -> you will need to provide image ids in instances (.id can not be None)

'auto'
Source code in sahi/utils/coco.py
def __init__(
    self,
    name: str | None = None,
    image_dir: str | None = None,
    remapping_dict: dict[int, int] | None = None,
    ignore_negative_samples: bool = False,
    clip_bboxes_to_img_dims: bool = False,
    image_id_setting: Literal["auto", "manual"] = "auto",
):
    """Creates Coco object.

    Args:
        name: str
            Name of the Coco dataset, it determines exported json name.
        image_dir: str
            Base file directory that contains dataset images. Required for dataset merging.
        remapping_dict: dict
            {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1
        ignore_negative_samples: bool
            If True ignores images without annotations in all operations.
        image_id_setting: str
            how to assign image ids while exporting can be
            auto -> will assign id from scratch (<CocoImage>.id will be ignored)
            manual -> you will need to provide image ids in <CocoImage> instances (<CocoImage>.id can not be None)
    """
    if image_id_setting not in ["auto", "manual"]:
        raise ValueError("image_id_setting must be either 'auto' or 'manual'")
    self.name: str | None = name
    self.image_dir: str | None = image_dir
    self.remapping_dict: dict[int, int] | None = remapping_dict
    self.ignore_negative_samples = ignore_negative_samples
    self.categories: list[CocoCategory] = []
    self.images = []
    self._stats = None
    self.clip_bboxes_to_img_dims = clip_bboxes_to_img_dims
    self.image_id_setting = image_id_setting
add_categories_from_coco_category_list(coco_category_list)

Creates CocoCategory object using coco category list.

Parameters:

Name Type Description Default
coco_category_list

List[Dict] [ {"supercategory": "person", "id": 1, "name": "person"}, {"supercategory": "vehicle", "id": 2, "name": "bicycle"} ]

required
Source code in sahi/utils/coco.py
def add_categories_from_coco_category_list(self, coco_category_list):
    """Creates CocoCategory object using coco category list.

    Args:
        coco_category_list: List[Dict]
            [
                {"supercategory": "person", "id": 1, "name": "person"},
                {"supercategory": "vehicle", "id": 2, "name": "bicycle"}
            ]
    """

    for coco_category in coco_category_list:
        if self.remapping_dict is not None:
            for source_id in self.remapping_dict.keys():
                if coco_category["id"] == source_id:
                    target_id = self.remapping_dict[source_id]
                    coco_category["id"] = target_id

        self.add_category(CocoCategory.from_coco_category(coco_category))
add_category(category)

Adds category to this Coco instance.

Parameters:

Name Type Description Default
category

CocoCategory

required
Source code in sahi/utils/coco.py
def add_category(self, category):
    """Adds category to this Coco instance.

    Args:
        category: CocoCategory
    """

    # assert type(category) == CocoCategory, "category must be a CocoCategory instance"
    if not isinstance(category, CocoCategory):
        raise TypeError("category must be a CocoCategory instance")
    self.categories.append(category)
add_image(image)

Adds image to this Coco instance.

Parameters:

Name Type Description Default
image

CocoImage

required
Source code in sahi/utils/coco.py
def add_image(self, image):
    """Adds image to this Coco instance.

    Args:
        image: CocoImage
    """

    if self.image_id_setting == "manual" and image.id is None:
        raise ValueError("image id should be manually set for image_id_setting='manual'")
    self.images.append(image)
calculate_stats()

Iterates over all annotations and calculates total number of.

Source code in sahi/utils/coco.py
def calculate_stats(self):
    """Iterates over all annotations and calculates total number of."""
    # init all stats
    num_annotations = 0
    num_images = len(self.images)
    num_negative_images = 0
    num_categories = len(self.json_categories)
    category_name_to_zero = {category["name"]: 0 for category in self.json_categories}
    category_name_to_inf = {category["name"]: float("inf") for category in self.json_categories}
    num_images_per_category = copy.deepcopy(category_name_to_zero)
    num_annotations_per_category = copy.deepcopy(category_name_to_zero)
    min_annotation_area_per_category = copy.deepcopy(category_name_to_inf)
    max_annotation_area_per_category = copy.deepcopy(category_name_to_zero)
    min_num_annotations_in_image = float("inf")
    max_num_annotations_in_image = 0
    total_annotation_area = 0
    min_annotation_area = 1e10
    max_annotation_area = 0
    for image in self.images:
        image_contains_category = {}
        for annotation in image.annotations:
            annotation_area = annotation.area
            total_annotation_area += annotation_area
            num_annotations_per_category[annotation.category_name] += 1
            image_contains_category[annotation.category_name] = 1
            # update min&max annotation area
            if annotation_area > max_annotation_area:
                max_annotation_area = annotation_area
            if annotation_area < min_annotation_area:
                min_annotation_area = annotation_area
            if annotation_area > max_annotation_area_per_category[annotation.category_name]:
                max_annotation_area_per_category[annotation.category_name] = annotation_area
            if annotation_area < min_annotation_area_per_category[annotation.category_name]:
                min_annotation_area_per_category[annotation.category_name] = annotation_area
        # update num_negative_images
        if len(image.annotations) == 0:
            num_negative_images += 1
        # update num_annotations
        num_annotations += len(image.annotations)
        # update num_images_per_category
        num_images_per_category = dict(Counter(num_images_per_category) + Counter(image_contains_category))
        # update min&max_num_annotations_in_image
        num_annotations_in_image = len(image.annotations)
        if num_annotations_in_image > max_num_annotations_in_image:
            max_num_annotations_in_image = num_annotations_in_image
        if num_annotations_in_image < min_num_annotations_in_image:
            min_num_annotations_in_image = num_annotations_in_image
    if (num_images - num_negative_images) > 0:
        avg_num_annotations_in_image = num_annotations / (num_images - num_negative_images)
        avg_annotation_area = total_annotation_area / num_annotations
    else:
        avg_num_annotations_in_image = 0
        avg_annotation_area = 0

    self._stats = {
        "num_images": num_images,
        "num_annotations": num_annotations,
        "num_categories": num_categories,
        "num_negative_images": num_negative_images,
        "num_images_per_category": num_images_per_category,
        "num_annotations_per_category": num_annotations_per_category,
        "min_num_annotations_in_image": min_num_annotations_in_image,
        "max_num_annotations_in_image": max_num_annotations_in_image,
        "avg_num_annotations_in_image": avg_num_annotations_in_image,
        "min_annotation_area": min_annotation_area,
        "max_annotation_area": max_annotation_area,
        "avg_annotation_area": avg_annotation_area,
        "min_annotation_area_per_category": min_annotation_area_per_category,
        "max_annotation_area_per_category": max_annotation_area_per_category,
    }
export_as_yolo(output_dir, train_split_rate=1.0, numpy_seed=0, mp=False, disable_symlink=False)

Exports current COCO dataset in ultralytics/yolo format. Creates train val folders with image symlinks and txt files and a data yaml file.

Parameters:

Name Type Description Default
output_dir str | Path

str Export directory.

required
train_split_rate float

float If given 1, will be exported as train split. If given 0, will be exported as val split. If in between 0-1, both train/val splits will be calculated and exported.

1.0
numpy_seed int

int To fix the numpy seed.

0
mp bool

bool If True, multiprocess mode is on. Should be called in 'if name == main:' block.

False
disable_symlink bool

bool If True, symlinks will not be created. Instead, images will be copied.

False
Source code in sahi/utils/coco.py
def export_as_yolo(
    self,
    output_dir: str | Path,
    train_split_rate: float = 1.0,
    numpy_seed: int = 0,
    mp: bool = False,
    disable_symlink: bool = False,
):
    """Exports current COCO dataset in ultralytics/yolo format. Creates train val folders with image symlinks and
    txt files and a data yaml file.

    Args:
        output_dir: str
            Export directory.
        train_split_rate: float
            If given 1, will be exported as train split.
            If given 0, will be exported as val split.
            If in between 0-1, both train/val splits will be calculated and exported.
        numpy_seed: int
            To fix the numpy seed.
        mp: bool
            If True, multiprocess mode is on.
            Should be called in 'if __name__ == __main__:' block.
        disable_symlink: bool
            If True, symlinks will not be created. Instead, images will be copied.
    """
    try:
        import yaml
    except ImportError:
        raise ImportError('Please run "pip install -U pyyaml" to install yaml first for yolo formatted exporting.')

    # set split_mode
    if 0 < train_split_rate and train_split_rate < 1:
        split_mode = "TRAINVAL"
    elif train_split_rate == 0:
        split_mode = "VAL"
    elif train_split_rate == 1:
        split_mode = "TRAIN"
    else:
        raise ValueError("train_split_rate cannot be <0 or >1")

    # split dataset
    if split_mode == "TRAINVAL":
        result = self.split_coco_as_train_val(
            train_split_rate=train_split_rate,
            numpy_seed=numpy_seed,
        )
        train_coco = result["train_coco"]
        val_coco = result["val_coco"]
    elif split_mode == "TRAIN":
        train_coco = self
        val_coco = None
    elif split_mode == "VAL":
        train_coco = None
        val_coco = self

    # create train val image dirs
    train_dir = ""
    val_dir = ""
    if split_mode in ["TRAINVAL", "TRAIN"]:
        train_dir = Path(os.path.abspath(output_dir)) / "train/"
        train_dir.mkdir(parents=True, exist_ok=True)  # create dir
    if split_mode in ["TRAINVAL", "VAL"]:
        val_dir = Path(os.path.abspath(output_dir)) / "val/"
        val_dir.mkdir(parents=True, exist_ok=True)  # create dir

    # create image symlinks and annotation txts
    if split_mode in ["TRAINVAL", "TRAIN"]:
        export_yolo_images_and_txts_from_coco_object(
            output_dir=train_dir,
            coco=train_coco,
            ignore_negative_samples=self.ignore_negative_samples,
            mp=mp,
            disable_symlink=disable_symlink,
        )
    if split_mode in ["TRAINVAL", "VAL"]:
        export_yolo_images_and_txts_from_coco_object(
            output_dir=val_dir,
            coco=val_coco,
            ignore_negative_samples=self.ignore_negative_samples,
            mp=mp,
            disable_symlink=disable_symlink,
        )

    # create yolov5 data yaml
    data = {
        "train": str(train_dir),
        "val": str(val_dir),
        "nc": len(self.category_mapping),
        "names": list(self.category_mapping.values()),
    }
    yaml_path = str(Path(output_dir) / "data.yml")
    with open(yaml_path, "w") as outfile:
        yaml.dump(data, outfile, default_flow_style=None)
export_as_yolov5(output_dir, train_split_rate=1.0, numpy_seed=0, mp=False, disable_symlink=False)

Deprecated.

Please use export_as_yolo instead. Calls export_as_yolo with the same arguments.

Source code in sahi/utils/coco.py
def export_as_yolov5(
    self,
    output_dir: str | Path,
    train_split_rate: float = 1.0,
    numpy_seed: int = 0,
    mp: bool = False,
    disable_symlink: bool = False,
):
    """Deprecated.

    Please use export_as_yolo instead. Calls export_as_yolo with the same arguments.
    """
    warnings.warn(
        "export_as_yolov5 is deprecated. Please use export_as_yolo instead.",
        DeprecationWarning,
    )
    self.export_as_yolo(
        output_dir=output_dir,
        train_split_rate=train_split_rate,
        numpy_seed=numpy_seed,
        mp=mp,
        disable_symlink=disable_symlink,
    )
from_coco_dict_or_path(coco_dict_or_path, image_dir=None, remapping_dict=None, ignore_negative_samples=False, clip_bboxes_to_img_dims=False, use_threads=False, num_threads=10) classmethod

Creates coco object from COCO formatted dict or COCO dataset file path.

Parameters:

Name Type Description Default
coco_dict_or_path dict | str

dict/str or List[dict/str] COCO formatted dict or COCO dataset file path List of COCO formatted dict or COCO dataset file path

required
image_dir str | None

str Base file directory that contains dataset images. Required for merging and yolov5 conversion.

None
remapping_dict dict | None

dict {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1

None
ignore_negative_samples bool

bool If True ignores images without annotations in all operations.

False
clip_bboxes_to_img_dims bool

bool = False Limits bounding boxes to image dimensions.

False
use_threads bool

bool = False Use threads when processing the json image list, defaults to False

False
num_threads int

int = 10 Slice the image list to given number of chunks, defaults to 10

10
Properties

images: list of CocoImage category_mapping: dict

Source code in sahi/utils/coco.py
@classmethod
def from_coco_dict_or_path(
    cls,
    coco_dict_or_path: dict | str,
    image_dir: str | None = None,
    remapping_dict: dict | None = None,
    ignore_negative_samples: bool = False,
    clip_bboxes_to_img_dims: bool = False,
    use_threads: bool = False,
    num_threads: int = 10,
):
    """Creates coco object from COCO formatted dict or COCO dataset file path.

    Args:
        coco_dict_or_path: dict/str or List[dict/str]
            COCO formatted dict or COCO dataset file path
            List of COCO formatted dict or COCO dataset file path
        image_dir: str
            Base file directory that contains dataset images. Required for merging and yolov5 conversion.
        remapping_dict: dict
            {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1
        ignore_negative_samples: bool
            If True ignores images without annotations in all operations.
        clip_bboxes_to_img_dims: bool = False
            Limits bounding boxes to image dimensions.
        use_threads: bool = False
            Use threads when processing the json image list, defaults to False
        num_threads: int = 10
            Slice the image list to given number of chunks, defaults to 10

    Properties:
        images: list of CocoImage
        category_mapping: dict
    """
    # init coco object
    coco = cls(
        image_dir=image_dir,
        remapping_dict=remapping_dict,
        ignore_negative_samples=ignore_negative_samples,
        clip_bboxes_to_img_dims=clip_bboxes_to_img_dims,
    )

    if type(coco_dict_or_path) not in [str, dict]:
        raise TypeError("coco_dict_or_path should be a dict or str")

    # load coco dict if path is given
    if isinstance(coco_dict_or_path, str):
        coco_dict = load_json(coco_dict_or_path)
    else:
        coco_dict = coco_dict_or_path

    dict_size = len(coco_dict["images"])

    # arrange image id to annotation id mapping
    coco.add_categories_from_coco_category_list(coco_dict["categories"])
    image_id_to_annotation_list = get_imageid2annotationlist_mapping(coco_dict)
    category_mapping = coco.category_mapping

    # https://github.com/obss/sahi/issues/98
    image_id_set: set = set()

    lock = Lock()

    def fill_image_id_set(start, finish, image_list, _image_id_set, _image_id_to_annotation_list, _coco, lock):
        for coco_image_dict in tqdm(
            image_list[start:finish], f"Loading coco annotations between {start} and {finish}"
        ):
            coco_image = CocoImage.from_coco_image_dict(coco_image_dict)
            image_id = coco_image_dict["id"]
            # https://github.com/obss/sahi/issues/98
            if image_id in _image_id_set:
                print(f"duplicate image_id: {image_id}, will be ignored.")
                continue
            else:
                lock.acquire()
                _image_id_set.add(image_id)
                lock.release()

            # select annotations of the image
            annotation_list = _image_id_to_annotation_list[image_id]
            for coco_annotation_dict in annotation_list:
                # apply category remapping if remapping_dict is provided
                if _coco.remapping_dict is not None:
                    # apply category remapping (id:id)
                    category_id = _coco.remapping_dict[coco_annotation_dict["category_id"]]
                    # update category id
                    coco_annotation_dict["category_id"] = category_id
                else:
                    category_id = coco_annotation_dict["category_id"]
                # get category name (id:name)
                category_name = category_mapping[category_id]
                coco_annotation = CocoAnnotation.from_coco_annotation_dict(
                    category_name=category_name, annotation_dict=coco_annotation_dict
                )
                coco_image.add_annotation(coco_annotation)
            _coco.add_image(coco_image)

    chunk_size = dict_size / num_threads

    if use_threads is True:
        for i in range(num_threads):
            start = i * chunk_size
            finish = start + chunk_size
            if finish > dict_size:
                finish = dict_size
            t = Thread(
                target=fill_image_id_set,
                args=(start, finish, coco_dict["images"], image_id_set, image_id_to_annotation_list, coco, lock),
            )
            t.start()

        main_thread = threading.currentThread()
        for t in threading.enumerate():
            if t is not main_thread:
                t.join()

    else:
        for coco_image_dict in tqdm(coco_dict["images"], "Loading coco annotations"):
            coco_image = CocoImage.from_coco_image_dict(coco_image_dict)
            image_id = coco_image_dict["id"]
            # https://github.com/obss/sahi/issues/98
            if image_id in image_id_set:
                print(f"duplicate image_id: {image_id}, will be ignored.")
                continue
            else:
                image_id_set.add(image_id)
            # select annotations of the image
            annotation_list = image_id_to_annotation_list[image_id]
            # TODO: coco_annotation_dict is of type CocoAnnotation according to how image_id_to_annotation_list
            # was created. Either image_id_to_annotation_list is not defined correctly or the following
            # loop is wrong as it expects a dict.
            for coco_annotation_dict in annotation_list:
                # apply category remapping if remapping_dict is provided
                if coco.remapping_dict is not None:
                    # apply category remapping (id:id)
                    category_id = coco.remapping_dict[coco_annotation_dict["category_id"]]
                    # update category id
                    coco_annotation_dict["category_id"] = category_id
                else:
                    category_id = coco_annotation_dict["category_id"]
                # get category name (id:name)
                category_name = category_mapping[category_id]
                coco_annotation = CocoAnnotation.from_coco_annotation_dict(
                    category_name=category_name, annotation_dict=coco_annotation_dict
                )
                coco_image.add_annotation(coco_annotation)
            coco.add_image(coco_image)

    if clip_bboxes_to_img_dims:
        coco = coco.get_coco_with_clipped_bboxes()
    return coco
get_area_filtered_coco(min=0, max_val=float('inf'), intervals_per_category=None)

Filters annotation areas with given min and max values and returns remaining images as sahi.utils.coco.Coco object.

Parameters:

Name Type Description Default
min

int minimum allowed area

0
max_val

int maximum allowed area

float('inf')
intervals_per_category

dict of dicts { "human": {"min": 20, "max": 10000}, "vehicle": {"min": 50, "max": 15000}, }

None

Returns: area_filtered_coco: sahi.utils.coco.Coco

Source code in sahi/utils/coco.py
def get_area_filtered_coco(self, min=0, max_val=float("inf"), intervals_per_category=None):
    """Filters annotation areas with given min and max values and returns remaining images as sahi.utils.coco.Coco
    object.

    Args:
        min: int
            minimum allowed area
        max_val: int
            maximum allowed area
        intervals_per_category: dict of dicts
            {
                "human": {"min": 20, "max": 10000},
                "vehicle": {"min": 50, "max": 15000},
            }
    Returns:
        area_filtered_coco: sahi.utils.coco.Coco
    """
    area_filtered_coco = Coco(
        name=self.name,
        image_dir=self.image_dir,
        remapping_dict=self.remapping_dict,
        ignore_negative_samples=self.ignore_negative_samples,
    )
    area_filtered_coco.add_categories_from_coco_category_list(self.json_categories)
    for image in self.images:
        is_valid_image = True
        for annotation in image.annotations:
            if intervals_per_category is not None and annotation.category_name in intervals_per_category.keys():
                category_based_min = intervals_per_category[annotation.category_name]["min"]
                category_based_max = intervals_per_category[annotation.category_name]["max"]
                if annotation.area < category_based_min or annotation.area > category_based_max:
                    is_valid_image = False
            if annotation.area < min or annotation.area > max_val:
                is_valid_image = False
        if is_valid_image:
            area_filtered_coco.add_image(image)

    return area_filtered_coco
get_coco_with_clipped_bboxes()

Limits overflowing bounding boxes to image dimensions.

Source code in sahi/utils/coco.py
def get_coco_with_clipped_bboxes(self):
    """Limits overflowing bounding boxes to image dimensions."""
    from sahi.slicing import annotation_inside_slice

    coco = Coco(
        name=self.name,
        image_dir=self.image_dir,
        remapping_dict=self.remapping_dict,
        ignore_negative_samples=self.ignore_negative_samples,
    )
    coco.add_categories_from_coco_category_list(self.json_categories)

    for coco_img in self.images:
        img_dims = [0, 0, coco_img.width, coco_img.height]
        coco_image = CocoImage(
            file_name=coco_img.file_name, height=coco_img.height, width=coco_img.width, id=coco_img.id
        )
        for coco_ann in coco_img.annotations:
            ann_dict: dict = coco_ann.json
            if annotation_inside_slice(annotation=ann_dict, slice_bbox=img_dims):
                shapely_ann = coco_ann.get_sliced_coco_annotation(img_dims)
                bbox = ShapelyAnnotation.to_xywh(shapely_ann._shapely_annotation)
                coco_ann_from_shapely = CocoAnnotation(
                    bbox=bbox,
                    category_id=coco_ann.category_id,
                    category_name=coco_ann.category_name,
                    image_id=coco_ann.image_id,
                )
                coco_image.add_annotation(coco_ann_from_shapely)
            else:
                continue
        coco.add_image(coco_image)
    return coco
get_subsampled_coco(subsample_ratio=2, category_id=None)

Subsamples images with subsample_ratio and returns as sahi.utils.coco.Coco object.

Parameters:

Name Type Description Default
subsample_ratio int

int 10 means take every 10th image with its annotations

2
category_id int | None

int subsample only images containing given category_id, if -1 then subsamples negative samples

None

Returns: subsampled_coco: sahi.utils.coco.Coco

Source code in sahi/utils/coco.py
def get_subsampled_coco(self, subsample_ratio: int = 2, category_id: int | None = None):
    """Subsamples images with subsample_ratio and returns as sahi.utils.coco.Coco object.

    Args:
        subsample_ratio: int
            10 means take every 10th image with its annotations
        category_id: int
            subsample only images containing given category_id, if -1 then subsamples negative samples
    Returns:
        subsampled_coco: sahi.utils.coco.Coco
    """
    subsampled_coco = Coco(
        name=self.name,
        image_dir=self.image_dir,
        remapping_dict=self.remapping_dict,
        ignore_negative_samples=self.ignore_negative_samples,
    )
    subsampled_coco.add_categories_from_coco_category_list(self.json_categories)

    if category_id is not None:
        # get images that contain given category id
        images_that_contain_category: list[CocoImage] = []
        annotation: CocoAnnotation
        for image in self.images:
            category_id_to_contains = defaultdict(int)
            for annotation in image.annotations:
                category_id_to_contains[annotation.category_id] = 1
            if category_id_to_contains[category_id]:
                add_this_image = True
            elif category_id == -1 and len(image.annotations) == 0:
                # if category_id is given as -1, select negative samples
                add_this_image = True
            else:
                add_this_image = False

            if add_this_image:
                images_that_contain_category.append(image)

        # get images that does not contain given category id
        images_that_doesnt_contain_category: list[CocoImage] = []
        for image in self.images:
            category_id_to_contains = defaultdict(int)
            for annotation in image.annotations:
                category_id_to_contains[annotation.category_id] = 1
            if category_id_to_contains[category_id]:
                add_this_image = False
            elif category_id == -1 and len(image.annotations) == 0:
                # if category_id is given as -1, dont select negative samples
                add_this_image = False
            else:
                add_this_image = True

            if add_this_image:
                images_that_doesnt_contain_category.append(image)

    if category_id:
        selected_images = images_that_contain_category
        # add images that does not contain given category without subsampling
        for image_ind in range(len(images_that_doesnt_contain_category)):
            subsampled_coco.add_image(images_that_doesnt_contain_category[image_ind])
    else:
        selected_images = self.images
    for image_ind in range(0, len(selected_images), subsample_ratio):
        subsampled_coco.add_image(selected_images[image_ind])

    return subsampled_coco
get_upsampled_coco(upsample_ratio=2, category_id=None)

Upsamples images with upsample_ratio and returns as sahi.utils.coco.Coco object.

Parameters:

Name Type Description Default
upsample_ratio int

int 10 means copy each sample 10 times

2
category_id int | None

int upsample only images containing given category_id, if -1 then upsamples negative samples

None

Returns: upsampled_coco: sahi.utils.coco.Coco

Source code in sahi/utils/coco.py
def get_upsampled_coco(self, upsample_ratio: int = 2, category_id: int | None = None):
    """Upsamples images with upsample_ratio and returns as sahi.utils.coco.Coco object.

    Args:
        upsample_ratio: int
            10 means copy each sample 10 times
        category_id: int
            upsample only images containing given category_id, if -1 then upsamples negative samples
    Returns:
        upsampled_coco: sahi.utils.coco.Coco
    """
    upsampled_coco = Coco(
        name=self.name,
        image_dir=self.image_dir,
        remapping_dict=self.remapping_dict,
        ignore_negative_samples=self.ignore_negative_samples,
    )
    upsampled_coco.add_categories_from_coco_category_list(self.json_categories)
    for ind in range(upsample_ratio):
        for image_ind in range(len(self.images)):
            # calculate add_this_image
            if category_id is not None:
                category_id_to_contains = defaultdict(int)
                annotation: CocoAnnotation
                for annotation in self.images[image_ind].annotations:
                    category_id_to_contains[annotation.category_id] = 1
                if category_id_to_contains[category_id]:
                    add_this_image = True
                elif category_id == -1 and len(self.images[image_ind].annotations) == 0:
                    # if category_id is given as -1, select negative samples
                    add_this_image = True
                elif ind == 0:
                    # in first iteration add all images
                    add_this_image = True
                else:
                    add_this_image = False
            else:
                add_this_image = True

            if add_this_image:
                upsampled_coco.add_image(self.images[image_ind])

    return upsampled_coco
merge(coco, desired_name2id=None, verbose=1)

Combines the images/annotations/categories of given coco object with current one.

Parameters:

Name Type Description Default
coco

sahi.utils.coco.Coco instance A COCO dataset object

required
desired_name2id

dict

required
verbose

bool If True, merging info is printed

1
Source code in sahi/utils/coco.py
def merge(self, coco, desired_name2id=None, verbose=1):
    """Combines the images/annotations/categories of given coco object with current one.

    Args:
        coco : sahi.utils.coco.Coco instance
            A COCO dataset object
        desired_name2id : dict
            {"human": 1, "car": 2, "big_vehicle": 3}
        verbose: bool
            If True, merging info is printed
    """
    if self.image_dir is None or coco.image_dir is None:
        raise ValueError("image_dir should be provided for merging.")
    if verbose:
        if not desired_name2id:
            print("'desired_name2id' is not specified, combining all categories.")

    # create desired_name2id by combining all categories, if desired_name2id is not specified
    coco1 = self
    coco2 = coco
    category_ind = 0
    if desired_name2id is None:
        desired_name2id = {}
        for coco in [coco1, coco2]:
            temp_categories = copy.deepcopy(coco.json_categories)
            for temp_category in temp_categories:
                if temp_category["name"] not in desired_name2id:
                    desired_name2id[temp_category["name"]] = category_ind
                    category_ind += 1
                else:
                    continue

    # update categories and image paths
    for coco in [coco1, coco2]:
        coco.update_categories(desired_name2id=desired_name2id, update_image_filenames=True)

    # combine images and categories
    coco1.images.extend(coco2.images)
    self.images: list[CocoImage] = coco1.images
    self.categories = coco1.categories

    # print categories
    if verbose:
        print(
            "Categories are formed as:\n",
            self.json_categories,
        )
split_coco_as_train_val(train_split_rate=0.9, numpy_seed=0)

Split images into train-val and returns them as sahi.utils.coco.Coco objects.

Parameters:

Name Type Description Default
train_split_rate

float

0.9
numpy_seed

int random seed. Actually, this doesn't use numpy, but the random package from the standard library, but it is called numpy for compatibility.

0

Returns:

Name Type Description
result

dict { "train_coco": "", "val_coco": "", }

Source code in sahi/utils/coco.py
def split_coco_as_train_val(self, train_split_rate=0.9, numpy_seed=0):
    """Split images into train-val and returns them as sahi.utils.coco.Coco objects.

    Args:
        train_split_rate: float
        numpy_seed: int
            random seed. Actually, this doesn't use numpy, but the random package
            from the standard library, but it is called numpy for compatibility.

    Returns:
        result : dict
            {
                "train_coco": "",
                "val_coco": "",
            }
    """
    # divide images
    num_images = len(self.images)
    shuffled_images = copy.deepcopy(self.images)
    random.seed(numpy_seed)
    random.shuffle(shuffled_images)
    num_train = int(num_images * train_split_rate)
    train_images = shuffled_images[:num_train]
    val_images = shuffled_images[num_train:]

    # form train val coco objects
    train_coco = Coco(
        name=self.name if self.name else "split" + "_train",
        image_dir=self.image_dir,
    )
    train_coco.images = train_images
    train_coco.categories = self.categories

    val_coco = Coco(name=self.name if self.name else "split" + "_val", image_dir=self.image_dir)
    val_coco.images = val_images
    val_coco.categories = self.categories

    # return result
    return {
        "train_coco": train_coco,
        "val_coco": val_coco,
    }
update_categories(desired_name2id, update_image_filenames=False)

Rearranges category mapping of given COCO object based on given desired_name2id. Can also be used to filter some of the categories.

Parameters:

Name Type Description Default
desired_name2id dict[str, int]

dict

required
update_image_filenames bool

bool If True, updates coco image file_names with absolute file paths.

False
Source code in sahi/utils/coco.py
def update_categories(self, desired_name2id: dict[str, int], update_image_filenames: bool = False):
    """Rearranges category mapping of given COCO object based on given desired_name2id. Can also be used to filter
    some of the categories.

    Args:
        desired_name2id: dict
            {"big_vehicle": 1, "car": 2, "human": 3}
        update_image_filenames: bool
            If True, updates coco image file_names with absolute file paths.
    """
    # init vars
    currentid2desiredid_mapping: dict[int, int | None] = {}
    updated_coco = Coco(
        name=self.name,
        image_dir=self.image_dir,
        remapping_dict=self.remapping_dict,
        ignore_negative_samples=self.ignore_negative_samples,
    )
    # create category id mapping (currentid2desiredid_mapping)
    for coco_category in self.categories:
        current_category_id = coco_category.id
        current_category_name = coco_category.name
        if not current_category_name:
            logger.warning("no category name provided to update categories")
            continue
        if current_category_name in desired_name2id.keys():
            currentid2desiredid_mapping[current_category_id] = desired_name2id[current_category_name]
        else:
            # ignore categories that are not included in desired_name2id
            currentid2desiredid_mapping[current_category_id] = None

    # add updated categories
    for name in desired_name2id.keys():
        updated_coco_category = CocoCategory(id=desired_name2id[name], name=name, supercategory=name)
        updated_coco.add_category(updated_coco_category)

    # add updated images & annotations
    for coco_image in copy.deepcopy(self.images):
        updated_coco_image = CocoImage.from_coco_image_dict(coco_image.json)
        # update filename to abspath
        file_name_is_abspath = True if os.path.abspath(coco_image.file_name) == coco_image.file_name else False
        if update_image_filenames and not file_name_is_abspath:
            if not self.image_dir:
                logger.error("image directory not set")
            else:
                updated_coco_image.file_name = str(Path(os.path.abspath(self.image_dir)) / coco_image.file_name)
        # update annotations
        for coco_annotation in coco_image.annotations:
            current_category_id = coco_annotation.category_id
            desired_category_id = currentid2desiredid_mapping[current_category_id]
            # append annotations with category id present in desired_name2id
            if desired_category_id is not None:
                # update cetegory id
                coco_annotation.category_id = desired_category_id
                # append updated annotation to target coco dict
                updated_coco_image.add_annotation(coco_annotation)
        updated_coco.add_image(updated_coco_image)

    # overwrite instance
    self.__dict__ = updated_coco.__dict__
CocoAnnotation

COCO formatted annotation.

Source code in sahi/utils/coco.py
class CocoAnnotation:
    """COCO formatted annotation."""

    @classmethod
    def from_coco_segmentation(cls, segmentation, category_id, category_name, iscrowd=0):
        """Creates CocoAnnotation object using coco segmentation.

        Args:
            segmentation: List[List]
                [[1, 1, 325, 125, 250, 200, 5, 200]]
            category_id: int
                Category id of the annotation
            category_name: str
                Category name of the annotation
            iscrowd: int
                0 or 1
        """
        return cls(
            segmentation=segmentation,
            category_id=category_id,
            category_name=category_name,
            iscrowd=iscrowd,
        )

    @classmethod
    def from_coco_bbox(cls, bbox, category_id, category_name, iscrowd=0):
        """Creates CocoAnnotation object using coco bbox.

        Args:
            bbox: List
                [xmin, ymin, width, height]
            category_id: int
                Category id of the annotation
            category_name: str
                Category name of the annotation
            iscrowd: int
                0 or 1
        """
        return cls(
            bbox=bbox,
            category_id=category_id,
            category_name=category_name,
            iscrowd=iscrowd,
        )

    @classmethod
    def from_coco_annotation_dict(cls, annotation_dict: dict, category_name: str | None = None):
        """Creates CocoAnnotation object from category name and COCO formatted annotation dict (with fields "bbox",
        "segmentation", "category_id").

        Args:
            category_name: str
                Category name of the annotation
            annotation_dict: dict
                COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")
        """
        if annotation_dict.__contains__("segmentation") and isinstance(annotation_dict["segmentation"], dict):
            has_rle_segmentation = True
            logger.warning(
                f"Segmentation annotation for id {annotation_dict['id']} is skipped since "
                "RLE segmentation format is not supported."
            )
        else:
            has_rle_segmentation = False

        if (
            annotation_dict.__contains__("segmentation")
            and annotation_dict["segmentation"]
            and not has_rle_segmentation
        ):
            return cls(
                segmentation=annotation_dict["segmentation"],
                category_id=annotation_dict["category_id"],
                category_name=category_name,
            )
        else:
            return cls(
                bbox=annotation_dict["bbox"],
                category_id=annotation_dict["category_id"],
                category_name=category_name,
            )

    @classmethod
    def from_shapely_annotation(
        cls,
        shapely_annotation: ShapelyAnnotation,
        category_id: int,
        category_name: str,
        iscrowd: int,
    ):
        """Creates CocoAnnotation object from ShapelyAnnotation object.

        Args:
            shapely_annotation (ShapelyAnnotation)
            category_id (int): Category id of the annotation
            category_name (str): Category name of the annotation
            iscrowd (int): 0 or 1
        """
        coco_annotation = cls(
            bbox=[0, 0, 0, 0],
            category_id=category_id,
            category_name=category_name,
            iscrowd=iscrowd,
        )
        coco_annotation._segmentation = shapely_annotation.to_coco_segmentation()
        coco_annotation._shapely_annotation = shapely_annotation
        return coco_annotation

    def __init__(
        self,
        category_id: int,
        category_name: str | None = None,
        segmentation=None,
        bbox: list[int] | None = None,
        image_id=None,
        iscrowd=0,
    ):
        """Creates coco annotation object using bbox or segmentation.

        Args:
            segmentation: List[List]
                [[1, 1, 325, 125, 250, 200, 5, 200]]
            bbox: List
                [xmin, ymin, width, height]
            category_id: int
                Category id of the annotation
            category_name: str
                Category name of the annotation
            image_id: int
                Image ID of the annotation
            iscrowd: int
                0 or 1
        """
        if bbox is None and segmentation is None:
            raise ValueError("you must provide a bbox or polygon")

        self._segmentation = segmentation
        self._category_id = category_id
        self._category_name = category_name
        self._image_id = image_id
        self._iscrowd = iscrowd

        if self._segmentation:
            shapely_annotation = ShapelyAnnotation.from_coco_segmentation(segmentation=self._segmentation)
        else:
            if not bbox:
                raise TypeError("Coco bounding box not set")
            shapely_annotation = ShapelyAnnotation.from_coco_bbox(bbox=bbox)
        self._shapely_annotation = shapely_annotation

    def get_sliced_coco_annotation(self, slice_bbox: list[int]):
        shapely_polygon = box(slice_bbox[0], slice_bbox[1], slice_bbox[2], slice_bbox[3])
        intersection_shapely_annotation = self._shapely_annotation.get_intersection(shapely_polygon)
        return CocoAnnotation.from_shapely_annotation(
            intersection_shapely_annotation,
            category_id=self.category_id,
            category_name=self.category_name or "",
            iscrowd=self.iscrowd,
        )

    @property
    def area(self):
        """Returns area of annotation polygon (or bbox if no polygon available)"""
        return self._shapely_annotation.area

    @property
    def bbox(self):
        """Returns coco formatted bbox of the annotation as [xmin, ymin, width, height]"""
        return self._shapely_annotation.to_xywh()

    @property
    def segmentation(self):
        """Returns coco formatted segmentation of the annotation as [[1, 1, 325, 125, 250, 200, 5, 200]]"""
        if self._segmentation:
            return self._shapely_annotation.to_coco_segmentation()
        else:
            return []

    @property
    def category_id(self):
        """Returns category id of the annotation as int."""
        return self._category_id

    @category_id.setter
    def category_id(self, i):
        if not isinstance(i, int):
            raise Exception("category_id must be an integer")
        self._category_id = i

    @property
    def image_id(self):
        """Returns image id of the annotation as int."""
        return self._image_id

    @image_id.setter
    def image_id(self, i):
        if not isinstance(i, int):
            raise Exception("image_id must be an integer")
        self._image_id = i

    @property
    def category_name(self):
        """Returns category name of the annotation as str."""
        return self._category_name

    @category_name.setter
    def category_name(self, n):
        if not isinstance(n, str):
            raise Exception("category_name must be a string")
        self._category_name = n

    @property
    def iscrowd(self):
        """Returns iscrowd info of the annotation."""
        return self._iscrowd

    @property
    def json(self):
        return {
            "image_id": self.image_id,
            "bbox": self.bbox,
            "category_id": self.category_id,
            "segmentation": self.segmentation,
            "iscrowd": self.iscrowd,
            "area": self.area,
        }

    def serialize(self):
        warnings.warn("Use json property instead of serialize method", DeprecationWarning, stacklevel=2)
        return self.json

    def __repr__(self):
        return f"""CocoAnnotation<
    image_id: {self.image_id},
    bbox: {self.bbox},
    segmentation: {self.segmentation},
    category_id: {self.category_id},
    category_name: {self.category_name},
    iscrowd: {self.iscrowd},
    area: {self.area}>"""
Attributes
area property

Returns area of annotation polygon (or bbox if no polygon available)

bbox property

Returns coco formatted bbox of the annotation as [xmin, ymin, width, height]

category_id property writable

Returns category id of the annotation as int.

category_name property writable

Returns category name of the annotation as str.

image_id property writable

Returns image id of the annotation as int.

iscrowd property

Returns iscrowd info of the annotation.

segmentation property

Returns coco formatted segmentation of the annotation as [[1, 1, 325, 125, 250, 200, 5, 200]]

Functions
__init__(category_id, category_name=None, segmentation=None, bbox=None, image_id=None, iscrowd=0)

Creates coco annotation object using bbox or segmentation.

Parameters:

Name Type Description Default
segmentation

List[List][[1, 1, 325, 125, 250, 200, 5, 200]]

None
bbox list[int] | None

List [xmin, ymin, width, height]

None
category_id int

int Category id of the annotation

required
category_name str | None

str Category name of the annotation

None
image_id

int Image ID of the annotation

None
iscrowd

int 0 or 1

0
Source code in sahi/utils/coco.py
def __init__(
    self,
    category_id: int,
    category_name: str | None = None,
    segmentation=None,
    bbox: list[int] | None = None,
    image_id=None,
    iscrowd=0,
):
    """Creates coco annotation object using bbox or segmentation.

    Args:
        segmentation: List[List]
            [[1, 1, 325, 125, 250, 200, 5, 200]]
        bbox: List
            [xmin, ymin, width, height]
        category_id: int
            Category id of the annotation
        category_name: str
            Category name of the annotation
        image_id: int
            Image ID of the annotation
        iscrowd: int
            0 or 1
    """
    if bbox is None and segmentation is None:
        raise ValueError("you must provide a bbox or polygon")

    self._segmentation = segmentation
    self._category_id = category_id
    self._category_name = category_name
    self._image_id = image_id
    self._iscrowd = iscrowd

    if self._segmentation:
        shapely_annotation = ShapelyAnnotation.from_coco_segmentation(segmentation=self._segmentation)
    else:
        if not bbox:
            raise TypeError("Coco bounding box not set")
        shapely_annotation = ShapelyAnnotation.from_coco_bbox(bbox=bbox)
    self._shapely_annotation = shapely_annotation
from_coco_annotation_dict(annotation_dict, category_name=None) classmethod

Creates CocoAnnotation object from category name and COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id").

Parameters:

Name Type Description Default
category_name str | None

str Category name of the annotation

None
annotation_dict dict

dict COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")

required
Source code in sahi/utils/coco.py
@classmethod
def from_coco_annotation_dict(cls, annotation_dict: dict, category_name: str | None = None):
    """Creates CocoAnnotation object from category name and COCO formatted annotation dict (with fields "bbox",
    "segmentation", "category_id").

    Args:
        category_name: str
            Category name of the annotation
        annotation_dict: dict
            COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")
    """
    if annotation_dict.__contains__("segmentation") and isinstance(annotation_dict["segmentation"], dict):
        has_rle_segmentation = True
        logger.warning(
            f"Segmentation annotation for id {annotation_dict['id']} is skipped since "
            "RLE segmentation format is not supported."
        )
    else:
        has_rle_segmentation = False

    if (
        annotation_dict.__contains__("segmentation")
        and annotation_dict["segmentation"]
        and not has_rle_segmentation
    ):
        return cls(
            segmentation=annotation_dict["segmentation"],
            category_id=annotation_dict["category_id"],
            category_name=category_name,
        )
    else:
        return cls(
            bbox=annotation_dict["bbox"],
            category_id=annotation_dict["category_id"],
            category_name=category_name,
        )
from_coco_bbox(bbox, category_id, category_name, iscrowd=0) classmethod

Creates CocoAnnotation object using coco bbox.

Parameters:

Name Type Description Default
bbox

List [xmin, ymin, width, height]

required
category_id

int Category id of the annotation

required
category_name

str Category name of the annotation

required
iscrowd

int 0 or 1

0
Source code in sahi/utils/coco.py
@classmethod
def from_coco_bbox(cls, bbox, category_id, category_name, iscrowd=0):
    """Creates CocoAnnotation object using coco bbox.

    Args:
        bbox: List
            [xmin, ymin, width, height]
        category_id: int
            Category id of the annotation
        category_name: str
            Category name of the annotation
        iscrowd: int
            0 or 1
    """
    return cls(
        bbox=bbox,
        category_id=category_id,
        category_name=category_name,
        iscrowd=iscrowd,
    )
from_coco_segmentation(segmentation, category_id, category_name, iscrowd=0) classmethod

Creates CocoAnnotation object using coco segmentation.

Parameters:

Name Type Description Default
segmentation

List[List][[1, 1, 325, 125, 250, 200, 5, 200]]

required
category_id

int Category id of the annotation

required
category_name

str Category name of the annotation

required
iscrowd

int 0 or 1

0
Source code in sahi/utils/coco.py
@classmethod
def from_coco_segmentation(cls, segmentation, category_id, category_name, iscrowd=0):
    """Creates CocoAnnotation object using coco segmentation.

    Args:
        segmentation: List[List]
            [[1, 1, 325, 125, 250, 200, 5, 200]]
        category_id: int
            Category id of the annotation
        category_name: str
            Category name of the annotation
        iscrowd: int
            0 or 1
    """
    return cls(
        segmentation=segmentation,
        category_id=category_id,
        category_name=category_name,
        iscrowd=iscrowd,
    )
from_shapely_annotation(shapely_annotation, category_id, category_name, iscrowd) classmethod

Creates CocoAnnotation object from ShapelyAnnotation object.

Parameters:

Name Type Description Default
category_id int

Category id of the annotation

required
category_name str

Category name of the annotation

required
iscrowd int

0 or 1

required
Source code in sahi/utils/coco.py
@classmethod
def from_shapely_annotation(
    cls,
    shapely_annotation: ShapelyAnnotation,
    category_id: int,
    category_name: str,
    iscrowd: int,
):
    """Creates CocoAnnotation object from ShapelyAnnotation object.

    Args:
        shapely_annotation (ShapelyAnnotation)
        category_id (int): Category id of the annotation
        category_name (str): Category name of the annotation
        iscrowd (int): 0 or 1
    """
    coco_annotation = cls(
        bbox=[0, 0, 0, 0],
        category_id=category_id,
        category_name=category_name,
        iscrowd=iscrowd,
    )
    coco_annotation._segmentation = shapely_annotation.to_coco_segmentation()
    coco_annotation._shapely_annotation = shapely_annotation
    return coco_annotation
CocoCategory

COCO formatted category.

Source code in sahi/utils/coco.py
class CocoCategory:
    """COCO formatted category."""

    def __init__(self, id: int = 0, name: str | None = None, supercategory: str | None = None):
        self.id = int(id)
        self.name = name
        self.supercategory = supercategory if supercategory else name

    @classmethod
    def from_coco_category(cls, category):
        """Creates CocoCategory object using coco category.

        Args:
            category: Dict
                {"supercategory": "person", "id": 1, "name": "person"},
        """
        return cls(
            id=category["id"],
            name=category["name"],
            supercategory=category["supercategory"] if "supercategory" in category else category["name"],
        )

    @property
    def json(self):
        return {
            "id": self.id,
            "name": self.name,
            "supercategory": self.supercategory,
        }

    def __repr__(self):
        return f"""CocoCategory<
    id: {self.id},
    name: {self.name},
    supercategory: {self.supercategory}>"""
Functions
from_coco_category(category) classmethod

Creates CocoCategory object using coco category.

Parameters:

Name Type Description Default
category

Dict {"supercategory": "person", "id": 1, "name": "person"},

required
Source code in sahi/utils/coco.py
@classmethod
def from_coco_category(cls, category):
    """Creates CocoCategory object using coco category.

    Args:
        category: Dict
            {"supercategory": "person", "id": 1, "name": "person"},
    """
    return cls(
        id=category["id"],
        name=category["name"],
        supercategory=category["supercategory"] if "supercategory" in category else category["name"],
    )
CocoImage
Source code in sahi/utils/coco.py
class CocoImage:
    @classmethod
    def from_coco_image_dict(cls, image_dict):
        """Creates CocoImage object from COCO formatted image dict (with fields "id", "file_name", "height" and
        "weight").

        Args:
            image_dict: dict
                COCO formatted image dict (with fields "id", "file_name", "height" and "weight")
        """
        return cls(
            id=image_dict["id"],
            file_name=image_dict["file_name"],
            height=image_dict["height"],
            width=image_dict["width"],
        )

    def __init__(self, file_name: str, height: int, width: int, id: int | None = None):
        """Creates CocoImage object.

        Args:
            id : int
                Image id
            file_name : str
                Image path
            height : int
                Image height in pixels
            width : int
                Image width in pixels
        """
        self.id = int(id) if id else id
        self.file_name = file_name
        self.height = int(height)
        self.width = int(width)
        self.annotations = []  # list of CocoAnnotation that belong to this image
        self.predictions = []  # list of CocoPrediction that belong to this image

    def add_annotation(self, annotation):
        """Adds annotation to this CocoImage instance.

        annotation : CocoAnnotation
        """

        if not isinstance(annotation, CocoAnnotation):
            raise TypeError("annotation must be a CocoAnnotation instance")
        self.annotations.append(annotation)

    def add_prediction(self, prediction):
        """Adds prediction to this CocoImage instance.

        prediction : CocoPrediction
        """

        if not isinstance(prediction, CocoPrediction):
            raise TypeError("prediction must be a CocoPrediction instance")
        self.predictions.append(prediction)

    @property
    def json(self):
        return {
            "id": self.id,
            "file_name": self.file_name,
            "height": self.height,
            "width": self.width,
        }

    def __repr__(self):
        return f"""CocoImage<
    id: {self.id},
    file_name: {self.file_name},
    height: {self.height},
    width: {self.width},
    annotations: List[CocoAnnotation],
    predictions: List[CocoPrediction]>"""
Functions
__init__(file_name, height, width, id=None)

Creates CocoImage object.

Parameters:

Name Type Description Default
id

int Image id

required
file_name

str Image path

required
height

int Image height in pixels

required
width

int Image width in pixels

required
Source code in sahi/utils/coco.py
def __init__(self, file_name: str, height: int, width: int, id: int | None = None):
    """Creates CocoImage object.

    Args:
        id : int
            Image id
        file_name : str
            Image path
        height : int
            Image height in pixels
        width : int
            Image width in pixels
    """
    self.id = int(id) if id else id
    self.file_name = file_name
    self.height = int(height)
    self.width = int(width)
    self.annotations = []  # list of CocoAnnotation that belong to this image
    self.predictions = []  # list of CocoPrediction that belong to this image
add_annotation(annotation)

Adds annotation to this CocoImage instance.

annotation : CocoAnnotation

Source code in sahi/utils/coco.py
def add_annotation(self, annotation):
    """Adds annotation to this CocoImage instance.

    annotation : CocoAnnotation
    """

    if not isinstance(annotation, CocoAnnotation):
        raise TypeError("annotation must be a CocoAnnotation instance")
    self.annotations.append(annotation)
add_prediction(prediction)

Adds prediction to this CocoImage instance.

prediction : CocoPrediction

Source code in sahi/utils/coco.py
def add_prediction(self, prediction):
    """Adds prediction to this CocoImage instance.

    prediction : CocoPrediction
    """

    if not isinstance(prediction, CocoPrediction):
        raise TypeError("prediction must be a CocoPrediction instance")
    self.predictions.append(prediction)
from_coco_image_dict(image_dict) classmethod

Creates CocoImage object from COCO formatted image dict (with fields "id", "file_name", "height" and "weight").

Parameters:

Name Type Description Default
image_dict

dict COCO formatted image dict (with fields "id", "file_name", "height" and "weight")

required
Source code in sahi/utils/coco.py
@classmethod
def from_coco_image_dict(cls, image_dict):
    """Creates CocoImage object from COCO formatted image dict (with fields "id", "file_name", "height" and
    "weight").

    Args:
        image_dict: dict
            COCO formatted image dict (with fields "id", "file_name", "height" and "weight")
    """
    return cls(
        id=image_dict["id"],
        file_name=image_dict["file_name"],
        height=image_dict["height"],
        width=image_dict["width"],
    )
CocoPrediction

Bases: CocoAnnotation

Class for handling predictions in coco format.

Source code in sahi/utils/coco.py
class CocoPrediction(CocoAnnotation):
    """Class for handling predictions in coco format."""

    @classmethod
    def from_coco_segmentation(cls, segmentation, category_id, category_name, score, iscrowd=0, image_id=None):
        """Creates CocoAnnotation object using coco segmentation.

        Args:
            segmentation: List[List]
                [[1, 1, 325, 125, 250, 200, 5, 200]]
            category_id: int
                Category id of the annotation
            category_name: str
                Category name of the annotation
            score: float
                Prediction score between 0 and 1
            iscrowd: int
                0 or 1
        """
        return cls(
            segmentation=segmentation,
            category_id=category_id,
            category_name=category_name,
            score=score,
            iscrowd=iscrowd,
            image_id=image_id,
        )

    @classmethod
    def from_coco_bbox(cls, bbox, category_id, category_name, score, iscrowd=0, image_id=None):
        """Creates CocoAnnotation object using coco bbox.

        Args:
            bbox: List
                [xmin, ymin, width, height]
            category_id: int
                Category id of the annotation
            category_name: str
                Category name of the annotation
            score: float
                Prediction score between 0 and 1
            iscrowd: int
                0 or 1
        """
        return cls(
            bbox=bbox,
            category_id=category_id,
            category_name=category_name,
            score=score,
            iscrowd=iscrowd,
            image_id=image_id,
        )

    @classmethod
    def from_coco_annotation_dict(cls, category_name, annotation_dict, score, image_id=None):
        """Creates CocoAnnotation object from category name and COCO formatted annotation dict (with fields "bbox",
        "segmentation", "category_id").

        Args:
            category_name: str
                Category name of the annotation
            annotation_dict: dict
                COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")
            score: float
                Prediction score between 0 and 1
        """
        if annotation_dict["segmentation"]:
            return cls(
                segmentation=annotation_dict["segmentation"],
                category_id=annotation_dict["category_id"],
                category_name=category_name,
                score=score,
                image_id=image_id,
            )
        else:
            return cls(
                bbox=annotation_dict["bbox"],
                category_id=annotation_dict["category_id"],
                category_name=category_name,
                image_id=image_id,
            )

    def __init__(
        self,
        segmentation=None,
        bbox=None,
        category_id: int = 0,
        category_name: str = "",
        image_id=None,
        score=None,
        iscrowd=0,
    ):
        """

        Args:
            segmentation: List[List]
                [[1, 1, 325, 125, 250, 200, 5, 200]]
            bbox: List
                [xmin, ymin, width, height]
            category_id: int
                Category id of the annotation
            category_name: str
                Category name of the annotation
            image_id: int
                Image ID of the annotation
            score: float
                Prediction score between 0 and 1
            iscrowd: int
                0 or 1
        """
        self.score = score
        super().__init__(
            segmentation=segmentation,
            bbox=bbox,
            category_id=category_id,
            category_name=category_name,
            image_id=image_id,
            iscrowd=iscrowd,
        )

    @property
    def json(self):
        return {
            "image_id": self.image_id,
            "bbox": self.bbox,
            "score": self.score,
            "category_id": self.category_id,
            "category_name": self.category_name,
            "segmentation": self.segmentation,
            "iscrowd": self.iscrowd,
            "area": self.area,
        }

    def serialize(self):
        warnings.warn("Use json property instead of serialize method", DeprecationWarning, stacklevel=2)

    def __repr__(self):
        return f"""CocoPrediction<
    image_id: {self.image_id},
    bbox: {self.bbox},
    segmentation: {self.segmentation},
    score: {self.score},
    category_id: {self.category_id},
    category_name: {self.category_name},
    iscrowd: {self.iscrowd},
    area: {self.area}>"""
Functions
__init__(segmentation=None, bbox=None, category_id=0, category_name='', image_id=None, score=None, iscrowd=0)

Parameters:

Name Type Description Default
segmentation

List[List][[1, 1, 325, 125, 250, 200, 5, 200]]

None
bbox

List [xmin, ymin, width, height]

None
category_id int

int Category id of the annotation

0
category_name str

str Category name of the annotation

''
image_id

int Image ID of the annotation

None
score

float Prediction score between 0 and 1

None
iscrowd

int 0 or 1

0
Source code in sahi/utils/coco.py
def __init__(
    self,
    segmentation=None,
    bbox=None,
    category_id: int = 0,
    category_name: str = "",
    image_id=None,
    score=None,
    iscrowd=0,
):
    """

    Args:
        segmentation: List[List]
            [[1, 1, 325, 125, 250, 200, 5, 200]]
        bbox: List
            [xmin, ymin, width, height]
        category_id: int
            Category id of the annotation
        category_name: str
            Category name of the annotation
        image_id: int
            Image ID of the annotation
        score: float
            Prediction score between 0 and 1
        iscrowd: int
            0 or 1
    """
    self.score = score
    super().__init__(
        segmentation=segmentation,
        bbox=bbox,
        category_id=category_id,
        category_name=category_name,
        image_id=image_id,
        iscrowd=iscrowd,
    )
from_coco_annotation_dict(category_name, annotation_dict, score, image_id=None) classmethod

Creates CocoAnnotation object from category name and COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id").

Parameters:

Name Type Description Default
category_name

str Category name of the annotation

required
annotation_dict

dict COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")

required
score

float Prediction score between 0 and 1

required
Source code in sahi/utils/coco.py
@classmethod
def from_coco_annotation_dict(cls, category_name, annotation_dict, score, image_id=None):
    """Creates CocoAnnotation object from category name and COCO formatted annotation dict (with fields "bbox",
    "segmentation", "category_id").

    Args:
        category_name: str
            Category name of the annotation
        annotation_dict: dict
            COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")
        score: float
            Prediction score between 0 and 1
    """
    if annotation_dict["segmentation"]:
        return cls(
            segmentation=annotation_dict["segmentation"],
            category_id=annotation_dict["category_id"],
            category_name=category_name,
            score=score,
            image_id=image_id,
        )
    else:
        return cls(
            bbox=annotation_dict["bbox"],
            category_id=annotation_dict["category_id"],
            category_name=category_name,
            image_id=image_id,
        )
from_coco_bbox(bbox, category_id, category_name, score, iscrowd=0, image_id=None) classmethod

Creates CocoAnnotation object using coco bbox.

Parameters:

Name Type Description Default
bbox

List [xmin, ymin, width, height]

required
category_id

int Category id of the annotation

required
category_name

str Category name of the annotation

required
score

float Prediction score between 0 and 1

required
iscrowd

int 0 or 1

0
Source code in sahi/utils/coco.py
@classmethod
def from_coco_bbox(cls, bbox, category_id, category_name, score, iscrowd=0, image_id=None):
    """Creates CocoAnnotation object using coco bbox.

    Args:
        bbox: List
            [xmin, ymin, width, height]
        category_id: int
            Category id of the annotation
        category_name: str
            Category name of the annotation
        score: float
            Prediction score between 0 and 1
        iscrowd: int
            0 or 1
    """
    return cls(
        bbox=bbox,
        category_id=category_id,
        category_name=category_name,
        score=score,
        iscrowd=iscrowd,
        image_id=image_id,
    )
from_coco_segmentation(segmentation, category_id, category_name, score, iscrowd=0, image_id=None) classmethod

Creates CocoAnnotation object using coco segmentation.

Parameters:

Name Type Description Default
segmentation

List[List][[1, 1, 325, 125, 250, 200, 5, 200]]

required
category_id

int Category id of the annotation

required
category_name

str Category name of the annotation

required
score

float Prediction score between 0 and 1

required
iscrowd

int 0 or 1

0
Source code in sahi/utils/coco.py
@classmethod
def from_coco_segmentation(cls, segmentation, category_id, category_name, score, iscrowd=0, image_id=None):
    """Creates CocoAnnotation object using coco segmentation.

    Args:
        segmentation: List[List]
            [[1, 1, 325, 125, 250, 200, 5, 200]]
        category_id: int
            Category id of the annotation
        category_name: str
            Category name of the annotation
        score: float
            Prediction score between 0 and 1
        iscrowd: int
            0 or 1
    """
    return cls(
        segmentation=segmentation,
        category_id=category_id,
        category_name=category_name,
        score=score,
        iscrowd=iscrowd,
        image_id=image_id,
    )
CocoVid
Source code in sahi/utils/coco.py
class CocoVid:
    def __init__(self, name=None, remapping_dict=None):
        """Creates CocoVid object.

        Args:
            name: str
                Name of the CocoVid dataset, it determines exported json name.
            remapping_dict: dict
                {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1
        """
        self.name = name
        self.remapping_dict = remapping_dict
        self.categories = []
        self.videos = []

    def add_categories_from_coco_category_list(self, coco_category_list):
        """Creates CocoCategory object using coco category list.

        Args:
            coco_category_list: List[Dict]
                [
                    {"supercategory": "person", "id": 1, "name": "person"},
                    {"supercategory": "vehicle", "id": 2, "name": "bicycle"}
                ]
        """

        for coco_category in coco_category_list:
            if self.remapping_dict is not None:
                for source_id in self.remapping_dict.keys():
                    if coco_category["id"] == source_id:
                        target_id = self.remapping_dict[source_id]
                        coco_category["id"] = target_id

            self.add_category(CocoCategory.from_coco_category(coco_category))

    def add_category(self, category: CocoCategory):
        """Adds category to this CocoVid instance.

        Args:
            category: CocoCategory
        """

        if not isinstance(category, CocoCategory):
            raise TypeError("category must be a CocoCategory instance")  # type: ignore
        self.categories.append(category)

    @property
    def json_categories(self):
        categories = []
        for category in self.categories:
            categories.append(category.json)
        return categories

    @property
    def category_mapping(self):
        category_mapping = {}
        for category in self.categories:
            category_mapping[category.id] = category.name
        return category_mapping

    def add_video(self, video: CocoVideo):
        """Adds video to this CocoVid instance.

        Args:
            video: CocoVideo
        """

        if not isinstance(video, CocoVideo):
            raise TypeError("video must be a CocoVideo instance")  # type: ignore
        self.videos.append(video)

    @property
    def json(self):
        coco_dict = {
            "videos": [],
            "images": [],
            "annotations": [],
            "categories": self.json_categories,
        }
        annotation_id = 1
        image_id = 1
        video_id = 1
        global_instance_id = 1
        for coco_video in self.videos:
            coco_video.id = video_id
            coco_dict["videos"].append(coco_video.json)

            frame_id = 0
            instance_id_set = set()
            for cocovid_image in coco_video.images:
                cocovid_image.id = image_id
                cocovid_image.frame_id = frame_id
                cocovid_image.video_id = coco_video.id
                coco_dict["images"].append(cocovid_image.json)

                for cocovid_annotation in cocovid_image.annotations:
                    instance_id_set.add(cocovid_annotation.instance_id)
                    cocovid_annotation.instance_id += global_instance_id

                    cocovid_annotation.id = annotation_id
                    cocovid_annotation.image_id = cocovid_image.id
                    coco_dict["annotations"].append(cocovid_annotation.json)

                    # increment annotation_id
                    annotation_id = copy.deepcopy(annotation_id + 1)
                # increment image_id and frame_id
                image_id = copy.deepcopy(image_id + 1)
                frame_id = copy.deepcopy(frame_id + 1)
            # increment video_id and global_instance_id
            video_id = copy.deepcopy(video_id + 1)
            global_instance_id += len(instance_id_set)

        return coco_dict
Functions
__init__(name=None, remapping_dict=None)

Creates CocoVid object.

Parameters:

Name Type Description Default
name

str Name of the CocoVid dataset, it determines exported json name.

None
remapping_dict

dict {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1

None
Source code in sahi/utils/coco.py
def __init__(self, name=None, remapping_dict=None):
    """Creates CocoVid object.

    Args:
        name: str
            Name of the CocoVid dataset, it determines exported json name.
        remapping_dict: dict
            {1:0, 2:1} maps category id 1 to 0 and category id 2 to 1
    """
    self.name = name
    self.remapping_dict = remapping_dict
    self.categories = []
    self.videos = []
add_categories_from_coco_category_list(coco_category_list)

Creates CocoCategory object using coco category list.

Parameters:

Name Type Description Default
coco_category_list

List[Dict] [ {"supercategory": "person", "id": 1, "name": "person"}, {"supercategory": "vehicle", "id": 2, "name": "bicycle"} ]

required
Source code in sahi/utils/coco.py
def add_categories_from_coco_category_list(self, coco_category_list):
    """Creates CocoCategory object using coco category list.

    Args:
        coco_category_list: List[Dict]
            [
                {"supercategory": "person", "id": 1, "name": "person"},
                {"supercategory": "vehicle", "id": 2, "name": "bicycle"}
            ]
    """

    for coco_category in coco_category_list:
        if self.remapping_dict is not None:
            for source_id in self.remapping_dict.keys():
                if coco_category["id"] == source_id:
                    target_id = self.remapping_dict[source_id]
                    coco_category["id"] = target_id

        self.add_category(CocoCategory.from_coco_category(coco_category))
add_category(category)

Adds category to this CocoVid instance.

Parameters:

Name Type Description Default
category CocoCategory

CocoCategory

required
Source code in sahi/utils/coco.py
def add_category(self, category: CocoCategory):
    """Adds category to this CocoVid instance.

    Args:
        category: CocoCategory
    """

    if not isinstance(category, CocoCategory):
        raise TypeError("category must be a CocoCategory instance")  # type: ignore
    self.categories.append(category)
add_video(video)

Adds video to this CocoVid instance.

Parameters:

Name Type Description Default
video CocoVideo

CocoVideo

required
Source code in sahi/utils/coco.py
def add_video(self, video: CocoVideo):
    """Adds video to this CocoVid instance.

    Args:
        video: CocoVideo
    """

    if not isinstance(video, CocoVideo):
        raise TypeError("video must be a CocoVideo instance")  # type: ignore
    self.videos.append(video)
CocoVidAnnotation

Bases: CocoAnnotation

COCOVid formatted annotation.

https://github.com/open-mmlab/mmtracking/blob/master/docs/tutorials/customize_dataset.md#the-cocovid-annotation-file

Source code in sahi/utils/coco.py
class CocoVidAnnotation(CocoAnnotation):
    """COCOVid formatted annotation.

    https://github.com/open-mmlab/mmtracking/blob/master/docs/tutorials/customize_dataset.md#the-cocovid-annotation-file
    """

    def __init__(
        self,
        category_id: int,
        category_name: str,
        bbox: list[int],
        image_id=None,
        instance_id=None,
        iscrowd=0,
        id=None,
    ):
        """
        Args:
            bbox: List
                [xmin, ymin, width, height]
            category_id: int
                Category id of the annotation
            category_name: str
                Category name of the annotation
            image_id: int
                Image ID of the annotation
            instance_id: int
                Used for tracking
            iscrowd: int
                0 or 1
            id: int
                Annotation id
        """
        super().__init__(
            bbox=bbox,
            category_id=category_id,
            category_name=category_name,
            image_id=image_id,
            iscrowd=iscrowd,
        )
        self.instance_id = instance_id
        self.id = id

    @property
    def json(self):
        return {
            "id": self.id,
            "image_id": self.image_id,
            "bbox": self.bbox,
            "segmentation": self.segmentation,
            "category_id": self.category_id,
            "category_name": self.category_name,
            "instance_id": self.instance_id,
            "iscrowd": self.iscrowd,
            "area": self.area,
        }

    def __repr__(self):
        return f"""CocoAnnotation<
    id: {self.id},
    image_id: {self.image_id},
    bbox: {self.bbox},
    segmentation: {self.segmentation},
    category_id: {self.category_id},
    category_name: {self.category_name},
    instance_id: {self.instance_id},
    iscrowd: {self.iscrowd},
    area: {self.area}>"""
Functions
__init__(category_id, category_name, bbox, image_id=None, instance_id=None, iscrowd=0, id=None)

Parameters:

Name Type Description Default
bbox list[int]

List [xmin, ymin, width, height]

required
category_id int

int Category id of the annotation

required
category_name str

str Category name of the annotation

required
image_id

int Image ID of the annotation

None
instance_id

int Used for tracking

None
iscrowd

int 0 or 1

0
id

int Annotation id

None
Source code in sahi/utils/coco.py
def __init__(
    self,
    category_id: int,
    category_name: str,
    bbox: list[int],
    image_id=None,
    instance_id=None,
    iscrowd=0,
    id=None,
):
    """
    Args:
        bbox: List
            [xmin, ymin, width, height]
        category_id: int
            Category id of the annotation
        category_name: str
            Category name of the annotation
        image_id: int
            Image ID of the annotation
        instance_id: int
            Used for tracking
        iscrowd: int
            0 or 1
        id: int
            Annotation id
    """
    super().__init__(
        bbox=bbox,
        category_id=category_id,
        category_name=category_name,
        image_id=image_id,
        iscrowd=iscrowd,
    )
    self.instance_id = instance_id
    self.id = id
CocoVidImage

Bases: CocoImage

COCOVid formatted image.

https://github.com/open-mmlab/mmtracking/blob/master/docs/tutorials/customize_dataset.md#the-cocovid-annotation-file

Source code in sahi/utils/coco.py
class CocoVidImage(CocoImage):
    """COCOVid formatted image.

    https://github.com/open-mmlab/mmtracking/blob/master/docs/tutorials/customize_dataset.md#the-cocovid-annotation-file
    """

    def __init__(
        self,
        file_name,
        height,
        width,
        video_id=None,
        frame_id=None,
        id=None,
    ):
        """Creates CocoVidImage object.

        Args:
            id: int
                Image id
            file_name: str
                Image path
            height: int
                Image height in pixels
            width: int
                Image width in pixels
            frame_id: int
                0-indexed frame id
            video_id: int
                Video id
        """
        super().__init__(file_name=file_name, height=height, width=width, id=id)
        self.frame_id = frame_id
        self.video_id = video_id

    @classmethod
    def from_coco_image(cls, coco_image, video_id=None, frame_id=None):
        """Creates CocoVidImage object using CocoImage object.

        Args:
            coco_image: CocoImage
            frame_id: int
                0-indexed frame id
            video_id: int
                Video id
        """
        return cls(
            file_name=coco_image.file_name,
            height=coco_image.height,
            width=coco_image.width,
            id=coco_image.id,
            video_id=video_id,
            frame_id=frame_id,
        )

    def add_annotation(self, annotation):
        """
        Adds annotation to this CocoImage instance
        annotation : CocoVidAnnotation
        """

        if not isinstance(annotation, CocoVidAnnotation):
            raise TypeError("annotation must be a CocoVidAnnotation instance")
        self.annotations.append(annotation)

    @property
    def json(self):
        return {
            "file_name": self.file_name,
            "height": self.height,
            "width": self.width,
            "id": self.id,
            "video_id": self.video_id,
            "frame_id": self.frame_id,
        }

    def __repr__(self):
        return f"""CocoVidImage<
    file_name: {self.file_name},
    height: {self.height},
    width: {self.width},
    id: {self.id},
    video_id: {self.video_id},
    frame_id: {self.frame_id},
    annotations: List[CocoVidAnnotation]>"""
Functions
__init__(file_name, height, width, video_id=None, frame_id=None, id=None)

Creates CocoVidImage object.

Parameters:

Name Type Description Default
id

int Image id

None
file_name

str Image path

required
height

int Image height in pixels

required
width

int Image width in pixels

required
frame_id

int 0-indexed frame id

None
video_id

int Video id

None
Source code in sahi/utils/coco.py
def __init__(
    self,
    file_name,
    height,
    width,
    video_id=None,
    frame_id=None,
    id=None,
):
    """Creates CocoVidImage object.

    Args:
        id: int
            Image id
        file_name: str
            Image path
        height: int
            Image height in pixels
        width: int
            Image width in pixels
        frame_id: int
            0-indexed frame id
        video_id: int
            Video id
    """
    super().__init__(file_name=file_name, height=height, width=width, id=id)
    self.frame_id = frame_id
    self.video_id = video_id
add_annotation(annotation)

Adds annotation to this CocoImage instance annotation : CocoVidAnnotation

Source code in sahi/utils/coco.py
def add_annotation(self, annotation):
    """
    Adds annotation to this CocoImage instance
    annotation : CocoVidAnnotation
    """

    if not isinstance(annotation, CocoVidAnnotation):
        raise TypeError("annotation must be a CocoVidAnnotation instance")
    self.annotations.append(annotation)
from_coco_image(coco_image, video_id=None, frame_id=None) classmethod

Creates CocoVidImage object using CocoImage object.

Parameters:

Name Type Description Default
coco_image

CocoImage

required
frame_id

int 0-indexed frame id

None
video_id

int Video id

None
Source code in sahi/utils/coco.py
@classmethod
def from_coco_image(cls, coco_image, video_id=None, frame_id=None):
    """Creates CocoVidImage object using CocoImage object.

    Args:
        coco_image: CocoImage
        frame_id: int
            0-indexed frame id
        video_id: int
            Video id
    """
    return cls(
        file_name=coco_image.file_name,
        height=coco_image.height,
        width=coco_image.width,
        id=coco_image.id,
        video_id=video_id,
        frame_id=frame_id,
    )
CocoVideo

COCO formatted video.

https://github.com/open-mmlab/mmtracking/blob/master/docs/tutorials/customize_dataset.md#the-cocovid-annotation-file

Source code in sahi/utils/coco.py
class CocoVideo:
    """COCO formatted video.

    https://github.com/open-mmlab/mmtracking/blob/master/docs/tutorials/customize_dataset.md#the-cocovid-annotation-file
    """

    def __init__(
        self,
        name: str,
        id: int | None = None,
        fps: float | None = None,
        height: int | None = None,
        width: int | None = None,
    ):
        """Creates CocoVideo object.

        Args:
            name: str
                Video name
            id: int
                Video id
            fps: float
                Video fps
            height: int
                Video height in pixels
            width: int
                Video width in pixels
        """
        self.name = name
        self.id = id
        self.fps = fps
        self.height = height
        self.width = width
        self.images = []  # list of CocoImage that belong to this video

    def add_image(self, image):
        """
        Adds image to this CocoVideo instance
        Args:
            image: CocoImage
        """

        if not isinstance(image, CocoImage):
            raise TypeError("image must be a CocoImage instance")
        self.images.append(CocoVidImage.from_coco_image(image))

    def add_cocovidimage(self, cocovidimage):
        """
        Adds CocoVidImage to this CocoVideo instance
        Args:
            cocovidimage: CocoVidImage
        """

        if not isinstance(cocovidimage, CocoVidImage):
            raise TypeError("cocovidimage must be a CocoVidImage instance")
        self.images.append(cocovidimage)

    @property
    def json(self):
        return {
            "name": self.name,
            "id": self.id,
            "fps": self.fps,
            "height": self.height,
            "width": self.width,
        }

    def __repr__(self):
        return f"""CocoVideo<
    id: {self.id},
    name: {self.name},
    fps: {self.fps},
    height: {self.height},
    width: {self.width},
    images: List[CocoVidImage]>"""
Functions
__init__(name, id=None, fps=None, height=None, width=None)

Creates CocoVideo object.

Parameters:

Name Type Description Default
name str

str Video name

required
id int | None

int Video id

None
fps float | None

float Video fps

None
height int | None

int Video height in pixels

None
width int | None

int Video width in pixels

None
Source code in sahi/utils/coco.py
def __init__(
    self,
    name: str,
    id: int | None = None,
    fps: float | None = None,
    height: int | None = None,
    width: int | None = None,
):
    """Creates CocoVideo object.

    Args:
        name: str
            Video name
        id: int
            Video id
        fps: float
            Video fps
        height: int
            Video height in pixels
        width: int
            Video width in pixels
    """
    self.name = name
    self.id = id
    self.fps = fps
    self.height = height
    self.width = width
    self.images = []  # list of CocoImage that belong to this video
add_cocovidimage(cocovidimage)

Adds CocoVidImage to this CocoVideo instance Args: cocovidimage: CocoVidImage

Source code in sahi/utils/coco.py
def add_cocovidimage(self, cocovidimage):
    """
    Adds CocoVidImage to this CocoVideo instance
    Args:
        cocovidimage: CocoVidImage
    """

    if not isinstance(cocovidimage, CocoVidImage):
        raise TypeError("cocovidimage must be a CocoVidImage instance")
    self.images.append(cocovidimage)
add_image(image)

Adds image to this CocoVideo instance Args: image: CocoImage

Source code in sahi/utils/coco.py
def add_image(self, image):
    """
    Adds image to this CocoVideo instance
    Args:
        image: CocoImage
    """

    if not isinstance(image, CocoImage):
        raise TypeError("image must be a CocoImage instance")
    self.images.append(CocoVidImage.from_coco_image(image))
DatasetClassCounts dataclass

Stores the number of images that include each category in a dataset.

Source code in sahi/utils/coco.py
@dataclass
class DatasetClassCounts:
    """Stores the number of images that include each category in a dataset."""

    counts: dict
    total_images: int

    def frequencies(self):
        """Calculates the frequency of images that contain each category."""
        return {cid: count / self.total_images for cid, count in self.counts.items()}

    def __add__(self, o):
        total = self.total_images + o.total_images
        exclusive_keys = set(o.counts.keys()) - set(self.counts.keys())
        counts = {}
        for k, v in self.counts.items():
            counts[k] = v + o.counts.get(k, 0)
        for k in exclusive_keys:
            counts[k] = o.counts[k]
        return DatasetClassCounts(counts, total)
Functions
frequencies()

Calculates the frequency of images that contain each category.

Source code in sahi/utils/coco.py
def frequencies(self):
    """Calculates the frequency of images that contain each category."""
    return {cid: count / self.total_images for cid, count in self.counts.items()}
Functions
add_bbox_and_area_to_coco(source_coco_path='', target_coco_path='', add_bbox=True, add_area=True)

Takes single coco dataset file path, calculates and fills bbox and area fields of the annotations and exports the updated coco dict.

coco_dict : dict Updated coco dict

Source code in sahi/utils/coco.py
def add_bbox_and_area_to_coco(
    source_coco_path: str = "",
    target_coco_path: str = "",
    add_bbox: bool = True,
    add_area: bool = True,
) -> dict:
    """Takes single coco dataset file path, calculates and fills bbox and area fields of the annotations and exports the
    updated coco dict.

    Returns:
    coco_dict : dict
        Updated coco dict
    """
    coco_dict = load_json(source_coco_path)
    coco_dict = copy.deepcopy(coco_dict)

    annotations = coco_dict["annotations"]
    for ind, annotation in enumerate(annotations):
        # assign annotation bbox
        if add_bbox:
            coco_polygons = []
            [coco_polygons.extend(coco_polygon) for coco_polygon in annotation["segmentation"]]
            minx, miny, maxx, maxy = list(
                [
                    min(coco_polygons[0::2]),
                    min(coco_polygons[1::2]),
                    max(coco_polygons[0::2]),
                    max(coco_polygons[1::2]),
                ]
            )
            x, y, width, height = (
                minx,
                miny,
                maxx - minx,
                maxy - miny,
            )
            annotations[ind]["bbox"] = [x, y, width, height]

        # assign annotation area
        if add_area:
            shapely_multipolygon = get_shapely_multipolygon(coco_segmentation=annotation["segmentation"])
            annotations[ind]["area"] = shapely_multipolygon.area

    coco_dict["annotations"] = annotations
    save_json(coco_dict, target_coco_path)
    return coco_dict
count_images_with_category(coco_file_path)

Reads a coco dataset file and returns an DatasetClassCounts object that stores the number of images that include each category in a dataset Returns: DatasetClassCounts object coco_file_path : str path to coco dataset file

Source code in sahi/utils/coco.py
def count_images_with_category(coco_file_path):
    """Reads a coco dataset file and returns an DatasetClassCounts object
     that stores the number of images that include each category in a dataset
    Returns: DatasetClassCounts object
    coco_file_path : str
        path to coco dataset file
    """

    image_id_2_category_2_count = defaultdict(lambda: defaultdict(int))
    coco = load_json(coco_file_path)
    for annotation in coco["annotations"]:
        image_id = annotation["image_id"]
        cid = annotation["category_id"]
        image_id_2_category_2_count[image_id][cid] = image_id_2_category_2_count[image_id][cid] + 1

    category_2_count = defaultdict(int)
    for image_id, image_category_2_count in image_id_2_category_2_count.items():
        for cid, count in image_category_2_count.items():
            if count > 0:
                category_2_count[cid] = category_2_count[cid] + 1

    category_2_count = dict(category_2_count)
    total_images = len(image_id_2_category_2_count.keys())
    return DatasetClassCounts(category_2_count, total_images)
create_coco_dict(images, categories, ignore_negative_samples=False, image_id_setting='auto')

Creates COCO dict with fields "images", "annotations", "categories".

Args

images : List of CocoImage containing a list of CocoAnnotation
categories : List of Dict
    COCO categories
ignore_negative_samples : Bool
    If True, images without annotations are ignored
image_id_setting: str
    how to assign image ids while exporting can be
        auto --> will assign id from scratch (<CocoImage>.id will be ignored)
        manual --> you will need to provide image ids in <CocoImage> instances (<CocoImage>.id can not be None)

Returns

coco_dict : Dict
    COCO dict with fields "images", "annotations", "categories"
Source code in sahi/utils/coco.py
def create_coco_dict(images, categories, ignore_negative_samples=False, image_id_setting="auto"):
    """Creates COCO dict with fields "images", "annotations", "categories".

    Args

        images : List of CocoImage containing a list of CocoAnnotation
        categories : List of Dict
            COCO categories
        ignore_negative_samples : Bool
            If True, images without annotations are ignored
        image_id_setting: str
            how to assign image ids while exporting can be
                auto --> will assign id from scratch (<CocoImage>.id will be ignored)
                manual --> you will need to provide image ids in <CocoImage> instances (<CocoImage>.id can not be None)
    Returns

        coco_dict : Dict
            COCO dict with fields "images", "annotations", "categories"
    """
    # assertion of parameters
    if image_id_setting not in ["auto", "manual"]:
        raise ValueError("'image_id_setting' should be one of ['auto', 'manual']")

    # define accumulators
    image_index = 1
    annotation_id = 1
    coco_dict = dict(images=[], annotations=[], categories=categories)
    for coco_image in images:
        # get coco annotations
        coco_annotations = coco_image.annotations
        # get num annotations
        num_annotations = len(coco_annotations)
        # if ignore_negative_samples is True and no annotations, skip image
        if ignore_negative_samples and num_annotations == 0:
            continue
        else:
            # get image_id
            if image_id_setting == "auto":
                image_id = image_index
                image_index += 1
            elif image_id_setting == "manual":
                if coco_image.id is None:
                    raise ValueError("'coco_image.id' should be set manually when image_id_setting == 'manual'")
                image_id = coco_image.id

            # create coco image object
            out_image = {
                "height": coco_image.height,
                "width": coco_image.width,
                "id": image_id,
                "file_name": coco_image.file_name,
            }
            coco_dict["images"].append(out_image)

            # do the same for image annotations
            for coco_annotation in coco_annotations:
                # create coco annotation object
                out_annotation = {
                    "iscrowd": 0,
                    "image_id": image_id,
                    "bbox": coco_annotation.bbox,
                    "segmentation": coco_annotation.segmentation,
                    "category_id": coco_annotation.category_id,
                    "id": annotation_id,
                    "area": coco_annotation.area,
                }
                coco_dict["annotations"].append(out_annotation)
                # increment annotation id
                annotation_id += 1

    # return coco dict
    return coco_dict
create_coco_prediction_array(images, ignore_negative_samples=False, image_id_setting='auto')

Creates COCO prediction array which is list of predictions.

Args

images : List of CocoImage containing a list of CocoAnnotation
ignore_negative_samples : Bool
    If True, images without predictions are ignored
image_id_setting: str
    how to assign image ids while exporting can be
        auto --> will assign id from scratch (<CocoImage>.id will be ignored)
        manual --> you will need to provide image ids in <CocoImage> instances (<CocoImage>.id can not be None)

Returns

coco_prediction_array : List
    COCO predictions array
Source code in sahi/utils/coco.py
def create_coco_prediction_array(images, ignore_negative_samples=False, image_id_setting="auto"):
    """Creates COCO prediction array which is list of predictions.

    Args

        images : List of CocoImage containing a list of CocoAnnotation
        ignore_negative_samples : Bool
            If True, images without predictions are ignored
        image_id_setting: str
            how to assign image ids while exporting can be
                auto --> will assign id from scratch (<CocoImage>.id will be ignored)
                manual --> you will need to provide image ids in <CocoImage> instances (<CocoImage>.id can not be None)
    Returns

        coco_prediction_array : List
            COCO predictions array
    """
    # assertion of parameters
    if image_id_setting not in ["auto", "manual"]:
        raise ValueError("'image_id_setting' should be one of ['auto', 'manual']")
    # define accumulators
    image_index = 1
    prediction_id = 1
    predictions_array = []
    for coco_image in images:
        # get coco predictions
        coco_predictions = coco_image.predictions
        # get num predictions
        num_predictions = len(coco_predictions)
        # if ignore_negative_samples is True and no annotations, skip image
        if ignore_negative_samples and num_predictions == 0:
            continue
        else:
            # get image_id
            if image_id_setting == "auto":
                image_id = image_index
                image_index += 1
            elif image_id_setting == "manual":
                if coco_image.id is None:
                    raise ValueError("'coco_image.id' should be set manually when image_id_setting == 'manual'")
                image_id = coco_image.id

            # create coco prediction object
            for prediction_index, coco_prediction in enumerate(coco_predictions):
                # create coco prediction object
                out_prediction = {
                    "id": prediction_id,
                    "image_id": image_id,
                    "bbox": coco_prediction.bbox,
                    "score": coco_prediction.score,
                    "category_id": coco_prediction.category_id,
                    "segmentation": coco_prediction.segmentation,
                    "iscrowd": coco_prediction.iscrowd,
                    "area": coco_prediction.area,
                }
                predictions_array.append(out_prediction)

                # increment prediction id
                prediction_id += 1

    # return predictions array
    return predictions_array
export_coco_as_yolo(output_dir, train_coco=None, val_coco=None, train_split_rate=0.9, numpy_seed=0, disable_symlink=False)

Exports current COCO dataset in ultralytics/YOLO format. Creates train val folders with image symlinks and txt files and a data yaml file.

Parameters:

Name Type Description Default
output_dir str

str Export directory.

required
train_coco Coco | None

Coco coco object for training

None
val_coco Coco | None

Coco coco object for val

None
train_split_rate float

float train split rate between 0 and 1. will be used when val_coco is None.

0.9
numpy_seed

int To fix the numpy seed.

0
disable_symlink

bool If True, copy images instead of creating symlinks.

False

Returns:

Name Type Description
yaml_path

str Path for the exported YOLO data.yml

Source code in sahi/utils/coco.py
def export_coco_as_yolo(
    output_dir: str,
    train_coco: Coco | None = None,
    val_coco: Coco | None = None,
    train_split_rate: float = 0.9,
    numpy_seed=0,
    disable_symlink=False,
):
    """Exports current COCO dataset in ultralytics/YOLO format. Creates train val folders with image symlinks and txt
    files and a data yaml file.

    Args:
        output_dir: str
            Export directory.
        train_coco: Coco
            coco object for training
        val_coco: Coco
            coco object for val
        train_split_rate: float
            train split rate between 0 and 1. will be used when val_coco is None.
        numpy_seed: int
            To fix the numpy seed.
        disable_symlink: bool
            If True, copy images instead of creating symlinks.

    Returns:
        yaml_path: str
            Path for the exported YOLO data.yml
    """
    try:
        import yaml
    except ImportError:
        raise ImportError('Please run "pip install -U pyyaml" to install yaml first for YOLO formatted exporting.')

    # set split_mode
    if train_coco and not val_coco:
        split_mode = True
    elif train_coco and val_coco:
        split_mode = False
    else:
        raise ValueError("'train_coco' have to be provided")

    # check train_split_rate
    if split_mode and not (0 < train_split_rate < 1):
        raise ValueError("train_split_rate cannot be <0 or >1")

    # split dataset
    if split_mode:
        result = train_coco.split_coco_as_train_val(
            train_split_rate=train_split_rate,
            numpy_seed=numpy_seed,
        )
        train_coco = result["train_coco"]
        val_coco = result["val_coco"]

    # create train val image dirs
    train_dir = Path(os.path.abspath(output_dir)) / "train/"
    train_dir.mkdir(parents=True, exist_ok=True)  # create dir
    val_dir = Path(os.path.abspath(output_dir)) / "val/"
    val_dir.mkdir(parents=True, exist_ok=True)  # create dir

    # create image symlinks and annotation txts
    export_yolo_images_and_txts_from_coco_object(
        output_dir=train_dir,
        coco=train_coco,
        ignore_negative_samples=train_coco.ignore_negative_samples,
        mp=False,
        disable_symlink=disable_symlink,
    )
    assert val_coco, "Validation Coco object not set"
    export_yolo_images_and_txts_from_coco_object(
        output_dir=val_dir,
        coco=val_coco,
        ignore_negative_samples=val_coco.ignore_negative_samples,
        mp=False,
        disable_symlink=disable_symlink,
    )

    # create yolov5 data yaml
    data = {
        "train": str(train_dir).replace("\\", "/"),
        "val": str(val_dir).replace("\\", "/"),
        "nc": len(train_coco.category_mapping),
        "names": list(train_coco.category_mapping.values()),
    }
    yaml_path = str(Path(output_dir) / "data.yml")
    with open(yaml_path, "w") as outfile:
        yaml.dump(data, outfile, default_flow_style=False)

    return yaml_path
export_coco_as_yolo_via_yml(yml_path, output_dir, train_split_rate=0.9, numpy_seed=0, disable_symlink=False)

Exports current COCO dataset in ultralytics/YOLO format. Creates train val folders with image symlinks and txt files and a data yaml file. Uses a yml file as input.

Parameters:

Name Type Description Default
yml_path str

str file should contain these fields: train_json_path: str train_image_dir: str val_json_path: str val_image_dir: str

required
output_dir str

str Export directory.

required
train_split_rate float

float train split rate between 0 and 1. will be used when val_json_path is None.

0.9
numpy_seed

int To fix the numpy seed.

0
disable_symlink

bool If True, copy images instead of creating symlinks.

False

Returns:

Name Type Description
yaml_path

str Path for the exported YOLO data.yml

Source code in sahi/utils/coco.py
def export_coco_as_yolo_via_yml(
    yml_path: str, output_dir: str, train_split_rate: float = 0.9, numpy_seed=0, disable_symlink=False
):
    """Exports current COCO dataset in ultralytics/YOLO format. Creates train val folders with image symlinks and txt
    files and a data yaml file. Uses a yml file as input.

    Args:
        yml_path: str
            file should contain these fields:
                train_json_path: str
                train_image_dir: str
                val_json_path: str
                val_image_dir: str
        output_dir: str
            Export directory.
        train_split_rate: float
            train split rate between 0 and 1. will be used when val_json_path is None.
        numpy_seed: int
            To fix the numpy seed.
        disable_symlink: bool
            If True, copy images instead of creating symlinks.

    Returns:
        yaml_path: str
            Path for the exported YOLO data.yml
    """
    try:
        import yaml
    except ImportError:
        raise ImportError('Please run "pip install -U pyyaml" to install yaml first for YOLO formatted exporting.')

    with open(yml_path) as stream:
        config_dict = yaml.safe_load(stream)

    if config_dict["train_json_path"]:
        if not config_dict["train_image_dir"]:
            raise ValueError(f"{yml_path} is missing `train_image_dir`")
        train_coco = Coco.from_coco_dict_or_path(
            config_dict["train_json_path"], image_dir=config_dict["train_image_dir"]
        )
    else:
        train_coco = None

    if config_dict["val_json_path"]:
        if not config_dict["val_image_dir"]:
            raise ValueError(f"{yml_path} is missing `val_image_dir`")
        val_coco = Coco.from_coco_dict_or_path(config_dict["val_json_path"], image_dir=config_dict["val_image_dir"])
    else:
        val_coco = None

    yaml_path = export_coco_as_yolo(
        output_dir=output_dir,
        train_coco=train_coco,
        val_coco=val_coco,
        train_split_rate=train_split_rate,
        numpy_seed=numpy_seed,
        disable_symlink=disable_symlink,
    )

    return yaml_path
export_coco_as_yolov5(output_dir, train_coco=None, val_coco=None, train_split_rate=0.9, numpy_seed=0, disable_symlink=False)

Deprecated.

Please use export_coco_as_yolo instead. Calls export_coco_as_yolo with the same arguments.

Source code in sahi/utils/coco.py
def export_coco_as_yolov5(
    output_dir: str,
    train_coco: Coco | None = None,
    val_coco: Coco | None = None,
    train_split_rate: float = 0.9,
    numpy_seed=0,
    disable_symlink=False,
):
    """Deprecated.

    Please use export_coco_as_yolo instead. Calls export_coco_as_yolo with the same arguments.
    """
    warnings.warn(
        "export_coco_as_yolov5 is deprecated. Please use export_coco_as_yolo instead.",
        DeprecationWarning,
    )
    export_coco_as_yolo(
        output_dir=output_dir,
        train_coco=train_coco,
        val_coco=val_coco,
        train_split_rate=train_split_rate,
        numpy_seed=numpy_seed,
        disable_symlink=disable_symlink,
    )
export_coco_as_yolov5_via_yml(yml_path, output_dir, train_split_rate=0.9, numpy_seed=0, disable_symlink=False)

Deprecated.

Please use export_coco_as_yolo_via_yml instead. Calls export_coco_as_yolo_via_yml with the same arguments.

Source code in sahi/utils/coco.py
def export_coco_as_yolov5_via_yml(
    yml_path: str, output_dir: str, train_split_rate: float = 0.9, numpy_seed=0, disable_symlink=False
):
    """Deprecated.

    Please use export_coco_as_yolo_via_yml instead. Calls export_coco_as_yolo_via_yml with the same arguments.
    """
    warnings.warn(
        "export_coco_as_yolov5_via_yml is deprecated. Please use export_coco_as_yolo_via_yml instead.",
        DeprecationWarning,
    )
    export_coco_as_yolo_via_yml(
        yml_path=yml_path,
        output_dir=output_dir,
        train_split_rate=train_split_rate,
        numpy_seed=numpy_seed,
        disable_symlink=disable_symlink,
    )
export_single_yolo_image_and_corresponding_txt(coco_image, coco_image_dir, output_dir, ignore_negative_samples=False, disable_symlink=False)

Generates YOLO formatted image symlink and annotation txt file.

Parameters:

Name Type Description Default
coco_image

sahi.utils.coco.CocoImage

required
coco_image_dir

str

required
output_dir

str Export directory.

required
ignore_negative_samples

bool If True ignores images without annotations in all operations.

False
Source code in sahi/utils/coco.py
def export_single_yolo_image_and_corresponding_txt(
    coco_image, coco_image_dir, output_dir, ignore_negative_samples=False, disable_symlink=False
):
    """Generates YOLO formatted image symlink and annotation txt file.

    Args:
        coco_image: sahi.utils.coco.CocoImage
        coco_image_dir: str
        output_dir: str
            Export directory.
        ignore_negative_samples: bool
            If True ignores images without annotations in all operations.
    """
    # if coco_image contains any invalid annotations, skip it
    contains_invalid_annotations = False
    for coco_annotation in coco_image.annotations:
        if len(coco_annotation.bbox) != 4:
            contains_invalid_annotations = True
            break
    if contains_invalid_annotations:
        return
    # skip images without annotations
    if len(coco_image.annotations) == 0 and ignore_negative_samples:
        return
    # skip images without suffix
    # https://github.com/obss/sahi/issues/114
    if Path(coco_image.file_name).suffix == "":
        print(f"image file has no suffix, skipping it: '{coco_image.file_name}'")
        return
    elif Path(coco_image.file_name).suffix in [".txt"]:  # TODO: extend this list
        print(f"image file has incorrect suffix, skipping it: '{coco_image.file_name}'")
        return
    # set coco and yolo image paths
    if Path(coco_image.file_name).is_file():
        coco_image_path = os.path.abspath(coco_image.file_name)
    else:
        if coco_image_dir is None:
            raise ValueError("You have to specify image_dir of Coco object for yolo conversion.")

        coco_image_path = os.path.abspath(str(Path(coco_image_dir) / coco_image.file_name))

    yolo_image_path_temp = str(Path(output_dir) / Path(coco_image.file_name).name)
    # increment target file name if already present
    yolo_image_path = copy.deepcopy(yolo_image_path_temp)
    name_increment = 2
    while Path(yolo_image_path).is_file():
        parent_dir = Path(yolo_image_path_temp).parent
        filename = Path(yolo_image_path_temp).stem
        filesuffix = Path(yolo_image_path_temp).suffix
        filename = filename + "_" + str(name_increment)
        yolo_image_path = str(parent_dir / (filename + filesuffix))
        name_increment += 1
    # create a symbolic link pointing to coco_image_path named yolo_image_path
    if disable_symlink:
        import shutil

        shutil.copy(coco_image_path, yolo_image_path)
    else:
        os.symlink(coco_image_path, yolo_image_path)
    # calculate annotation normalization ratios
    width = coco_image.width
    height = coco_image.height
    dw = 1.0 / (width)
    dh = 1.0 / (height)
    # set annotation filepath
    image_file_suffix = Path(yolo_image_path).suffix
    yolo_annotation_path = yolo_image_path.replace(image_file_suffix, ".txt")
    # create annotation file
    annotations = coco_image.annotations
    with open(yolo_annotation_path, "w") as outfile:
        for annotation in annotations:
            # convert coco bbox to yolo bbox
            x_center = annotation.bbox[0] + annotation.bbox[2] / 2.0
            y_center = annotation.bbox[1] + annotation.bbox[3] / 2.0
            bbox_width = annotation.bbox[2]
            bbox_height = annotation.bbox[3]
            x_center = x_center * dw
            y_center = y_center * dh
            bbox_width = bbox_width * dw
            bbox_height = bbox_height * dh
            category_id = annotation.category_id
            yolo_bbox = (x_center, y_center, bbox_width, bbox_height)
            # save yolo annotation
            outfile.write(str(category_id) + " " + " ".join([str(value) for value in yolo_bbox]) + "\n")
export_yolo_images_and_txts_from_coco_object(output_dir, coco, ignore_negative_samples=False, mp=False, disable_symlink=False)

Creates image symlinks and annotation txts in yolo format from coco dataset.

Parameters:

Name Type Description Default
output_dir

str Export directory.

required
coco

sahi.utils.coco.Coco Initialized Coco object that contains images and categories.

required
ignore_negative_samples

bool If True ignores images without annotations in all operations.

False
mp

bool If True, multiprocess mode is on. Should be called in 'if name == main:' block.

False
disable_symlink

bool If True, symlinks are not created. Instead images are copied.

False
Source code in sahi/utils/coco.py
def export_yolo_images_and_txts_from_coco_object(
    output_dir, coco, ignore_negative_samples=False, mp=False, disable_symlink=False
):
    """Creates image symlinks and annotation txts in yolo format from coco dataset.

    Args:
        output_dir: str
            Export directory.
        coco: sahi.utils.coco.Coco
            Initialized Coco object that contains images and categories.
        ignore_negative_samples: bool
            If True ignores images without annotations in all operations.
        mp: bool
            If True, multiprocess mode is on.
            Should be called in 'if __name__ == __main__:' block.
        disable_symlink: bool
            If True, symlinks are not created. Instead images are copied.
    """
    logger.info("generating image symlinks and annotation files for yolo...")
    # symlink is not supported in colab
    if is_colab() and not disable_symlink:
        logger.warning("symlink is not supported in colab, disabling it...")
        disable_symlink = True
    if mp:
        with Pool(processes=48) as pool:
            args = [
                (coco_image, coco.image_dir, output_dir, ignore_negative_samples, disable_symlink)
                for coco_image in coco.images
            ]
            pool.starmap(
                export_single_yolo_image_and_corresponding_txt,
                tqdm(args, total=len(args)),
            )
    else:
        for coco_image in tqdm(coco.images):
            export_single_yolo_image_and_corresponding_txt(
                coco_image, coco.image_dir, output_dir, ignore_negative_samples, disable_symlink
            )
get_imageid2annotationlist_mapping(coco_dict)

Get image_id to annotationlist mapping for faster indexing.

Args

coco_dict : dict
    coco dict with fields "images", "annotations", "categories"

Returns

image_id_to_annotation_list : dict
{
    1: [CocoAnnotation, CocoAnnotation, CocoAnnotation],
    2: [CocoAnnotation]
}

where
CocoAnnotation = {
    'area': 2795520,
    'bbox': [491.0, 1035.0, 153.0, 182.0],
    'category_id': 1,
    'id': 1,
    'image_id': 1,
    'iscrowd': 0,
    'segmentation': [[491.0, 1035.0, 644.0, 1035.0, 644.0, 1217.0, 491.0, 1217.0]]
}
Source code in sahi/utils/coco.py
def get_imageid2annotationlist_mapping(coco_dict: dict) -> dict[int, list[CocoAnnotation]]:
    """Get image_id to annotationlist mapping for faster indexing.

    Args

        coco_dict : dict
            coco dict with fields "images", "annotations", "categories"
    Returns

        image_id_to_annotation_list : dict
        {
            1: [CocoAnnotation, CocoAnnotation, CocoAnnotation],
            2: [CocoAnnotation]
        }

        where
        CocoAnnotation = {
            'area': 2795520,
            'bbox': [491.0, 1035.0, 153.0, 182.0],
            'category_id': 1,
            'id': 1,
            'image_id': 1,
            'iscrowd': 0,
            'segmentation': [[491.0, 1035.0, 644.0, 1035.0, 644.0, 1217.0, 491.0, 1217.0]]
        }
    """
    image_id_to_annotation_list: dict = defaultdict(list)
    logger.debug("indexing coco dataset annotations...")
    for annotation in coco_dict["annotations"]:
        image_id = annotation["image_id"]
        image_id_to_annotation_list[image_id].append(annotation)

    return image_id_to_annotation_list
merge(coco_dict1, coco_dict2, desired_name2id=None)

Combines 2 coco formatted annotations dicts, and returns the combined coco dict.

Parameters:

Name Type Description Default
coco_dict1

dict First coco dictionary.

required
coco_dict2

dict Second coco dictionary.

required
desired_name2id

dict

required

Returns: merged_coco_dict : dict Merged COCO dict.

Source code in sahi/utils/coco.py
def merge(coco_dict1: dict, coco_dict2: dict, desired_name2id: dict | None = None) -> dict:
    """Combines 2 coco formatted annotations dicts, and returns the combined coco dict.

    Args:
        coco_dict1 : dict
            First coco dictionary.
        coco_dict2 : dict
            Second coco dictionary.
        desired_name2id : dict
            {"human": 1, "car": 2, "big_vehicle": 3}
    Returns:
        merged_coco_dict : dict
            Merged COCO dict.
    """

    # copy input dicts so that original dicts are not affected
    temp_coco_dict1 = copy.deepcopy(coco_dict1)
    temp_coco_dict2 = copy.deepcopy(coco_dict2)

    # rearrange categories if any desired_name2id mapping is given
    if desired_name2id is not None:
        temp_coco_dict1 = update_categories(desired_name2id, temp_coco_dict1)
        temp_coco_dict2 = update_categories(desired_name2id, temp_coco_dict2)

    # rearrange categories of the second coco based on first, if their categories are not the same
    if temp_coco_dict1["categories"] != temp_coco_dict2["categories"]:
        desired_name2id = {category["name"]: category["id"] for category in temp_coco_dict1["categories"]}
        temp_coco_dict2 = update_categories(desired_name2id, temp_coco_dict2)

    # calculate first image and annotation index of the second coco file
    max_image_id = np.array([image["id"] for image in coco_dict1["images"]]).max()
    max_annotation_id = np.array([annotation["id"] for annotation in coco_dict1["annotations"]]).max()

    merged_coco_dict = temp_coco_dict1

    for image in temp_coco_dict2["images"]:
        image["id"] += max_image_id + 1
        merged_coco_dict["images"].append(image)

    for annotation in temp_coco_dict2["annotations"]:
        annotation["image_id"] += max_image_id + 1
        annotation["id"] += max_annotation_id + 1
        merged_coco_dict["annotations"].append(annotation)

    return merged_coco_dict
merge_from_file(coco_path1, coco_path2, save_path)

Combines 2 coco formatted annotations files given their paths, and saves the combined file to save_path.

Args:

coco_path1 : str
    Path for the first coco file.
coco_path2 : str
    Path for the second coco file.
save_path : str
    "dirname/coco.json"
Source code in sahi/utils/coco.py
def merge_from_file(coco_path1: str, coco_path2: str, save_path: str):
    """Combines 2 coco formatted annotations files given their paths, and saves the combined file to save_path.

    Args:

        coco_path1 : str
            Path for the first coco file.
        coco_path2 : str
            Path for the second coco file.
        save_path : str
            "dirname/coco.json"
    """

    # load coco files to be combined
    coco_dict1 = load_json(coco_path1)
    coco_dict2 = load_json(coco_path2)

    # merge coco dicts
    merged_coco_dict = merge(coco_dict1, coco_dict2)

    # save merged coco dict
    save_json(merged_coco_dict, save_path)
merge_from_list(coco_dict_list, desired_name2id=None, verbose=1)

Combines a list of coco formatted annotations dicts, and returns the combined coco dict.

Args:

coco_dict_list: list of dict
    A list of coco dicts
desired_name2id: dict
    {"human": 1, "car": 2, "big_vehicle": 3}
verbose: bool
    If True, merging info is printed

Returns:

merged_coco_dict: dict
    Merged COCO dict.
Source code in sahi/utils/coco.py
def merge_from_list(coco_dict_list, desired_name2id=None, verbose=1):
    """Combines a list of coco formatted annotations dicts, and returns the combined coco dict.

    Args:

        coco_dict_list: list of dict
            A list of coco dicts
        desired_name2id: dict
            {"human": 1, "car": 2, "big_vehicle": 3}
        verbose: bool
            If True, merging info is printed
    Returns:

        merged_coco_dict: dict
            Merged COCO dict.
    """
    if verbose:
        if not desired_name2id:
            print("'desired_name2id' is not specified, combining all categories.")

    # create desired_name2id by combinin all categories, if desired_name2id is not specified
    if desired_name2id is None:
        desired_name2id = {}
        ind = 0
        for coco_dict in coco_dict_list:
            temp_categories = copy.deepcopy(coco_dict["categories"])
            for temp_category in temp_categories:
                if temp_category["name"] not in desired_name2id:
                    desired_name2id[temp_category["name"]] = ind
                    ind += 1
                else:
                    continue

    for ind, coco_dict in enumerate(coco_dict_list):
        if ind == 0:
            merged_coco_dict = copy.deepcopy(coco_dict)
        else:
            merged_coco_dict = merge(merged_coco_dict, coco_dict, desired_name2id)

    # print categories
    if verbose:
        print(
            "Categories are formed as:\n",
            merged_coco_dict["categories"],
        )

    return merged_coco_dict
remove_invalid_coco_results(result_list_or_path, dataset_dict_or_path=None)
Removes invalid predictions from coco result such as
  • negative bbox value
  • extreme bbox value

Parameters:

Name Type Description Default
result_list_or_path list | str

path or list for coco result json

required
dataset_dict_or_path optional

path or dict for coco dataset json

None
Source code in sahi/utils/coco.py
def remove_invalid_coco_results(result_list_or_path: list | str, dataset_dict_or_path: dict | str | None = None):
    """
    Removes invalid predictions from coco result such as:
        - negative bbox value
        - extreme bbox value

    Args:
        result_list_or_path: path or list for coco result json
        dataset_dict_or_path (optional): path or dict for coco dataset json
    """

    # prepare coco results
    if isinstance(result_list_or_path, str):
        result_list = load_json(result_list_or_path)
    elif isinstance(result_list_or_path, list):
        result_list = result_list_or_path
    else:
        raise TypeError('incorrect type for "result_list_or_path"')  # type: ignore

    # prepare image info from coco dataset
    if dataset_dict_or_path is not None:
        if isinstance(dataset_dict_or_path, str):
            dataset_dict = load_json(dataset_dict_or_path)
        elif isinstance(dataset_dict_or_path, dict):
            dataset_dict = dataset_dict_or_path
        else:
            raise TypeError('incorrect type for "dataset_dict"')  # type: ignore
        image_id_to_height = {}
        image_id_to_width = {}
        for coco_image in dataset_dict["images"]:
            image_id_to_height[coco_image["id"]] = coco_image["height"]
            image_id_to_width[coco_image["id"]] = coco_image["width"]

    # remove invalid predictions
    fixed_result_list = []
    for coco_result in result_list:
        bbox = coco_result["bbox"]
        # ignore invalid predictions
        if not bbox:
            print("ignoring invalid prediction with empty bbox")
            continue
        if bbox[0] < 0 or bbox[1] < 0 or bbox[2] < 0 or bbox[3] < 0:
            print(f"ignoring invalid prediction with bbox: {bbox}")
            continue
        if dataset_dict_or_path is not None:
            if (
                bbox[1] > image_id_to_height[coco_result["image_id"]]
                or bbox[3] > image_id_to_height[coco_result["image_id"]]
                or bbox[0] > image_id_to_width[coco_result["image_id"]]
                or bbox[2] > image_id_to_width[coco_result["image_id"]]
            ):
                print(f"ignoring invalid prediction with bbox: {bbox}")
                continue
        fixed_result_list.append(coco_result)
    return fixed_result_list
update_categories(desired_name2id, coco_dict)

Rearranges category mapping of given COCO dictionary based on given category_mapping. Can also be used to filter some of the categories.

Args:

desired_name2id : dict
    {"big_vehicle": 1, "car": 2, "human": 3}
coco_dict : dict
    COCO formatted dictionary.

Returns:

Name Type Description
coco_target dict

dict COCO dict with updated/filtered categories.

Source code in sahi/utils/coco.py
def update_categories(desired_name2id: dict, coco_dict: dict) -> dict:
    """Rearranges category mapping of given COCO dictionary based on given category_mapping. Can also be used to filter
    some of the categories.

    Args:

        desired_name2id : dict
            {"big_vehicle": 1, "car": 2, "human": 3}
        coco_dict : dict
            COCO formatted dictionary.

    Returns:
        coco_target : dict
            COCO dict with updated/filtered categories.
    """
    # so that original variable doesn't get affected
    coco_source = copy.deepcopy(coco_dict)

    # init target coco dict
    coco_target = {"images": [], "annotations": [], "categories": []}

    # init vars
    currentid2desiredid_mapping = {}
    # create category id mapping (currentid2desiredid_mapping)
    for category in coco_source["categories"]:
        current_category_id = category["id"]
        current_category_name = category["name"]
        if current_category_name in desired_name2id.keys():
            currentid2desiredid_mapping[current_category_id] = desired_name2id[current_category_name]
        else:
            # ignore categories that are not included in desired_name2id
            currentid2desiredid_mapping[current_category_id] = -1

    # update annotations
    for annotation in coco_source["annotations"]:
        current_category_id = annotation["category_id"]
        desired_category_id = currentid2desiredid_mapping[current_category_id]
        # append annotations with category id present in desired_name2id
        if desired_category_id != -1:
            # update cetegory id
            annotation["category_id"] = desired_category_id
            # append updated annotation to target coco dict
            coco_target["annotations"].append(annotation)

    # create desired categories
    categories = []
    for name in desired_name2id.keys():
        category = {}
        category["name"] = category["supercategory"] = name
        category["id"] = desired_name2id[name]
        categories.append(category)

    # update categories
    coco_target["categories"] = categories

    # update images
    coco_target["images"] = coco_source["images"]

    return coco_target
update_categories_from_file(desired_name2id, coco_path, save_path)

Rearranges category mapping of a COCO dictionary in coco_path based on given category_mapping. Can also be used to filter some of the categories.

Parameters:

Name Type Description Default
desired_name2id

dict

required
coco_path

str "dirname/coco.json"

required
Source code in sahi/utils/coco.py
def update_categories_from_file(desired_name2id: dict, coco_path: str, save_path: str) -> None:
    """Rearranges category mapping of a COCO dictionary in coco_path based on given category_mapping. Can also be used
    to filter some of the categories.

    Args:
        desired_name2id : dict
            {"human": 1, "car": 2, "big_vehicle": 3}
        coco_path : str
            "dirname/coco.json"
    """
    # load source coco dict
    coco_source = load_json(coco_path)

    # update categories
    coco_target = update_categories(desired_name2id, coco_source)

    # save modified coco file
    save_json(coco_target, save_path)
cv
Classes
Colors
Source code in sahi/utils/cv.py
class Colors:
    def __init__(self):
        hex_colors = (
            "FF3838 2C99A8 FF701F 6473FF CFD231 48F90A 92CC17 3DDB86 1A9334 00D4BB "
            "FF9D97 00C2FF 344593 FFB21D 0018EC 8438FF 520085 CB38FF FF95C8 FF37C7"
        )

        self.palette = [self.hex_to_rgb(f"#{c}") for c in hex_colors.split()]
        self.n = len(self.palette)

    def __call__(self, ind, bgr: bool = False):
        """Convert an index to a color code.

        Args:
            ind (int): The index to convert.
            bgr (bool, optional): Whether to return the color code in BGR format. Defaults to False.

        Returns:
            tuple: The color code in RGB or BGR format, depending on the value of `bgr`.
        """
        color_codes = self.palette[int(ind) % self.n]
        return (color_codes[2], color_codes[1], color_codes[0]) if bgr else color_codes

    @staticmethod
    def hex_to_rgb(hex_code):
        """Converts a hexadecimal color code to RGB format.

        Args:
            hex_code (str): The hexadecimal color code to convert.

        Returns:
            tuple: A tuple representing the RGB values in the order (R, G, B).
        """
        rgb = []
        for i in (0, 2, 4):
            rgb.append(int(hex_code[1 + i : 1 + i + 2], 16))
        return tuple(rgb)
Functions
__call__(ind, bgr=False)

Convert an index to a color code.

Parameters:

Name Type Description Default
ind int

The index to convert.

required
bgr bool

Whether to return the color code in BGR format. Defaults to False.

False

Returns:

Name Type Description
tuple

The color code in RGB or BGR format, depending on the value of bgr.

Source code in sahi/utils/cv.py
def __call__(self, ind, bgr: bool = False):
    """Convert an index to a color code.

    Args:
        ind (int): The index to convert.
        bgr (bool, optional): Whether to return the color code in BGR format. Defaults to False.

    Returns:
        tuple: The color code in RGB or BGR format, depending on the value of `bgr`.
    """
    color_codes = self.palette[int(ind) % self.n]
    return (color_codes[2], color_codes[1], color_codes[0]) if bgr else color_codes
hex_to_rgb(hex_code) staticmethod

Converts a hexadecimal color code to RGB format.

Parameters:

Name Type Description Default
hex_code str

The hexadecimal color code to convert.

required

Returns:

Name Type Description
tuple

A tuple representing the RGB values in the order (R, G, B).

Source code in sahi/utils/cv.py
@staticmethod
def hex_to_rgb(hex_code):
    """Converts a hexadecimal color code to RGB format.

    Args:
        hex_code (str): The hexadecimal color code to convert.

    Returns:
        tuple: A tuple representing the RGB values in the order (R, G, B).
    """
    rgb = []
    for i in (0, 2, 4):
        rgb.append(int(hex_code[1 + i : 1 + i + 2], 16))
    return tuple(rgb)
Functions
apply_color_mask(image, color)

Applies color mask to given input image.

Parameters:

Name Type Description Default
image ndarray

The input image to apply the color mask to.

required
color tuple

The RGB color tuple to use for the mask.

required

Returns:

Type Description

np.ndarray: The resulting image with the applied color mask.

Source code in sahi/utils/cv.py
def apply_color_mask(image: np.ndarray, color: tuple[int, int, int]):
    """Applies color mask to given input image.

    Args:
        image (np.ndarray): The input image to apply the color mask to.
        color (tuple): The RGB color tuple to use for the mask.

    Returns:
        np.ndarray: The resulting image with the applied color mask.
    """
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)

    (r[image == 1], g[image == 1], b[image == 1]) = color
    colored_mask = np.stack([r, g, b], axis=2)
    return colored_mask
convert_image_to(read_path, extension='jpg', grayscale=False)

Reads an image from the given path and saves it with the specified extension.

Parameters:

Name Type Description Default
read_path str

The path to the image file.

required
extension str

The desired file extension for the saved image. Defaults to "jpg".

'jpg'
grayscale bool

Whether to convert the image to grayscale. Defaults to False.

False
Source code in sahi/utils/cv.py
def convert_image_to(read_path, extension: str = "jpg", grayscale: bool = False):
    """Reads an image from the given path and saves it with the specified extension.

    Args:
        read_path (str): The path to the image file.
        extension (str, optional): The desired file extension for the saved image. Defaults to "jpg".
        grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to False.
    """
    image = cv2.imread(read_path)
    pre, _ = os.path.splitext(read_path)
    if grayscale:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        pre = pre + "_gray"
    save_path = pre + "." + extension
    cv2.imwrite(save_path, image)
crop_object_predictions(image, object_prediction_list, output_dir='', file_name='prediction_visual', export_format='png')

Crops bounding boxes over the source image and exports it to the output folder.

Parameters:

Name Type Description Default
image ndarray

The source image to crop bounding boxes from.

required
object_prediction_list

A list of object predictions.

required
output_dir str

The directory where the resulting visualizations will be exported. Defaults to an empty string.

''
file_name str

The name of the exported file. The exported file will be saved as output_dir + file_name + ".png". Defaults to "prediction_visual".

'prediction_visual'
export_format str

The format of the exported file. Can be specified as 'jpg' or 'png'. Defaults to "png".

'png'
Source code in sahi/utils/cv.py
def crop_object_predictions(
    image: np.ndarray,
    object_prediction_list,
    output_dir: str = "",
    file_name: str = "prediction_visual",
    export_format: str = "png",
):
    """Crops bounding boxes over the source image and exports it to the output folder.

    Args:
        image (np.ndarray): The source image to crop bounding boxes from.
        object_prediction_list: A list of object predictions.
        output_dir (str): The directory where the resulting visualizations will be exported. Defaults to an empty string.
        file_name (str): The name of the exported file. The exported file will be saved as `output_dir + file_name + ".png"`. Defaults to "prediction_visual".
        export_format (str): The format of the exported file. Can be specified as 'jpg' or 'png'. Defaults to "png".
    """  # noqa

    # create output folder if not present
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    # add bbox and mask to image if present
    for ind, object_prediction in enumerate(object_prediction_list):
        # deepcopy object_prediction_list so that the original is not altered
        object_prediction = object_prediction.deepcopy()
        bbox = object_prediction.bbox.to_xyxy()
        category_id = object_prediction.category.id
        # crop detections
        # deepcopy crops so that the original is not altered
        cropped_img = copy.deepcopy(
            image[
                int(bbox[1]) : int(bbox[3]),
                int(bbox[0]) : int(bbox[2]),
                :,
            ]
        )
        save_path = os.path.join(
            output_dir,
            file_name + "_box" + str(ind) + "_class" + str(category_id) + "." + export_format,
        )
        cv2.imwrite(save_path, cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR))
get_bbox_from_bool_mask(bool_mask)

Generate VOC bounding box [xmin, ymin, xmax, ymax] from given boolean mask.

Parameters:

Name Type Description Default
bool_mask ndarray

2D boolean mask.

required

Returns:

Type Description
list[int] | None

Optional[List[int]]: VOC bounding box [xmin, ymin, xmax, ymax] or None if no bounding box is found.

Source code in sahi/utils/cv.py
def get_bbox_from_bool_mask(bool_mask: np.ndarray) -> list[int] | None:
    """Generate VOC bounding box [xmin, ymin, xmax, ymax] from given boolean mask.

    Args:
        bool_mask (np.ndarray): 2D boolean mask.

    Returns:
        Optional[List[int]]: VOC bounding box [xmin, ymin, xmax, ymax] or None if no bounding box is found.
    """
    rows = np.any(bool_mask, axis=1)
    cols = np.any(bool_mask, axis=0)

    if not np.any(rows) or not np.any(cols):
        return None

    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]
    width = xmax - xmin
    height = ymax - ymin

    if width == 0 or height == 0:
        return None

    return [xmin, ymin, xmax, ymax]
get_bbox_from_coco_segmentation(coco_segmentation)

Generate voc box ([xmin, ymin, xmax, ymax]) from given coco segmentation.

Source code in sahi/utils/cv.py
def get_bbox_from_coco_segmentation(coco_segmentation):
    """Generate voc box ([xmin, ymin, xmax, ymax]) from given coco segmentation."""
    xs = []
    ys = []
    for segm in coco_segmentation:
        xs.extend(segm[::2])
        ys.extend(segm[1::2])
    if len(xs) == 0 or len(ys) == 0:
        return None
    xmin = min(xs)
    xmax = max(xs)
    ymin = min(ys)
    ymax = max(ys)
    return [xmin, ymin, xmax, ymax]
get_bool_mask_from_coco_segmentation(coco_segmentation, width, height)

Convert coco segmentation to 2D boolean mask of given height and width.

Parameters: - coco_segmentation: list of points representing the coco segmentation - width: width of the boolean mask - height: height of the boolean mask

Returns: - bool_mask: 2D boolean mask of size (height, width)

Source code in sahi/utils/cv.py
def get_bool_mask_from_coco_segmentation(coco_segmentation: list[list[float]], width: int, height: int) -> np.ndarray:
    """Convert coco segmentation to 2D boolean mask of given height and width.

    Parameters:
    - coco_segmentation: list of points representing the coco segmentation
    - width: width of the boolean mask
    - height: height of the boolean mask

    Returns:
    - bool_mask: 2D boolean mask of size (height, width)
    """
    size = [height, width]
    points = [np.array(point).reshape(-1, 2).round().astype(int) for point in coco_segmentation]
    bool_mask = np.zeros(size)
    bool_mask = cv2.fillPoly(bool_mask, points, (1.0,))
    bool_mask.astype(bool)
    return bool_mask
get_coco_segmentation_from_bool_mask(bool_mask)

Convert boolean mask to coco segmentation format [ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

Source code in sahi/utils/cv.py
def get_coco_segmentation_from_bool_mask(bool_mask: np.ndarray) -> list[list[float]]:
    """
    Convert boolean mask to coco segmentation format
    [
        [x1, y1, x2, y2, x3, y3, ...],
        [x1, y1, x2, y2, x3, y3, ...],
        ...
    ]
    """
    # Generate polygons from mask
    mask = np.squeeze(bool_mask)
    mask = mask.astype(np.uint8)
    mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=(0, 0, 0))
    polygons = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE, offset=(-1, -1))
    polygons = polygons[0] if len(polygons) == 2 else polygons[1]
    # Convert polygon to coco segmentation
    coco_segmentation = []
    for polygon in polygons:
        segmentation = polygon.flatten().tolist()
        # at least 3 points needed for a polygon
        if len(segmentation) >= 6:
            coco_segmentation.append(segmentation)
    return coco_segmentation
get_coco_segmentation_from_obb_points(obb_points)

Convert OBB (Oriented Bounding Box) points to COCO polygon format.

Parameters:

Name Type Description Default
obb_points ndarray

np.ndarray OBB points tensor from ultralytics.engine.results.OBB Shape: (4, 2) containing 4 points with (x,y) coordinates each

required

Returns:

Type Description
list[list[float]]

List[List[float]]: Polygon points in COCO format [[x1, y1, x2, y2, x3, y3, x4, y4], [...], ...]

Source code in sahi/utils/cv.py
def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> list[list[float]]:
    """Convert OBB (Oriented Bounding Box) points to COCO polygon format.

    Args:
        obb_points: np.ndarray
            OBB points tensor from ultralytics.engine.results.OBB
            Shape: (4, 2) containing 4 points with (x,y) coordinates each

    Returns:
        List[List[float]]: Polygon points in COCO format
            [[x1, y1, x2, y2, x3, y3, x4, y4], [...], ...]
    """
    # Convert from (4,2) to [x1,y1,x2,y2,x3,y3,x4,y4] format
    points = obb_points.reshape(-1).tolist()

    # Create polygon from points and close it by repeating first point
    polygons = []
    # Add first point to end to close polygon
    closed_polygon = [*points, points[0], points[1]]
    polygons.append(closed_polygon)

    return polygons
get_video_reader(source, save_dir, frame_skip_interval, export_visual=False, view_visual=False)

Creates OpenCV video capture object from given video file path.

Parameters:

Name Type Description Default
source str

Video file path

required
save_dir str

Video export directory

required
frame_skip_interval int

Frame skip interval

required
export_visual bool

Set True if you want to export visuals

False
view_visual bool

Set True if you want to render visual

False

Returns:

Name Type Description
iterator Generator[Image]

Pillow Image

video_writer VideoWriter | None

cv2.VideoWriter

video_file_name str

video name with extension

Source code in sahi/utils/cv.py
def get_video_reader(
    source: str,
    save_dir: str,
    frame_skip_interval: int,
    export_visual: bool = False,
    view_visual: bool = False,
) -> tuple[Generator[Image.Image], cv2.VideoWriter | None, str, int]:
    """Creates OpenCV video capture object from given video file path.

    Args:
        source: Video file path
        save_dir: Video export directory
        frame_skip_interval: Frame skip interval
        export_visual: Set True if you want to export visuals
        view_visual: Set True if you want to render visual

    Returns:
        iterator: Pillow Image
        video_writer: cv2.VideoWriter
        video_file_name: video name with extension
    """
    # get video name with extension
    video_file_name = os.path.basename(source)
    # get video from video path
    video_capture = cv2.VideoCapture(source)

    num_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    if view_visual:
        num_frames /= frame_skip_interval + 1
        num_frames = int(num_frames)

    def read_video_frame(video_capture, frame_skip_interval) -> Generator[Image.Image]:
        if view_visual:
            window_name = f"Prediction of {video_file_name!s}"
            cv2.namedWindow(window_name, cv2.WINDOW_AUTOSIZE)
            default_image = np.zeros((480, 640, 3), dtype=np.uint8)
            cv2.imshow(window_name, default_image)

            while video_capture.isOpened:
                frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES)
                video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num + frame_skip_interval)

                k = cv2.waitKey(20)
                frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES)

                if k == 27:
                    print(
                        "\n===========================Closing==========================="
                    )  # Exit the prediction, Key = Esc
                    exit()
                if k == 100:
                    frame_num += 100  # Skip 100 frames, Key = d
                if k == 97:
                    frame_num -= 100  # Prev 100 frames, Key = a
                if k == 103:
                    frame_num += 20  # Skip 20 frames, Key = g
                if k == 102:
                    frame_num -= 20  # Prev 20 frames, Key = f
                video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num)

                ret, frame = video_capture.read()
                if not ret:
                    print("\n=========================== Video Ended ===========================")
                    break
                yield Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

        else:
            while video_capture.isOpened:
                frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES)
                video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num + frame_skip_interval)

                ret, frame = video_capture.read()
                if not ret:
                    print("\n=========================== Video Ended ===========================")
                    break
                yield Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    if export_visual:
        # get video properties and create VideoWriter object
        if frame_skip_interval != 0:
            fps = video_capture.get(cv2.CAP_PROP_FPS)  # original fps of video
            # The fps of export video is increasing during view_image because frame is skipped
            fps = (
                fps / frame_skip_interval
            )  # How many time_interval equals to original fps. One time_interval skip x frames.
        else:
            fps = video_capture.get(cv2.CAP_PROP_FPS)

        w = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        h = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        size = (w, h)
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # pyright: ignore[reportAttributeAccessIssue]
        video_writer = cv2.VideoWriter(os.path.join(save_dir, video_file_name), fourcc, fps, size)
    else:
        video_writer = None

    return read_video_frame(video_capture, frame_skip_interval), video_writer, video_file_name, num_frames
ipython_display(image)

Displays numpy image in notebook.

If input image is in range 0..1, please first multiply img by 255 Assumes image is ndarray of shape [height, width, channels] where channels can be 1, 3 or 4

Source code in sahi/utils/cv.py
def ipython_display(image: np.ndarray):
    """Displays numpy image in notebook.

    If input image is in range 0..1, please first multiply img by 255
    Assumes image is ndarray of shape [height, width, channels] where channels can be 1, 3 or 4
    """
    import IPython

    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    _, ret = cv2.imencode(".png", image)
    i = IPython.display.Image(data=ret)  # type: ignore
    IPython.display.display(i)  # type: ignore
normalize_numpy_image(image)

Normalizes numpy image.

Source code in sahi/utils/cv.py
def normalize_numpy_image(image: np.ndarray):
    """Normalizes numpy image."""
    return image / np.max(image)
read_image(image_path)

Loads image as a numpy array from the given path.

Parameters:

Name Type Description Default
image_path str

The path to the image file.

required

Returns:

Type Description
ndarray

numpy.ndarray: The loaded image as a numpy array.

Source code in sahi/utils/cv.py
def read_image(image_path: str) -> np.ndarray:
    """Loads image as a numpy array from the given path.

    Args:
        image_path (str): The path to the image file.

    Returns:
        numpy.ndarray: The loaded image as a numpy array.
    """
    # read image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # return image
    return image
read_image_as_pil(image, exif_fix=True)

Loads an image as PIL.Image.Image.

Parameters:

Name Type Description Default
image Union[Image, str, ndarray]

The image to be loaded. It can be an image path or URL (str), a numpy image (np.ndarray), or a PIL.Image object.

required
exif_fix bool

Whether to apply an EXIF fix to the image. Defaults to False.

True

Returns:

Type Description
Image

PIL.Image.Image: The loaded image as a PIL.Image object.

Source code in sahi/utils/cv.py
def read_image_as_pil(image: Image.Image | str | np.ndarray, exif_fix: bool = True) -> Image.Image:
    """Loads an image as PIL.Image.Image.

    Args:
        image (Union[Image.Image, str, np.ndarray]): The image to be loaded. It can be an image path or URL (str),
            a numpy image (np.ndarray), or a PIL.Image object.
        exif_fix (bool, optional): Whether to apply an EXIF fix to the image. Defaults to False.

    Returns:
        PIL.Image.Image: The loaded image as a PIL.Image object.
    """
    # https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
    Image.MAX_IMAGE_PIXELS = None

    if isinstance(image, Image.Image):
        image_pil = image
    elif isinstance(image, str):
        # read image if str image path is provided
        try:
            image_pil = Image.open(
                BytesIO(requests.get(image, stream=True).content) if str(image).startswith("http") else image
            ).convert("RGB")
            if exif_fix:
                ImageOps.exif_transpose(image_pil, in_place=True)
        except Exception as e:  # handle large/tiff image reading
            logger.error(f"PIL failed reading image with error {e}, trying skimage instead")
            try:
                import skimage.io
            except ImportError:
                raise ImportError("Please run 'pip install -U scikit-image imagecodecs' for large image handling.")
            image_sk = skimage.io.imread(image).astype(np.uint8)
            if len(image_sk.shape) == 2:  # b&w
                image_pil = Image.fromarray(image_sk, mode="1")
            elif image_sk.shape[2] == 4:  # rgba
                image_pil = Image.fromarray(image_sk, mode="RGBA")
            elif image_sk.shape[2] == 3:  # rgb
                image_pil = Image.fromarray(image_sk, mode="RGB")
            else:
                raise TypeError(f"image with shape: {image_sk.shape[3]} is not supported.")
    elif isinstance(image, np.ndarray):
        # check if image is in CHW format (Channels, Height, Width)
        # heuristic: 3 dimensions, first dim (channels) < 5, last dim (width) > 4
        if image.ndim == 3 and image.shape[0] < 5:  # image in CHW
            if image.shape[2] > 4:
                # convert CHW to HWC (Height, Width, Channels)
                image = np.transpose(image, (1, 2, 0))
        image_pil = Image.fromarray(image)
    else:
        raise TypeError("read image with 'pillow' using 'Image.open()'")
    return image_pil
read_large_image(image_path)

Reads a large image from the specified image path.

Parameters:

Name Type Description Default
image_path str

The path to the image file.

required

Returns:

Name Type Description
tuple

A tuple containing the image data and a flag indicating whether cv2 was used to read the image. The image data is a numpy array representing the image in RGB format. The flag is True if cv2 was used, False otherwise.

Source code in sahi/utils/cv.py
def read_large_image(image_path: str):
    """Reads a large image from the specified image path.

    Args:
        image_path (str): The path to the image file.

    Returns:
        tuple: A tuple containing the image data and a flag indicating whether cv2 was used to read the image.
            The image data is a numpy array representing the image in RGB format.
            The flag is True if cv2 was used, False otherwise.
    """
    use_cv2 = True
    # read image, cv2 fails on large files
    try:
        # convert to rgb (cv2 reads in bgr)
        img_cv2 = cv2.imread(image_path, 1)
        image0 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
    except Exception as e:
        logger.error(f"OpenCV failed reading image with error {e}, trying skimage instead")
        try:
            import skimage.io
        except ImportError:
            raise ImportError(
                'Please run "pip install -U scikit-image" to install scikit-image first for large image handling.'
            )
        image0 = skimage.io.imread(image_path, as_grey=False).astype(np.uint8)  # [::-1]
        use_cv2 = False
    return image0, use_cv2
select_random_color()

Selects a random color from a predefined list of colors.

Returns:

Name Type Description
list

A list representing the RGB values of the selected color.

Source code in sahi/utils/cv.py
def select_random_color():
    """Selects a random color from a predefined list of colors.

    Returns:
        list: A list representing the RGB values of the selected color.
    """
    colors = [
        [0, 255, 0],
        [0, 0, 255],
        [255, 0, 0],
        [0, 255, 255],
        [255, 255, 0],
        [255, 0, 255],
        [80, 70, 180],
        [250, 80, 190],
        [245, 145, 50],
        [70, 150, 250],
        [50, 190, 190],
    ]
    return colors[random.randrange(0, 10)]
visualize_object_predictions(image, object_prediction_list, rect_th=None, text_size=None, text_th=None, color=None, hide_labels=False, hide_conf=False, output_dir=None, file_name='prediction_visual', export_format='png')

Visualizes prediction category names, bounding boxes over the source image and exports it to output folder.

Parameters:

Name Type Description Default
object_prediction_list

a list of prediction.ObjectPrediction

required
rect_th int | None

rectangle thickness

None
text_size float | None

size of the category name over box

None
text_th int | None

text thickness

None
color tuple | None

annotation color in the form: (0, 255, 0)

None
hide_labels bool

hide labels

False
hide_conf bool

hide confidence

False
output_dir str | None

directory for resulting visualization to be exported

None
file_name str | None

exported file will be saved as: output_dir+file_name+".png"

'prediction_visual'
export_format str | None

can be specified as 'jpg' or 'png'

'png'
Source code in sahi/utils/cv.py
def visualize_object_predictions(
    image: np.ndarray,
    object_prediction_list,
    rect_th: int | None = None,
    text_size: float | None = None,
    text_th: int | None = None,
    color: tuple | None = None,
    hide_labels: bool = False,
    hide_conf: bool = False,
    output_dir: str | None = None,
    file_name: str | None = "prediction_visual",
    export_format: str | None = "png",
):
    """Visualizes prediction category names, bounding boxes over the source image and exports it to output folder.

    Args:
        object_prediction_list: a list of prediction.ObjectPrediction
        rect_th: rectangle thickness
        text_size: size of the category name over box
        text_th: text thickness
        color: annotation color in the form: (0, 255, 0)
        hide_labels: hide labels
        hide_conf: hide confidence
        output_dir: directory for resulting visualization to be exported
        file_name: exported file will be saved as: output_dir+file_name+".png"
        export_format: can be specified as 'jpg' or 'png'
    """
    elapsed_time = time.time()
    # deepcopy image so that original is not altered
    image = copy.deepcopy(image)
    # select predefined classwise color palette if not specified
    if color is None:
        colors = Colors()
    else:
        colors = None
    # set rect_th for boxes
    rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.003), 2)
    # set text_th for category names
    text_th = text_th or max(rect_th - 1, 1)
    # set text_size for category names
    text_size = text_size or rect_th / 3

    # add masks or obb polygons to image if present
    for object_prediction in object_prediction_list:
        # deepcopy object_prediction_list so that original is not altered
        object_prediction = object_prediction.deepcopy()
        # arange label to be displayed
        label = f"{object_prediction.category.name}"
        if not hide_conf:
            label += f" {object_prediction.score.value:.2f}"
        # set color
        if colors is not None:
            color = colors(object_prediction.category.id)
        # visualize masks or obb polygons if present
        has_mask = object_prediction.mask is not None
        is_obb_pred = False
        if has_mask:
            segmentation = object_prediction.mask.segmentation
            if len(segmentation) == 1 and len(segmentation[0]) == 8:
                is_obb_pred = True

            if is_obb_pred:
                points = np.array(segmentation).reshape((-1, 1, 2)).astype(np.int32)
                cv2.polylines(image, [points], isClosed=True, color=color or (0, 0, 0), thickness=rect_th)

                if not hide_labels:
                    lowest_point = points[points[:, :, 1].argmax()][0]
                    box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0]
                    outside = lowest_point[1] - box_height - 3 >= 0
                    text_bg_point1 = (
                        lowest_point[0],
                        lowest_point[1] - box_height - 3 if outside else lowest_point[1] + 3,
                    )
                    text_bg_point2 = (lowest_point[0] + box_width, lowest_point[1])
                    cv2.rectangle(
                        image, text_bg_point1, text_bg_point2, color or (0, 0, 0), thickness=-1, lineType=cv2.LINE_AA
                    )
                    cv2.putText(
                        image,
                        label,
                        (lowest_point[0], lowest_point[1] - 2 if outside else lowest_point[1] + box_height + 2),
                        0,
                        text_size,
                        (255, 255, 255),
                        thickness=text_th,
                    )
            else:
                # draw mask
                rgb_mask = apply_color_mask(object_prediction.mask.bool_mask, color or (0, 0, 0))
                image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)

        # add bboxes to image if is_obb_pred=False
        if not is_obb_pred:
            bbox = object_prediction.bbox.to_xyxy()

            # set bbox points
            point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
            # visualize boxes
            cv2.rectangle(
                image,
                point1,
                point2,
                color=color or (0, 0, 0),
                thickness=rect_th,
            )

            if not hide_labels:
                box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
                    0
                ]  # label width, height
                outside = point1[1] - box_height - 3 >= 0  # label fits outside box
                point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
                # add bounding box text
                cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA)  # filled
                cv2.putText(
                    image,
                    label,
                    (point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
                    0,
                    text_size,
                    (255, 255, 255),
                    thickness=text_th,
                )

    # export if output_dir is present
    if output_dir is not None:
        # export image with predictions
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        # save inference result
        save_path = str(Path(output_dir) / ((file_name or "") + "." + (export_format or "")))
        cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

    elapsed_time = time.time() - elapsed_time
    return {"image": image, "elapsed_time": elapsed_time}
visualize_prediction(image, boxes, classes, masks=None, rect_th=None, text_size=None, text_th=None, color=None, hide_labels=False, output_dir=None, file_name='prediction_visual')

Visualizes prediction classes, bounding boxes over the source image and exports it to output folder.

Parameters:

Name Type Description Default
image ndarray

The source image.

required
boxes List[List]

List of bounding boxes coordinates.

required
classes List[str]

List of class labels corresponding to each bounding box.

required
masks Optional[List[ndarray]]

List of masks corresponding to each bounding box. Defaults to None.

None
rect_th int

Thickness of the bounding box rectangle. Defaults to None.

None
text_size float

Size of the text for class labels. Defaults to None.

None
text_th int

Thickness of the text for class labels. Defaults to None.

None
color tuple

Color of the bounding box and text. Defaults to None.

None
hide_labels bool

Whether to hide the class labels. Defaults to False.

False
output_dir Optional[str]

Output directory to save the visualization. Defaults to None.

None
file_name Optional[str]

File name for the saved visualization. Defaults to "prediction_visual".

'prediction_visual'

Returns:

Name Type Description
dict

A dictionary containing the visualized image and the elapsed time for the visualization process.

Source code in sahi/utils/cv.py
def visualize_prediction(
    image: np.ndarray,
    boxes: list[list],
    classes: list[str],
    masks: list[np.ndarray] | None = None,
    rect_th: int | None = None,
    text_size: float | None = None,
    text_th: int | None = None,
    color: tuple | None = None,
    hide_labels: bool = False,
    output_dir: str | None = None,
    file_name: str | None = "prediction_visual",
):
    """Visualizes prediction classes, bounding boxes over the source image and exports it to output folder.

    Args:
        image (np.ndarray): The source image.
        boxes (List[List]): List of bounding boxes coordinates.
        classes (List[str]): List of class labels corresponding to each bounding box.
        masks (Optional[List[np.ndarray]], optional): List of masks corresponding to each bounding box. Defaults to None.
        rect_th (int, optional): Thickness of the bounding box rectangle. Defaults to None.
        text_size (float, optional): Size of the text for class labels. Defaults to None.
        text_th (int, optional): Thickness of the text for class labels. Defaults to None.
        color (tuple, optional): Color of the bounding box and text. Defaults to None.
        hide_labels (bool, optional): Whether to hide the class labels. Defaults to False.
        output_dir (Optional[str], optional): Output directory to save the visualization. Defaults to None.
        file_name (Optional[str], optional): File name for the saved visualization. Defaults to "prediction_visual".

    Returns:
        dict: A dictionary containing the visualized image and the elapsed time for the visualization process.
    """  # noqa

    elapsed_time = time.time()
    # deepcopy image so that original is not altered
    image = copy.deepcopy(image)
    # select predefined classwise color palette if not specified
    if color is None:
        colors = Colors()
    else:
        colors = None
    # set rect_th for boxes
    rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.003), 2)
    # set text_th for category names
    text_th = text_th or max(rect_th - 1, 1)
    # set text_size for category names
    text_size = text_size or rect_th / 3

    # add masks to image if present
    if masks is not None and color is None:
        logger.error("Cannot add mask, no color tuple given")
    elif masks is not None and color is not None:
        for mask in masks:
            # deepcopy mask so that original is not altered
            mask = copy.deepcopy(mask)
            # draw mask
            rgb_mask = apply_color_mask(np.squeeze(mask), color)
            image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)

    # add bboxes to image if present
    for box_indice in range(len(boxes)):
        # deepcopy boxso that original is not altered
        box = copy.deepcopy(boxes[box_indice])
        class_ = classes[box_indice]

        # set color
        if colors is not None:
            mycolor = colors(class_)
        elif color is not None:
            mycolor = color
        else:
            logger.error("color cannot be defined")
            continue

        # set bbox points
        point1, point2 = [int(box[0]), int(box[1])], [int(box[2]), int(box[3])]
        # visualize boxes
        cv2.rectangle(
            image,
            point1,
            point2,
            color=mycolor,
            thickness=rect_th,
        )

        if not hide_labels:
            # arange bounding box text location
            label = f"{class_}"
            box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
                0
            ]  # label width, height
            outside = point1[1] - box_height - 3 >= 0  # label fits outside box
            point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
            # add bounding box text
            cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA)  # filled
            cv2.putText(
                image,
                label,
                (point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
                0,
                text_size,
                (255, 255, 255),
                thickness=text_th,
            )
    if output_dir:
        # create output folder if not present
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        # save inference result
        save_path = os.path.join(output_dir, (file_name or "unknown") + ".png")
        cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

    elapsed_time = time.time() - elapsed_time
    return {"image": image, "elapsed_time": elapsed_time}
detectron2
Functions
export_cfg_as_yaml(cfg, export_path='config.yaml')

Exports Detectron2 config object in yaml format so that it can be used later.

Parameters:

Name Type Description Default
cfg CfgNode

Detectron2 config object.

required
export_path str

Path to export the Detectron2 config.

'config.yaml'

Related Detectron2 doc: https://detectron2.readthedocs.io/en/stable/modules/config.html#detectron2.config.CfgNode.dump

Source code in sahi/utils/detectron2.py
def export_cfg_as_yaml(cfg, export_path: str = "config.yaml"):
    """Exports Detectron2 config object in yaml format so that it can be used later.

    Args:
        cfg (detectron2.config.CfgNode): Detectron2 config object.
        export_path (str): Path to export the Detectron2 config.
    Related Detectron2 doc: https://detectron2.readthedocs.io/en/stable/modules/config.html#detectron2.config.CfgNode.dump
    """
    Path(export_path).parent.mkdir(exist_ok=True, parents=True)

    with open(export_path, "w") as f:
        f.write(cfg.dump())
file
Functions
download_from_url(from_url, to_path)

Downloads a file from the given URL and saves it to the specified path.

Parameters:

Name Type Description Default
from_url str

The URL of the file to download.

required
to_path str

The path where the downloaded file should be saved.

required

Returns:

Type Description

None

Source code in sahi/utils/file.py
def download_from_url(from_url: str, to_path: str):
    """Downloads a file from the given URL and saves it to the specified path.

    Args:
        from_url (str): The URL of the file to download.
        to_path (str): The path where the downloaded file should be saved.

    Returns:
        None
    """
    Path(to_path).parent.mkdir(parents=True, exist_ok=True)

    if not os.path.exists(to_path):
        urllib.request.urlretrieve(
            from_url,
            to_path,
        )
get_base_filename(path)

Takes a file path, returns (base_filename_with_extension, base_filename_without_extension)

Source code in sahi/utils/file.py
def get_base_filename(path: str):
    """Takes a file path, returns (base_filename_with_extension, base_filename_without_extension)"""
    base_filename_with_extension = ntpath.basename(path)
    base_filename_without_extension, _ = os.path.splitext(base_filename_with_extension)
    return base_filename_with_extension, base_filename_without_extension
get_file_extension(path)

Get the file extension from a given file path.

Parameters:

Name Type Description Default
path str

The file path.

required

Returns:

Name Type Description
str

The file extension.

Source code in sahi/utils/file.py
def get_file_extension(path: str):
    """Get the file extension from a given file path.

    Args:
        path (str): The file path.

    Returns:
        str: The file extension.
    """
    _, file_extension = os.path.splitext(path)
    return file_extension
import_model_class(model_type, class_name)

Imports a predefined detection class by class name.

Parameters:

Name Type Description Default
model_type

str "yolov5", "detectron2", "mmdet", "huggingface" etc

required
model_name

str Name of the detection model class (example: "MmdetDetectionModel")

required

Returns: class_: class with given path

Source code in sahi/utils/file.py
def import_model_class(model_type, class_name):
    """Imports a predefined detection class by class name.

    Args:
        model_type: str
            "yolov5", "detectron2", "mmdet", "huggingface" etc
        model_name: str
            Name of the detection model class (example: "MmdetDetectionModel")
    Returns:
        class_: class with given path
    """
    module = __import__(f"sahi.models.{model_type}", fromlist=[class_name])
    class_ = getattr(module, class_name)
    return class_
increment_path(path, exist_ok=True, sep='')

Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.

Parameters:

Name Type Description Default
path str | Path

str The base path to increment.

required
exist_ok bool

bool If True, return the path as is if it already exists. If False, increment the path.

True
sep str

str The separator to use between the base path and the increment number.

''

Returns:

Name Type Description
str str

The incremented path.

Example

increment_path("runs/exp", sep="") 'runs/exp_0' increment_path("runs/exp_0", sep="") 'runs/exp_1'

Source code in sahi/utils/file.py
def increment_path(path: str | Path, exist_ok: bool = True, sep: str = "") -> str:
    """Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.

    Args:
        path: str
            The base path to increment.
        exist_ok: bool
            If True, return the path as is if it already exists. If False, increment the path.
        sep: str
            The separator to use between the base path and the increment number.

    Returns:
        str: The incremented path.

    Example:
        >>> increment_path("runs/exp", sep="_")
        'runs/exp_0'
        >>> increment_path("runs/exp_0", sep="_")
        'runs/exp_1'
    """
    path = Path(path)  # os-agnostic
    if (path.exists() and exist_ok) or (not path.exists()):
        return str(path)
    else:
        dirs = glob.glob(f"{path}{sep}*")  # similar paths
        matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
        indices = [int(m.groups()[0]) for m in matches if m]  # indices
        n = max(indices) + 1 if indices else 2  # increment number
        return f"{path}{sep}{n}"  # update path
is_colab()

Check if the current environment is a Google Colab instance.

Returns:

Name Type Description
bool

True if the environment is a Google Colab instance, False otherwise.

Source code in sahi/utils/file.py
def is_colab():
    """Check if the current environment is a Google Colab instance.

    Returns:
        bool: True if the environment is a Google Colab instance, False otherwise.
    """
    import sys

    return "google.colab" in sys.modules
list_files(directory, contains=['.json'], verbose=1)

Walk given directory and return a list of file path with desired extension.

Parameters:

Name Type Description Default
directory str

str "data/coco/"

required
contains list

list A list of strings to check if the target file contains them, example: ["coco.png", ".jpg", "jpeg"]

['.json']
verbose int

int 0: no print 1: print number of files

1

Returns:

Name Type Description
filepath_list list[str]

list List of file paths

Source code in sahi/utils/file.py
def list_files(
    directory: str,
    contains: list = [".json"],
    verbose: int = 1,
) -> list[str]:
    """Walk given directory and return a list of file path with desired extension.

    Args:
        directory: str
            "data/coco/"
        contains: list
            A list of strings to check if the target file contains them, example: ["coco.png", ".jpg", "jpeg"]
        verbose: int
            0: no print
            1: print number of files

    Returns:
        filepath_list : list
            List of file paths
    """
    # define verboseprint
    verboseprint = print if verbose else lambda *a, **k: None

    filepath_list: list[str] = []

    for file in os.listdir(directory):
        # check if filename contains any of the terms given in contains list
        if any(strtocheck in file.lower() for strtocheck in contains):
            filepath = str(os.path.join(directory, file))
            filepath_list.append(filepath)

    number_of_files = len(filepath_list)
    folder_name = Path(directory).name

    verboseprint(f"There are {number_of_files!s} listed files in folder: {folder_name}/")

    return filepath_list
list_files_recursively(directory, contains=['.json'], verbose=True)

Walk given directory recursively and return a list of file path with desired extension.

Parameters:

Name Type Description Default
directory

str "data/coco/"

required
contains

list A list of strings to check if the target file contains them, example: ["coco.png", ".jpg", "jpeg"]

required
verbose

bool If true, prints some results

required

Returns:

Name Type Description
relative_filepath_list list

list List of file paths relative to given directory

abs_filepath_list list

list List of absolute file paths

Source code in sahi/utils/file.py
def list_files_recursively(directory: str, contains: list = [".json"], verbose: bool = True) -> tuple[list, list]:
    """Walk given directory recursively and return a list of file path with desired extension.

    Args:
        directory : str
            "data/coco/"
        contains : list
            A list of strings to check if the target file contains them, example: ["coco.png", ".jpg", "jpeg"]
        verbose : bool
            If true, prints some results

    Returns:
        relative_filepath_list : list
            List of file paths relative to given directory
        abs_filepath_list : list
            List of absolute file paths
    """

    # define verboseprint
    verboseprint = print if verbose else lambda *a, **k: None

    # walk directories recursively and find json files
    abs_filepath_list = []
    relative_filepath_list = []

    # r=root, d=directories, f=files
    for r, _, f in os.walk(directory):
        for file in f:
            # check if filename contains any of the terms given in contains list
            if any(strtocheck in file.lower() for strtocheck in contains):
                abs_filepath = os.path.join(r, file)
                abs_filepath_list.append(abs_filepath)
                relative_filepath = abs_filepath.split(directory)[-1]
                relative_filepath_list.append(relative_filepath)

    number_of_files = len(relative_filepath_list)
    folder_name = directory.split(os.sep)[-1]

    verboseprint(f"There are {number_of_files} listed files in folder {folder_name}.")

    return relative_filepath_list, abs_filepath_list
load_json(load_path, encoding='utf-8')

Loads json formatted data (given as "data") from load_path Encoding type can be specified with 'encoding' argument.

Parameters:

Name Type Description Default
load_path str

str "dirname/coco.json"

required
encoding str

str Encoding type, default is 'utf-8'

'utf-8'
Example inputs

load_path: "dirname/coco.json"

Source code in sahi/utils/file.py
def load_json(load_path: str, encoding: str = "utf-8"):
    """Loads json formatted data (given as "data") from load_path Encoding type can be specified with 'encoding'
    argument.

    Args:
        load_path: str
            "dirname/coco.json"
        encoding: str
            Encoding type, default is 'utf-8'

    Example inputs:
        load_path: "dirname/coco.json"
    """
    # read from path
    with open(load_path, encoding=encoding) as json_file:
        data = json.load(json_file)
    return data
load_pickle(load_path)

Loads pickle formatted data (given as "data") from load_path

Parameters:

Name Type Description Default
load_path

str "dirname/coco.pickle"

required
Example inputs

load_path: "dirname/coco.pickle"

Source code in sahi/utils/file.py
def load_pickle(load_path):
    """
    Loads pickle formatted data (given as "data") from load_path

    Args:
        load_path: str
            "dirname/coco.pickle"

    Example inputs:
        load_path: "dirname/coco.pickle"
    """
    with open(load_path, "rb") as json_file:
        data = pickle.load(json_file)
    return data
save_json(data, save_path, indent=None)

Saves json formatted data (given as "data") as save_path

Parameters:

Name Type Description Default
data

dict Data to be saved as json

required
save_path

str "dirname/coco.json"

required
indent int | None

int or None Indentation level for pretty-printing the JSON data. If None, the most compact representation will be used. If an integer is provided, it specifies the number of spaces to use for indentation. Example: indent=4 will format the JSON data with an indentation of 4 spaces per level.

None
Example inputs

data: {"image_id": 5} save_path: "dirname/coco.json" indent: Train json files with indent=None, val json files with indent=4

Source code in sahi/utils/file.py
def save_json(data, save_path, indent: int | None = None):
    """
    Saves json formatted data (given as "data") as save_path

    Args:
        data: dict
            Data to be saved as json
        save_path: str
            "dirname/coco.json"
        indent: int or None
            Indentation level for pretty-printing the JSON data. If None, the most compact representation
            will be used. If an integer is provided, it specifies the number of spaces to use for indentation.
            Example: indent=4 will format the JSON data with an indentation of 4 spaces per level.

    Example inputs:
        data: {"image_id": 5}
        save_path: "dirname/coco.json"
        indent: Train json files with indent=None, val json files with indent=4
    """
    # create dir if not present
    Path(save_path).parent.mkdir(parents=True, exist_ok=True)

    # export as json
    with open(save_path, "w", encoding="utf-8") as outfile:
        json.dump(data, outfile, separators=(",", ":"), cls=NumpyEncoder, indent=indent)
save_pickle(data, save_path)

Saves pickle formatted data (given as "data") as save_path

Parameters:

Name Type Description Default
data

dict Data to be saved as pickle

required
save_path

str "dirname/coco.pickle"

required
Example inputs

data: {"image_id": 5} save_path: "dirname/coco.pickle"

Source code in sahi/utils/file.py
def save_pickle(data, save_path):
    """
    Saves pickle formatted data (given as "data") as save_path

    Args:
        data: dict
            Data to be saved as pickle
        save_path: str
            "dirname/coco.pickle"

    Example inputs:
        data: {"image_id": 5}
        save_path: "dirname/coco.pickle"
    """
    # create dir if not present
    Path(save_path).parent.mkdir(parents=True, exist_ok=True)

    # export as json
    with open(save_path, "wb") as outfile:
        pickle.dump(data, outfile)
unzip(file_path, dest_dir)

Unzips compressed .zip file.

Example inputs

file_path: 'data/01_alb_id.zip' dest_dir: 'data/'

Source code in sahi/utils/file.py
def unzip(file_path: str, dest_dir: str):
    """Unzips compressed .zip file.

    Example inputs:
        file_path: 'data/01_alb_id.zip'
        dest_dir: 'data/'
    """

    # unzip file
    with zipfile.ZipFile(file_path) as zf:
        zf.extractall(dest_dir)
import_utils
Functions
check_package_minimum_version(package_name, minimum_version, verbose=False)

Raise error if module version is not compatible.

Source code in sahi/utils/import_utils.py
def check_package_minimum_version(package_name: str, minimum_version: str, verbose=False):
    """Raise error if module version is not compatible."""
    from packaging import version

    _is_available, _version = get_package_info(package_name, verbose=verbose)
    if _is_available:
        if _version == "unknown":
            logger.warning(
                f"Could not determine version of {package_name}. Assuming version {minimum_version} is compatible."
            )
        else:
            if version.parse(_version) < version.parse(minimum_version):
                return False
    return True
check_requirements(package_names)

Raise error if module is not installed.

Source code in sahi/utils/import_utils.py
def check_requirements(package_names):
    """Raise error if module is not installed."""
    missing_packages = []
    for package_name in package_names:
        if importlib.util.find_spec(package_name) is None:
            missing_packages.append(package_name)
    if missing_packages:
        raise ImportError(f"The following packages are required to use this module: {missing_packages}")
    yield
ensure_package_minimum_version(package_name, minimum_version, verbose=False)

Raise error if module version is not compatible.

Source code in sahi/utils/import_utils.py
def ensure_package_minimum_version(package_name: str, minimum_version: str, verbose=False):
    """Raise error if module version is not compatible."""
    from packaging import version

    _is_available, _version = get_package_info(package_name, verbose=verbose)
    if _is_available:
        if _version == "unknown":
            logger.warning(
                f"Could not determine version of {package_name}. Assuming version {minimum_version} is compatible."
            )
        else:
            if version.parse(_version) < version.parse(minimum_version):
                raise ImportError(
                    f"Please upgrade {package_name} to version {minimum_version} or higher to use this module."
                )
    yield
get_package_info(package_name, verbose=True)

Returns the package version as a string and the package name as a string.

Source code in sahi/utils/import_utils.py
def get_package_info(package_name: str, verbose: bool = True):
    """Returns the package version as a string and the package name as a string."""
    _is_available = is_available(package_name)

    if _is_available:
        try:
            import importlib.metadata as _importlib_metadata

            _version = _importlib_metadata.version(package_name)
        except (ModuleNotFoundError, AttributeError):
            try:
                _version = importlib.import_module(package_name).__version__
            except AttributeError:
                _version = "unknown"
        if verbose:
            logger.pkg_info(f"{package_name} version {_version} is available.")
    else:
        _version = "N/A"

    return _is_available, _version
mmdet
Functions
download_mmdet_config(model_name='cascade_rcnn', config_file_name='cascade_mask_rcnn_r50_fpn_1x_coco.py', verbose=True)

Merges config files starting from given main config file name. Saves as single file.

Parameters:

Name Type Description Default
model_name str

mmdet model name. check https://github.com/open-mmlab/mmdetection/tree/master/configs.

'cascade_rcnn'
config_file_name str

mdmet config file name.

'cascade_mask_rcnn_r50_fpn_1x_coco.py'
verbose bool

if True, print save path.

True

Returns:

Type Description
str

(str) abs path of the downloaded config file.

Source code in sahi/utils/mmdet.py
def download_mmdet_config(
    model_name: str = "cascade_rcnn",
    config_file_name: str = "cascade_mask_rcnn_r50_fpn_1x_coco.py",
    verbose: bool = True,
) -> str:
    """Merges config files starting from given main config file name. Saves as single file.

    Args:
        model_name (str): mmdet model name. check https://github.com/open-mmlab/mmdetection/tree/master/configs.
        config_file_name (str): mdmet config file name.
        verbose (bool): if True, print save path.

    Returns:
        (str) abs path of the downloaded config file.
    """

    # get mmdet version
    from mmdet import __version__

    mmdet_ver = "v" + __version__

    # set main config url
    base_config_url = (
        "https://raw.githubusercontent.com/open-mmlab/mmdetection/" + mmdet_ver + "/configs/" + model_name + "/"
    )
    main_config_url = base_config_url + config_file_name

    # set final config dirs
    configs_dir = Path("mmdet_configs") / mmdet_ver
    model_config_dir = configs_dir / model_name

    # create final config dir
    configs_dir.mkdir(parents=True, exist_ok=True)
    model_config_dir.mkdir(parents=True, exist_ok=True)

    # get final config file name
    filename = Path(main_config_url).name

    # set final config file path
    final_config_path = str(model_config_dir / filename)

    if not Path(final_config_path).exists():
        # set config dirs
        temp_configs_dir = Path("temp_mmdet_configs")
        main_config_dir = temp_configs_dir / model_name

        # create config dirs
        temp_configs_dir.mkdir(parents=True, exist_ok=True)
        main_config_dir.mkdir(parents=True, exist_ok=True)

        # get main config file name
        filename = Path(main_config_url).name

        # set main config file path
        main_config_path = str(main_config_dir / filename)

        # download main config file
        urllib.request.urlretrieve(
            main_config_url,
            main_config_path,
        )

        # read main config file
        sys.path.insert(0, str(main_config_dir))
        temp_module_name = path.splitext(filename)[0]
        mod = import_module(temp_module_name)
        sys.path.pop(0)
        config_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")}

        # handle when config_dict["_base_"] is string
        if not isinstance(config_dict["_base_"], list):
            config_dict["_base_"] = [config_dict["_base_"]]

        # iterate over secondary config files
        for secondary_config_file_path in config_dict["_base_"]:
            # set config url
            config_url = base_config_url + secondary_config_file_path
            config_path = main_config_dir / secondary_config_file_path

            # create secondary config dir
            config_path.parent.mkdir(parents=True, exist_ok=True)

            # download secondary config files
            urllib.request.urlretrieve(
                config_url,
                str(config_path),
            )

            # read secondary config file
            secondary_config_dir = config_path.parent
            sys.path.insert(0, str(secondary_config_dir))
            temp_module_name = path.splitext(Path(config_path).name)[0]
            mod = import_module(temp_module_name)
            sys.path.pop(0)
            secondary_config_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")}

            # go deeper if there are more steps
            if secondary_config_dict.get("_base_") is not None:
                # handle when config_dict["_base_"] is string
                if not isinstance(secondary_config_dict["_base_"], list):
                    secondary_config_dict["_base_"] = [secondary_config_dict["_base_"]]

                # iterate over third config files
                for third_config_file_path in secondary_config_dict["_base_"]:
                    # set config url
                    config_url = base_config_url + third_config_file_path
                    config_path = main_config_dir / third_config_file_path

                    # create secondary config dir
                    config_path.parent.mkdir(parents=True, exist_ok=True)
                    # download secondary config files
                    urllib.request.urlretrieve(
                        config_url,
                        str(config_path),
                    )

        from mmengine import Config
        # dump final config as single file

        config = Config.fromfile(main_config_path)
        config.dump(final_config_path)

        if verbose:
            print(f"mmdet config file has been downloaded to {path.abspath(final_config_path)}")

        # remove temp config dir
        shutil.rmtree(temp_configs_dir)

    return path.abspath(final_config_path)
shapely
Classes
ShapelyAnnotation

Creates ShapelyAnnotation (as shapely MultiPolygon).

Can convert this instance annotation to various formats.

Source code in sahi/utils/shapely.py
class ShapelyAnnotation:
    """Creates ShapelyAnnotation (as shapely MultiPolygon).

    Can convert this instance annotation to various formats.
    """

    @classmethod
    def from_coco_segmentation(cls, segmentation, slice_bbox=None):
        """Init ShapelyAnnotation from coco segmentation.

        segmentation : List[List]
            [[1, 1, 325, 125, 250, 200, 5, 200]]
        slice_bbox (List[int]): [xmin, ymin, width, height]
            Should have the same format as the output of the get_bbox_from_shapely function.
            Is used to calculate sliced coco coordinates.
        """
        shapely_multipolygon = get_shapely_multipolygon(segmentation)
        return cls(multipolygon=shapely_multipolygon, slice_bbox=slice_bbox)

    @classmethod
    def from_coco_bbox(cls, bbox: list[int], slice_bbox: list[int] | None = None):
        """Init ShapelyAnnotation from coco bbox.

        bbox (List[int]): [xmin, ymin, width, height] slice_bbox (List[int]): [x_min, y_min, x_max, y_max] Is used
        to calculate sliced coco coordinates.
        """
        shapely_polygon = get_shapely_box(x=bbox[0], y=bbox[1], width=bbox[2], height=bbox[3])
        shapely_multipolygon = MultiPolygon([shapely_polygon])
        return cls(multipolygon=shapely_multipolygon, slice_bbox=slice_bbox)

    def __init__(self, multipolygon: MultiPolygon, slice_bbox=None):
        self.multipolygon = multipolygon
        self.slice_bbox = slice_bbox

    @property
    def multipolygon(self):
        return self.__multipolygon

    @property
    def area(self):
        return int(self.__area)

    @multipolygon.setter
    def multipolygon(self, multipolygon: MultiPolygon):
        self.__multipolygon = multipolygon
        # calculate areas of all polygons
        area = 0
        for shapely_polygon in multipolygon.geoms:
            area += shapely_polygon.area
        # set instance area
        self.__area = area

    def to_list(self):
        """
        [
            [(x1, y1), (x2, y2), (x3, y3), ...],
            [(x1, y1), (x2, y2), (x3, y3), ...],
            ...
        ]
        """
        list_of_list_of_points: list = []
        for shapely_polygon in self.multipolygon.geoms:
            # create list_of_points for selected shapely_polygon
            if shapely_polygon.area != 0:
                x_coords = shapely_polygon.exterior.coords.xy[0]
                y_coords = shapely_polygon.exterior.coords.xy[1]
                # fix coord by slice_bbox
                if self.slice_bbox:
                    minx = self.slice_bbox[0]
                    miny = self.slice_bbox[1]
                    x_coords = [x_coord - minx for x_coord in x_coords]
                    y_coords = [y_coord - miny for y_coord in y_coords]
                list_of_points = list(zip(x_coords, y_coords))
            else:
                list_of_points = []
            # append list_of_points to list_of_list_of_points
            list_of_list_of_points.append(list_of_points)
        # return result
        return list_of_list_of_points

    def to_coco_segmentation(self):
        """
        [
            [x1, y1, x2, y2, x3, y3, ...],
            [x1, y1, x2, y2, x3, y3, ...],
            ...
        ]
        """
        coco_segmentation: list = []
        for shapely_polygon in self.multipolygon.geoms:
            # create list_of_points for selected shapely_polygon
            if shapely_polygon.area != 0:
                x_coords = shapely_polygon.exterior.coords.xy[0]
                y_coords = shapely_polygon.exterior.coords.xy[1]
                # fix coord by slice_bbox
                if self.slice_bbox:
                    minx = self.slice_bbox[0]
                    miny = self.slice_bbox[1]
                    x_coords = [x_coord - minx for x_coord in x_coords]
                    y_coords = [y_coord - miny for y_coord in y_coords]
                # convert intersection to coco style segmentation annotation
                coco_polygon: list[None | int] = [None] * (len(x_coords) * 2)
                coco_polygon[0::2] = [int(coord) for coord in x_coords]
                coco_polygon[1::2] = [int(coord) for coord in y_coords]
            else:
                coco_polygon = []
            # remove if first and last points are duplicate
            if coco_polygon[:2] == coco_polygon[-2:]:
                del coco_polygon[-2:]
            # append coco_polygon to coco_segmentation
            coco_polygon = [point for point in coco_polygon] if coco_polygon else coco_polygon
            coco_segmentation.append(coco_polygon)
        return coco_segmentation

    def to_opencv_contours(self):
        """[ [[[1, 1]], [[325, 125]], [[250, 200]], [[5, 200]]], [[[1, 1]], [[325, 125]], [[250, 200]], [[5, 200]]] ]"""
        opencv_contours: list = []
        for shapely_polygon in self.multipolygon.geoms:
            # create opencv_contour for selected shapely_polygon
            if shapely_polygon.area != 0:
                x_coords = shapely_polygon.exterior.coords.xy[0]
                y_coords = shapely_polygon.exterior.coords.xy[1]
                # fix coord by slice_bbox
                if self.slice_bbox:
                    minx = self.slice_bbox[0]
                    miny = self.slice_bbox[1]
                    x_coords = [x_coord - minx for x_coord in x_coords]
                    y_coords = [y_coord - miny for y_coord in y_coords]
                opencv_contour = [[[int(x_coords[ind]), int(y_coords[ind])]] for ind in range(len(x_coords))]
            else:
                opencv_contour: list = []
            # append opencv_contour to opencv_contours
            opencv_contours.append(opencv_contour)
        # return result
        return opencv_contours

    def to_xywh(self):
        """[xmin, ymin, width, height]"""
        if self.multipolygon.area != 0:
            coco_bbox, _ = get_bbox_from_shapely(self.multipolygon)
            # fix coord by slice box
            if self.slice_bbox:
                minx = self.slice_bbox[0]
                miny = self.slice_bbox[1]
                coco_bbox[0] = coco_bbox[0] - minx
                coco_bbox[1] = coco_bbox[1] - miny
        else:
            coco_bbox: list = []
        return coco_bbox

    def to_coco_bbox(self):
        """[xmin, ymin, width, height]"""
        return self.to_xywh()

    def to_xyxy(self):
        """[xmin, ymin, xmax, ymax]"""
        if self.multipolygon.area != 0:
            _, voc_bbox = get_bbox_from_shapely(self.multipolygon)
            # fix coord by slice box
            if self.slice_bbox:
                minx = self.slice_bbox[0]
                miny = self.slice_bbox[1]
                voc_bbox[0] = voc_bbox[0] - minx
                voc_bbox[2] = voc_bbox[2] - minx
                voc_bbox[1] = voc_bbox[1] - miny
                voc_bbox[3] = voc_bbox[3] - miny
        else:
            voc_bbox = []
        return voc_bbox

    def to_voc_bbox(self):
        """[xmin, ymin, xmax, ymax]"""
        return self.to_xyxy()

    def get_convex_hull_shapely_annotation(self):
        shapely_multipolygon = MultiPolygon([self.multipolygon.convex_hull])
        shapely_annotation = ShapelyAnnotation(shapely_multipolygon)
        return shapely_annotation

    def get_simplified_shapely_annotation(self, tolerance=1):
        shapely_multipolygon = MultiPolygon([self.multipolygon.simplify(tolerance)])
        shapely_annotation = ShapelyAnnotation(shapely_multipolygon)
        return shapely_annotation

    def get_buffered_shapely_annotation(
        self,
        distance=3,
        resolution=16,
        quadsegs=None,
        cap_style=CAP_STYLE.round,
        join_style=JOIN_STYLE.round,
        mitre_limit=5.0,
        single_sided=False,
    ):
        """Approximates the present polygon to have a valid polygon shape.

        For more, check: https://shapely.readthedocs.io/en/stable/manual.html#object.buffer
        """
        buffered_polygon = self.multipolygon.buffer(
            distance=distance,
            resolution=resolution,
            quadsegs=quadsegs,
            cap_style=cap_style,
            join_style=join_style,
            mitre_limit=mitre_limit,
            single_sided=single_sided,
        )
        shapely_annotation = ShapelyAnnotation(MultiPolygon([buffered_polygon]))
        return shapely_annotation

    def get_intersection(self, polygon: Polygon):
        """Accepts shapely polygon object and returns the intersection in ShapelyAnnotation format."""
        # convert intersection polygon to list of tuples
        intersection = self.multipolygon.intersection(polygon)
        # if polygon is box then set slice_box property
        if (
            len(polygon.exterior.xy[0]) == 5
            and polygon.exterior.xy[0][0] == polygon.exterior.xy[0][1]
            and polygon.exterior.xy[0][2] == polygon.exterior.xy[0][3]
        ):
            coco_bbox, _ = get_bbox_from_shapely(polygon)
            slice_bbox = coco_bbox
        else:
            slice_bbox = None
        # convert intersection to multipolygon
        if intersection.geom_type == "Polygon":
            intersection_multipolygon = MultiPolygon([intersection])
        elif intersection.geom_type == "MultiPolygon":
            intersection_multipolygon = intersection
        else:
            intersection_multipolygon = MultiPolygon([])
        # create shapely annotation from intersection multipolygon
        intersection_shapely_annotation = ShapelyAnnotation(intersection_multipolygon, slice_bbox)

        return intersection_shapely_annotation
Functions
from_coco_bbox(bbox, slice_bbox=None) classmethod

Init ShapelyAnnotation from coco bbox.

bbox (List[int]): [xmin, ymin, width, height] slice_bbox (List[int]): [x_min, y_min, x_max, y_max] Is used to calculate sliced coco coordinates.

Source code in sahi/utils/shapely.py
@classmethod
def from_coco_bbox(cls, bbox: list[int], slice_bbox: list[int] | None = None):
    """Init ShapelyAnnotation from coco bbox.

    bbox (List[int]): [xmin, ymin, width, height] slice_bbox (List[int]): [x_min, y_min, x_max, y_max] Is used
    to calculate sliced coco coordinates.
    """
    shapely_polygon = get_shapely_box(x=bbox[0], y=bbox[1], width=bbox[2], height=bbox[3])
    shapely_multipolygon = MultiPolygon([shapely_polygon])
    return cls(multipolygon=shapely_multipolygon, slice_bbox=slice_bbox)
from_coco_segmentation(segmentation, slice_bbox=None) classmethod

Init ShapelyAnnotation from coco segmentation.

List[List]

[[1, 1, 325, 125, 250, 200, 5, 200]]

slice_bbox (List[int]): [xmin, ymin, width, height] Should have the same format as the output of the get_bbox_from_shapely function. Is used to calculate sliced coco coordinates.

Source code in sahi/utils/shapely.py
@classmethod
def from_coco_segmentation(cls, segmentation, slice_bbox=None):
    """Init ShapelyAnnotation from coco segmentation.

    segmentation : List[List]
        [[1, 1, 325, 125, 250, 200, 5, 200]]
    slice_bbox (List[int]): [xmin, ymin, width, height]
        Should have the same format as the output of the get_bbox_from_shapely function.
        Is used to calculate sliced coco coordinates.
    """
    shapely_multipolygon = get_shapely_multipolygon(segmentation)
    return cls(multipolygon=shapely_multipolygon, slice_bbox=slice_bbox)
get_buffered_shapely_annotation(distance=3, resolution=16, quadsegs=None, cap_style=CAP_STYLE.round, join_style=JOIN_STYLE.round, mitre_limit=5.0, single_sided=False)

Approximates the present polygon to have a valid polygon shape.

For more, check: https://shapely.readthedocs.io/en/stable/manual.html#object.buffer

Source code in sahi/utils/shapely.py
def get_buffered_shapely_annotation(
    self,
    distance=3,
    resolution=16,
    quadsegs=None,
    cap_style=CAP_STYLE.round,
    join_style=JOIN_STYLE.round,
    mitre_limit=5.0,
    single_sided=False,
):
    """Approximates the present polygon to have a valid polygon shape.

    For more, check: https://shapely.readthedocs.io/en/stable/manual.html#object.buffer
    """
    buffered_polygon = self.multipolygon.buffer(
        distance=distance,
        resolution=resolution,
        quadsegs=quadsegs,
        cap_style=cap_style,
        join_style=join_style,
        mitre_limit=mitre_limit,
        single_sided=single_sided,
    )
    shapely_annotation = ShapelyAnnotation(MultiPolygon([buffered_polygon]))
    return shapely_annotation
get_intersection(polygon)

Accepts shapely polygon object and returns the intersection in ShapelyAnnotation format.

Source code in sahi/utils/shapely.py
def get_intersection(self, polygon: Polygon):
    """Accepts shapely polygon object and returns the intersection in ShapelyAnnotation format."""
    # convert intersection polygon to list of tuples
    intersection = self.multipolygon.intersection(polygon)
    # if polygon is box then set slice_box property
    if (
        len(polygon.exterior.xy[0]) == 5
        and polygon.exterior.xy[0][0] == polygon.exterior.xy[0][1]
        and polygon.exterior.xy[0][2] == polygon.exterior.xy[0][3]
    ):
        coco_bbox, _ = get_bbox_from_shapely(polygon)
        slice_bbox = coco_bbox
    else:
        slice_bbox = None
    # convert intersection to multipolygon
    if intersection.geom_type == "Polygon":
        intersection_multipolygon = MultiPolygon([intersection])
    elif intersection.geom_type == "MultiPolygon":
        intersection_multipolygon = intersection
    else:
        intersection_multipolygon = MultiPolygon([])
    # create shapely annotation from intersection multipolygon
    intersection_shapely_annotation = ShapelyAnnotation(intersection_multipolygon, slice_bbox)

    return intersection_shapely_annotation
to_coco_bbox()

[xmin, ymin, width, height]

Source code in sahi/utils/shapely.py
def to_coco_bbox(self):
    """[xmin, ymin, width, height]"""
    return self.to_xywh()
to_coco_segmentation()

[ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

Source code in sahi/utils/shapely.py
def to_coco_segmentation(self):
    """
    [
        [x1, y1, x2, y2, x3, y3, ...],
        [x1, y1, x2, y2, x3, y3, ...],
        ...
    ]
    """
    coco_segmentation: list = []
    for shapely_polygon in self.multipolygon.geoms:
        # create list_of_points for selected shapely_polygon
        if shapely_polygon.area != 0:
            x_coords = shapely_polygon.exterior.coords.xy[0]
            y_coords = shapely_polygon.exterior.coords.xy[1]
            # fix coord by slice_bbox
            if self.slice_bbox:
                minx = self.slice_bbox[0]
                miny = self.slice_bbox[1]
                x_coords = [x_coord - minx for x_coord in x_coords]
                y_coords = [y_coord - miny for y_coord in y_coords]
            # convert intersection to coco style segmentation annotation
            coco_polygon: list[None | int] = [None] * (len(x_coords) * 2)
            coco_polygon[0::2] = [int(coord) for coord in x_coords]
            coco_polygon[1::2] = [int(coord) for coord in y_coords]
        else:
            coco_polygon = []
        # remove if first and last points are duplicate
        if coco_polygon[:2] == coco_polygon[-2:]:
            del coco_polygon[-2:]
        # append coco_polygon to coco_segmentation
        coco_polygon = [point for point in coco_polygon] if coco_polygon else coco_polygon
        coco_segmentation.append(coco_polygon)
    return coco_segmentation
to_list()

[ [(x1, y1), (x2, y2), (x3, y3), ...], [(x1, y1), (x2, y2), (x3, y3), ...], ... ]

Source code in sahi/utils/shapely.py
def to_list(self):
    """
    [
        [(x1, y1), (x2, y2), (x3, y3), ...],
        [(x1, y1), (x2, y2), (x3, y3), ...],
        ...
    ]
    """
    list_of_list_of_points: list = []
    for shapely_polygon in self.multipolygon.geoms:
        # create list_of_points for selected shapely_polygon
        if shapely_polygon.area != 0:
            x_coords = shapely_polygon.exterior.coords.xy[0]
            y_coords = shapely_polygon.exterior.coords.xy[1]
            # fix coord by slice_bbox
            if self.slice_bbox:
                minx = self.slice_bbox[0]
                miny = self.slice_bbox[1]
                x_coords = [x_coord - minx for x_coord in x_coords]
                y_coords = [y_coord - miny for y_coord in y_coords]
            list_of_points = list(zip(x_coords, y_coords))
        else:
            list_of_points = []
        # append list_of_points to list_of_list_of_points
        list_of_list_of_points.append(list_of_points)
    # return result
    return list_of_list_of_points
to_opencv_contours()

[ [[[1, 1]], [[325, 125]], [[250, 200]], [[5, 200]]], [[[1, 1]], [[325, 125]], [[250, 200]], [[5, 200]]] ]

Source code in sahi/utils/shapely.py
def to_opencv_contours(self):
    """[ [[[1, 1]], [[325, 125]], [[250, 200]], [[5, 200]]], [[[1, 1]], [[325, 125]], [[250, 200]], [[5, 200]]] ]"""
    opencv_contours: list = []
    for shapely_polygon in self.multipolygon.geoms:
        # create opencv_contour for selected shapely_polygon
        if shapely_polygon.area != 0:
            x_coords = shapely_polygon.exterior.coords.xy[0]
            y_coords = shapely_polygon.exterior.coords.xy[1]
            # fix coord by slice_bbox
            if self.slice_bbox:
                minx = self.slice_bbox[0]
                miny = self.slice_bbox[1]
                x_coords = [x_coord - minx for x_coord in x_coords]
                y_coords = [y_coord - miny for y_coord in y_coords]
            opencv_contour = [[[int(x_coords[ind]), int(y_coords[ind])]] for ind in range(len(x_coords))]
        else:
            opencv_contour: list = []
        # append opencv_contour to opencv_contours
        opencv_contours.append(opencv_contour)
    # return result
    return opencv_contours
to_voc_bbox()

[xmin, ymin, xmax, ymax]

Source code in sahi/utils/shapely.py
def to_voc_bbox(self):
    """[xmin, ymin, xmax, ymax]"""
    return self.to_xyxy()
to_xywh()

[xmin, ymin, width, height]

Source code in sahi/utils/shapely.py
def to_xywh(self):
    """[xmin, ymin, width, height]"""
    if self.multipolygon.area != 0:
        coco_bbox, _ = get_bbox_from_shapely(self.multipolygon)
        # fix coord by slice box
        if self.slice_bbox:
            minx = self.slice_bbox[0]
            miny = self.slice_bbox[1]
            coco_bbox[0] = coco_bbox[0] - minx
            coco_bbox[1] = coco_bbox[1] - miny
    else:
        coco_bbox: list = []
    return coco_bbox
to_xyxy()

[xmin, ymin, xmax, ymax]

Source code in sahi/utils/shapely.py
def to_xyxy(self):
    """[xmin, ymin, xmax, ymax]"""
    if self.multipolygon.area != 0:
        _, voc_bbox = get_bbox_from_shapely(self.multipolygon)
        # fix coord by slice box
        if self.slice_bbox:
            minx = self.slice_bbox[0]
            miny = self.slice_bbox[1]
            voc_bbox[0] = voc_bbox[0] - minx
            voc_bbox[2] = voc_bbox[2] - minx
            voc_bbox[1] = voc_bbox[1] - miny
            voc_bbox[3] = voc_bbox[3] - miny
    else:
        voc_bbox = []
    return voc_bbox
Functions
get_bbox_from_shapely(shapely_object)

Accepts shapely box/poly object and returns its bounding box in coco and voc formats.

Source code in sahi/utils/shapely.py
def get_bbox_from_shapely(shapely_object):
    """Accepts shapely box/poly object and returns its bounding box in coco and voc formats."""
    minx, miny, maxx, maxy = shapely_object.bounds
    width = maxx - minx
    height = maxy - miny
    coco_bbox = [minx, miny, width, height]
    voc_bbox = [minx, miny, maxx, maxy]

    return coco_bbox, voc_bbox
get_shapely_box(x, y, width, height)

Accepts coco style bbox coords and converts it to shapely box object.

Source code in sahi/utils/shapely.py
def get_shapely_box(x: int, y: int, width: int, height: int) -> Polygon:
    """Accepts coco style bbox coords and converts it to shapely box object."""
    minx = x
    miny = y
    maxx = x + width
    maxy = y + height
    shapely_box = box(minx, miny, maxx, maxy)

    return shapely_box
get_shapely_multipolygon(coco_segmentation)

Accepts coco style polygon coords and converts it to valid shapely multipolygon object.

Source code in sahi/utils/shapely.py
def get_shapely_multipolygon(coco_segmentation: list[list]) -> MultiPolygon:
    """Accepts coco style polygon coords and converts it to valid shapely multipolygon object."""

    def filter_polygons(geometry):
        """Filters out and returns only Polygon or MultiPolygon components of a geometry.

        If geometry is a Polygon, it converts it into a MultiPolygon. If it's a GeometryCollection, it filters to create
        a MultiPolygon from any Polygons in the collection. Returns an empty MultiPolygon if no Polygon or MultiPolygon
        components are found.

        Args:
            geometry: A shapely geometry object (Polygon, MultiPolygon, GeometryCollection, etc.)

        Returns: MultiPolygon
        """
        if isinstance(geometry, Polygon):
            return MultiPolygon([geometry])
        elif isinstance(geometry, MultiPolygon):
            return geometry
        elif isinstance(geometry, GeometryCollection):
            polygons = [
                geom.geoms if isinstance(geom, MultiPolygon) else geom
                for geom in geometry.geoms
                if isinstance(geom, (Polygon, MultiPolygon))
            ]
            return MultiPolygon(polygons) if polygons else MultiPolygon()
        return MultiPolygon()

    polygon_list = []
    for coco_polygon in coco_segmentation:
        point_list = list(zip(coco_polygon[0::2], coco_polygon[1::2]))
        shapely_polygon = Polygon(point_list)
        polygon_list.append(shapely_polygon)
    shapely_multipolygon = MultiPolygon(polygon_list)

    if not shapely_multipolygon.is_valid:
        shapely_multipolygon = filter_polygons(make_valid(shapely_multipolygon))

    return shapely_multipolygon
torch_utils
Functions
select_device(device=None)

Selects torch device.

Parameters:

Name Type Description Default
device str | None

"cpu", "mps", "cuda", "cuda:0", "cuda:1", etc. When no device string is given, the order of preference to try is: cuda:0 > mps > cpu

None

Returns:

Type Description
device

torch.device

Inspired by https://github.com/ultralytics/yolov5/blob/6371de8879e7ad7ec5283e8b95cc6dd85d6a5e72/utils/torch_utils.py#L107

Source code in sahi/utils/torch_utils.py
def select_device(device: str | None = None) -> torch.device:
    """Selects torch device.

    Args:
        device: "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc.
                When no device string is given, the order of preference
                to try is: cuda:0 > mps > cpu

    Returns:
        torch.device

    Inspired by https://github.com/ultralytics/yolov5/blob/6371de8879e7ad7ec5283e8b95cc6dd85d6a5e72/utils/torch_utils.py#L107
    """
    import torch

    if device == "cuda" or device is None:
        device = "cuda:0"
    device = str(device).strip().lower().replace("cuda:", "").replace("none", "")  # to string, 'cuda:0' to '0'
    cpu = device == "cpu"
    mps = device == "mps"  # Apple Metal Performance Shaders (MPS)
    if cpu or mps:
        environ["CUDA_VISIBLE_DEVICES"] = "-1"  # force torch.cuda.is_available() = False
    elif device:  # non-cpu device requested
        environ["CUDA_VISIBLE_DEVICES"] = device  # set environment variable - must be before assert is_available()

    cuda_id_pattern = r"^(0|[1-9]\d*)$"
    valid_cuda_id = bool(re.fullmatch(cuda_id_pattern, device))

    if not cpu and not mps and torch.cuda.is_available() and valid_cuda_id:  # prefer GPU if available
        arg = f"cuda:{device}" if device else "cuda:0"
    elif mps and getattr(torch, "has_mps", False) and torch.backends.mps.is_available():  # prefer MPS if available
        arg = "mps"
    else:  # revert to CPU
        arg = "cpu"

    return torch.device(arg)
to_float_tensor(img)

Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W).

Parameters:

Name Type Description Default
img ndarray | Image

PIL.Image or numpy array

required

Returns: torch.tensor

Source code in sahi/utils/torch_utils.py
def to_float_tensor(img: np.ndarray | Image) -> torch.Tensor:
    """Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C
    x H x W).

    Args:
        img: PIL.Image or numpy array
    Returns:
        torch.tensor
    """
    import torch

    nparray: np.ndarray
    if isinstance(img, np.ndarray):
        nparray = img
    else:
        nparray = np.array(img)
    nparray = nparray.transpose((2, 0, 1))
    tensor = torch.from_numpy(np.array(nparray)).float()
    if tensor.max() > 1:
        tensor /= 255
    return tensor