Skip to content

CV

sahi.utils.cv

Classes

Colors

Source code in sahi/utils/cv.py
class Colors:
    def __init__(self):
        hex = (
            "FF3838",
            "2C99A8",
            "FF701F",
            "6473FF",
            "CFD231",
            "48F90A",
            "92CC17",
            "3DDB86",
            "1A9334",
            "00D4BB",
            "FF9D97",
            "00C2FF",
            "344593",
            "FFB21D",
            "0018EC",
            "8438FF",
            "520085",
            "CB38FF",
            "FF95C8",
            "FF37C7",
        )
        self.palette = [self.hex_to_rgb("#" + c) for c in hex]
        self.n = len(self.palette)

    def __call__(self, ind, bgr: bool = False):
        """
        Convert an index to a color code.

        Args:
            ind (int): The index to convert.
            bgr (bool, optional): Whether to return the color code in BGR format. Defaults to False.

        Returns:
            tuple: The color code in RGB or BGR format, depending on the value of `bgr`.
        """
        color_codes = self.palette[int(ind) % self.n]
        return (color_codes[2], color_codes[1], color_codes[0]) if bgr else color_codes

    @staticmethod
    def hex_to_rgb(hex_code):
        """
        Converts a hexadecimal color code to RGB format.

        Args:
            hex_code (str): The hexadecimal color code to convert.

        Returns:
            tuple: A tuple representing the RGB values in the order (R, G, B).
        """
        rgb = []
        for i in (0, 2, 4):
            rgb.append(int(hex_code[1 + i : 1 + i + 2], 16))
        return tuple(rgb)
Functions
__call__(ind, bgr=False)

Convert an index to a color code.

Parameters:

Name Type Description Default
ind int

The index to convert.

required
bgr bool

Whether to return the color code in BGR format. Defaults to False.

False

Returns:

Name Type Description
tuple

The color code in RGB or BGR format, depending on the value of bgr.

Source code in sahi/utils/cv.py
def __call__(self, ind, bgr: bool = False):
    """
    Convert an index to a color code.

    Args:
        ind (int): The index to convert.
        bgr (bool, optional): Whether to return the color code in BGR format. Defaults to False.

    Returns:
        tuple: The color code in RGB or BGR format, depending on the value of `bgr`.
    """
    color_codes = self.palette[int(ind) % self.n]
    return (color_codes[2], color_codes[1], color_codes[0]) if bgr else color_codes
hex_to_rgb(hex_code) staticmethod

Converts a hexadecimal color code to RGB format.

Parameters:

Name Type Description Default
hex_code str

The hexadecimal color code to convert.

required

Returns:

Name Type Description
tuple

A tuple representing the RGB values in the order (R, G, B).

Source code in sahi/utils/cv.py
@staticmethod
def hex_to_rgb(hex_code):
    """
    Converts a hexadecimal color code to RGB format.

    Args:
        hex_code (str): The hexadecimal color code to convert.

    Returns:
        tuple: A tuple representing the RGB values in the order (R, G, B).
    """
    rgb = []
    for i in (0, 2, 4):
        rgb.append(int(hex_code[1 + i : 1 + i + 2], 16))
    return tuple(rgb)

Functions

apply_color_mask(image, color)

Applies color mask to given input image.

Parameters:

Name Type Description Default
image
ndarray

The input image to apply the color mask to.

required
color
tuple

The RGB color tuple to use for the mask.

required

Returns:

Type Description

np.ndarray: The resulting image with the applied color mask.

Source code in sahi/utils/cv.py
def apply_color_mask(image: np.ndarray, color: Tuple[int, int, int]):
    """
    Applies color mask to given input image.

    Args:
        image (np.ndarray): The input image to apply the color mask to.
        color (tuple): The RGB color tuple to use for the mask.

    Returns:
        np.ndarray: The resulting image with the applied color mask.
    """
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)

    (r[image == 1], g[image == 1], b[image == 1]) = color
    colored_mask = np.stack([r, g, b], axis=2)
    return colored_mask

convert_image_to(read_path, extension='jpg', grayscale=False)

Reads an image from the given path and saves it with the specified extension.

Parameters:

Name Type Description Default
read_path
str

The path to the image file.

required
extension
str

The desired file extension for the saved image. Defaults to "jpg".

'jpg'
grayscale
bool

Whether to convert the image to grayscale. Defaults to False.

False
Source code in sahi/utils/cv.py
def convert_image_to(read_path, extension: str = "jpg", grayscale: bool = False):
    """
    Reads an image from the given path and saves it with the specified extension.

    Args:
        read_path (str): The path to the image file.
        extension (str, optional): The desired file extension for the saved image. Defaults to "jpg".
        grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to False.
    """
    image = cv2.imread(read_path)
    pre, _ = os.path.splitext(read_path)
    if grayscale:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        pre = pre + "_gray"
    save_path = pre + "." + extension
    cv2.imwrite(save_path, image)

crop_object_predictions(image, object_prediction_list, output_dir='', file_name='prediction_visual', export_format='png')

Crops bounding boxes over the source image and exports it to the output folder.

Parameters:

Name Type Description Default
image
ndarray

The source image to crop bounding boxes from.

required
object_prediction_list

A list of object predictions.

required
output_dir
str

The directory where the resulting visualizations will be exported. Defaults to an empty string.

''
file_name
str

The name of the exported file. The exported file will be saved as output_dir + file_name + ".png". Defaults to "prediction_visual".

'prediction_visual'
export_format
str

The format of the exported file. Can be specified as 'jpg' or 'png'. Defaults to "png".

'png'
Source code in sahi/utils/cv.py
def crop_object_predictions(
    image: np.ndarray,
    object_prediction_list,
    output_dir: str = "",
    file_name: str = "prediction_visual",
    export_format: str = "png",
):
    """
    Crops bounding boxes over the source image and exports it to the output folder.

    Args:
        image (np.ndarray): The source image to crop bounding boxes from.
        object_prediction_list: A list of object predictions.
        output_dir (str): The directory where the resulting visualizations will be exported. Defaults to an empty string.
        file_name (str): The name of the exported file. The exported file will be saved as `output_dir + file_name + ".png"`. Defaults to "prediction_visual".
        export_format (str): The format of the exported file. Can be specified as 'jpg' or 'png'. Defaults to "png".
    """
    # create output folder if not present
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    # add bbox and mask to image if present
    for ind, object_prediction in enumerate(object_prediction_list):
        # deepcopy object_prediction_list so that the original is not altered
        object_prediction = object_prediction.deepcopy()
        bbox = object_prediction.bbox.to_xyxy()
        category_id = object_prediction.category.id
        # crop detections
        # deepcopy crops so that the original is not altered
        cropped_img = copy.deepcopy(
            image[
                int(bbox[1]) : int(bbox[3]),
                int(bbox[0]) : int(bbox[2]),
                :,
            ]
        )
        save_path = os.path.join(
            output_dir,
            file_name + "_box" + str(ind) + "_class" + str(category_id) + "." + export_format,
        )
        cv2.imwrite(save_path, cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR))

exif_transpose(image)

Transpose a PIL image accordingly if it has an EXIF Orientation tag. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()

Parameters:

Name Type Description Default
image
Image

The image to transpose.

required

Returns:

Type Description
Image

Image.Image: The transposed image.

Source code in sahi/utils/cv.py
def exif_transpose(image: Image.Image) -> Image.Image:
    """
    Transpose a PIL image accordingly if it has an EXIF Orientation tag.
    Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()

    Args:
        image (Image.Image): The image to transpose.

    Returns:
        Image.Image: The transposed image.
    """
    exif = image.getexif()
    orientation = exif.get(0x0112, 1)  # default 1
    if orientation > 1:
        method = {
            2: Image.Transpose.FLIP_LEFT_RIGHT,
            3: Image.Transpose.ROTATE_180,
            4: Image.Transpose.FLIP_TOP_BOTTOM,
            5: Image.Transpose.TRANSPOSE,
            6: Image.Transpose.ROTATE_270,
            7: Image.Transpose.TRANSVERSE,
            8: Image.Transpose.ROTATE_90,
        }.get(orientation)
        if method is not None:
            image = image.transpose(method)
            del exif[0x0112]
            image.info["exif"] = exif.tobytes()
    return image

get_bbox_from_bool_mask(bool_mask)

Generate VOC bounding box [xmin, ymin, xmax, ymax] from given boolean mask.

Parameters:

Name Type Description Default
bool_mask
ndarray

2D boolean mask.

required

Returns:

Type Description
Optional[List[int]]

Optional[List[int]]: VOC bounding box [xmin, ymin, xmax, ymax] or None if no bounding box is found.

Source code in sahi/utils/cv.py
def get_bbox_from_bool_mask(bool_mask: np.ndarray) -> Optional[List[int]]:
    """
    Generate VOC bounding box [xmin, ymin, xmax, ymax] from given boolean mask.

    Args:
        bool_mask (np.ndarray): 2D boolean mask.

    Returns:
        Optional[List[int]]: VOC bounding box [xmin, ymin, xmax, ymax] or None if no bounding box is found.
    """
    rows = np.any(bool_mask, axis=1)
    cols = np.any(bool_mask, axis=0)

    if not np.any(rows) or not np.any(cols):
        return None

    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]
    width = xmax - xmin
    height = ymax - ymin

    if width == 0 or height == 0:
        return None

    return [xmin, ymin, xmax, ymax]

get_bbox_from_coco_segmentation(coco_segmentation)

Generate voc box ([xmin, ymin, xmax, ymax]) from given coco segmentation

Source code in sahi/utils/cv.py
def get_bbox_from_coco_segmentation(coco_segmentation):
    """
    Generate voc box ([xmin, ymin, xmax, ymax]) from given coco segmentation
    """
    xs = []
    ys = []
    for segm in coco_segmentation:
        xs.extend(segm[::2])
        ys.extend(segm[1::2])
    if len(xs) == 0 or len(ys) == 0:
        return None
    xmin = min(xs)
    xmax = max(xs)
    ymin = min(ys)
    ymax = max(ys)
    return [xmin, ymin, xmax, ymax]

get_bool_mask_from_coco_segmentation(coco_segmentation, width, height)

Convert coco segmentation to 2D boolean mask of given height and width

Parameters: - coco_segmentation: list of points representing the coco segmentation - width: width of the boolean mask - height: height of the boolean mask

Returns: - bool_mask: 2D boolean mask of size (height, width)

Source code in sahi/utils/cv.py
def get_bool_mask_from_coco_segmentation(coco_segmentation: List[List[float]], width: int, height: int) -> np.ndarray:
    """
    Convert coco segmentation to 2D boolean mask of given height and width

    Parameters:
    - coco_segmentation: list of points representing the coco segmentation
    - width: width of the boolean mask
    - height: height of the boolean mask

    Returns:
    - bool_mask: 2D boolean mask of size (height, width)
    """
    size = [height, width]
    points = [np.array(point).reshape(-1, 2).round().astype(int) for point in coco_segmentation]
    bool_mask = np.zeros(size)
    bool_mask = cv2.fillPoly(bool_mask, points, (1.0,))
    bool_mask.astype(bool)
    return bool_mask

get_coco_segmentation_from_bool_mask(bool_mask)

Convert boolean mask to coco segmentation format [ [x1, y1, x2, y2, x3, y3, ...], [x1, y1, x2, y2, x3, y3, ...], ... ]

Source code in sahi/utils/cv.py
def get_coco_segmentation_from_bool_mask(bool_mask: np.ndarray) -> List[List[float]]:
    """
    Convert boolean mask to coco segmentation format
    [
        [x1, y1, x2, y2, x3, y3, ...],
        [x1, y1, x2, y2, x3, y3, ...],
        ...
    ]
    """
    # Generate polygons from mask
    mask = np.squeeze(bool_mask)
    mask = mask.astype(np.uint8)
    mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=(0, 0, 0))
    polygons = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE, offset=(-1, -1))
    polygons = polygons[0] if len(polygons) == 2 else polygons[1]
    # Convert polygon to coco segmentation
    coco_segmentation = []
    for polygon in polygons:
        segmentation = polygon.flatten().tolist()
        # at least 3 points needed for a polygon
        if len(segmentation) >= 6:
            coco_segmentation.append(segmentation)
    return coco_segmentation

get_coco_segmentation_from_obb_points(obb_points)

Convert OBB (Oriented Bounding Box) points to COCO polygon format.

Parameters:

Name Type Description Default
obb_points
ndarray

np.ndarray OBB points tensor from ultralytics.engine.results.OBB Shape: (4, 2) containing 4 points with (x,y) coordinates each

required

Returns:

Type Description
List[List[float]]

List[List[float]]: Polygon points in COCO format [[x1, y1, x2, y2, x3, y3, x4, y4], [...], ...]

Source code in sahi/utils/cv.py
def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> List[List[float]]:
    """
    Convert OBB (Oriented Bounding Box) points to COCO polygon format.

    Args:
        obb_points: np.ndarray
            OBB points tensor from ultralytics.engine.results.OBB
            Shape: (4, 2) containing 4 points with (x,y) coordinates each

    Returns:
        List[List[float]]: Polygon points in COCO format
            [[x1, y1, x2, y2, x3, y3, x4, y4], [...], ...]
    """
    # Convert from (4,2) to [x1,y1,x2,y2,x3,y3,x4,y4] format
    points = obb_points.reshape(-1).tolist()

    # Create polygon from points and close it by repeating first point
    polygons = []
    # Add first point to end to close polygon
    closed_polygon = points + [points[0], points[1]]
    polygons.append(closed_polygon)

    return polygons

get_video_reader(source, save_dir, frame_skip_interval, export_visual=False, view_visual=False)

Creates OpenCV video capture object from given video file path.

Parameters:

Name Type Description Default
source
str

Video file path

required
save_dir
str

Video export directory

required
frame_skip_interval
int

Frame skip interval

required
export_visual
bool

Set True if you want to export visuals

False
view_visual
bool

Set True if you want to render visual

False

Returns:

Name Type Description
iterator Generator[Image, None, None]

Pillow Image

video_writer Optional[VideoWriter]

cv2.VideoWriter

video_file_name str

video name with extension

Source code in sahi/utils/cv.py
def get_video_reader(
    source: str,
    save_dir: str,
    frame_skip_interval: int,
    export_visual: bool = False,
    view_visual: bool = False,
) -> Tuple[Generator[Image.Image, None, None], Optional[cv2.VideoWriter], str, int]:
    """
    Creates OpenCV video capture object from given video file path.

    Args:
        source: Video file path
        save_dir: Video export directory
        frame_skip_interval: Frame skip interval
        export_visual: Set True if you want to export visuals
        view_visual: Set True if you want to render visual

    Returns:
        iterator: Pillow Image
        video_writer: cv2.VideoWriter
        video_file_name: video name with extension
    """
    # get video name with extension
    video_file_name = os.path.basename(source)
    # get video from video path
    video_capture = cv2.VideoCapture(source)

    num_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    if view_visual:
        num_frames /= frame_skip_interval + 1
        num_frames = int(num_frames)

    def read_video_frame(video_capture, frame_skip_interval) -> Generator[Image.Image, None, None]:
        if view_visual:
            window_name = "Prediction of {}".format(str(video_file_name))
            cv2.namedWindow(window_name, cv2.WINDOW_AUTOSIZE)
            default_image = np.zeros((480, 640, 3), dtype=np.uint8)
            cv2.imshow(window_name, default_image)

            while video_capture.isOpened:
                frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES)
                video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num + frame_skip_interval)

                k = cv2.waitKey(20)
                frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES)

                if k == 27:
                    print(
                        "\n===========================Closing==========================="
                    )  # Exit the prediction, Key = Esc
                    exit()
                if k == 100:
                    frame_num += 100  # Skip 100 frames, Key = d
                if k == 97:
                    frame_num -= 100  # Prev 100 frames, Key = a
                if k == 103:
                    frame_num += 20  # Skip 20 frames, Key = g
                if k == 102:
                    frame_num -= 20  # Prev 20 frames, Key = f
                video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num)

                ret, frame = video_capture.read()
                if not ret:
                    print("\n=========================== Video Ended ===========================")
                    break
                yield Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

        else:
            while video_capture.isOpened:
                frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES)
                video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num + frame_skip_interval)

                ret, frame = video_capture.read()
                if not ret:
                    print("\n=========================== Video Ended ===========================")
                    break
                yield Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    if export_visual:
        # get video properties and create VideoWriter object
        if frame_skip_interval != 0:
            fps = video_capture.get(cv2.CAP_PROP_FPS)  # original fps of video
            # The fps of export video is increasing during view_image because frame is skipped
            fps = (
                fps / frame_skip_interval
            )  # How many time_interval equals to original fps. One time_interval skip x frames.
        else:
            fps = video_capture.get(cv2.CAP_PROP_FPS)

        w = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        h = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        size = (w, h)
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # pyright: ignore[reportAttributeAccessIssue]
        video_writer = cv2.VideoWriter(os.path.join(save_dir, video_file_name), fourcc, fps, size)
    else:
        video_writer = None

    return read_video_frame(video_capture, frame_skip_interval), video_writer, video_file_name, num_frames

ipython_display(image)

Displays numpy image in notebook.

If input image is in range 0..1, please first multiply img by 255 Assumes image is ndarray of shape [height, width, channels] where channels can be 1, 3 or 4

Source code in sahi/utils/cv.py
def ipython_display(image: np.ndarray):
    """
    Displays numpy image in notebook.

    If input image is in range 0..1, please first multiply img by 255
    Assumes image is ndarray of shape [height, width, channels] where channels can be 1, 3 or 4
    """
    import IPython

    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    _, ret = cv2.imencode(".png", image)
    i = IPython.display.Image(data=ret)  # type: ignore
    IPython.display.display(i)  # type: ignore

normalize_numpy_image(image)

Normalizes numpy image

Source code in sahi/utils/cv.py
def normalize_numpy_image(image: np.ndarray):
    """
    Normalizes numpy image
    """
    return image / np.max(image)

read_image(image_path)

Loads image as a numpy array from the given path.

Parameters:

Name Type Description Default
image_path
str

The path to the image file.

required

Returns:

Type Description
ndarray

numpy.ndarray: The loaded image as a numpy array.

Source code in sahi/utils/cv.py
def read_image(image_path: str) -> np.ndarray:
    """
    Loads image as a numpy array from the given path.

    Args:
        image_path (str): The path to the image file.

    Returns:
        numpy.ndarray: The loaded image as a numpy array.
    """
    # read image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # return image
    return image

read_image_as_pil(image, exif_fix=False)

Loads an image as PIL.Image.Image.

Parameters:

Name Type Description Default
image
Union[Image, str, ndarray]

The image to be loaded. It can be an image path or URL (str), a numpy image (np.ndarray), or a PIL.Image object.

required
exif_fix
bool

Whether to apply an EXIF fix to the image. Defaults to False.

False

Returns:

Type Description
Image

PIL.Image.Image: The loaded image as a PIL.Image object.

Source code in sahi/utils/cv.py
def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool = False) -> Image.Image:
    """
    Loads an image as PIL.Image.Image.

    Args:
        image (Union[Image.Image, str, np.ndarray]): The image to be loaded. It can be an image path or URL (str),
            a numpy image (np.ndarray), or a PIL.Image object.
        exif_fix (bool, optional): Whether to apply an EXIF fix to the image. Defaults to False.

    Returns:
        PIL.Image.Image: The loaded image as a PIL.Image object.
    """
    # https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
    Image.MAX_IMAGE_PIXELS = None

    if isinstance(image, Image.Image):
        image_pil = image
    elif isinstance(image, str):
        # read image if str image path is provided
        try:
            image_pil = Image.open(
                requests.get(image, stream=True).raw if str(image).startswith("http") else image  # type: ignore
            ).convert("RGB")
            if exif_fix:
                image_pil = exif_transpose(image_pil)
        except Exception as e:  # handle large/tiff image reading
            logger.debug(f"OpenCV failed reading image with error {e}, trying skimage instead")
            try:
                import skimage.io
            except ImportError:
                raise ImportError("Please run 'pip install -U scikit-image imagecodecs' for large image handling.")
            image_sk = skimage.io.imread(image).astype(np.uint8)
            if len(image_sk.shape) == 2:  # b&w
                image_pil = Image.fromarray(image_sk, mode="1")
            elif image_sk.shape[2] == 4:  # rgba
                image_pil = Image.fromarray(image_sk, mode="RGBA")
            elif image_sk.shape[2] == 3:  # rgb
                image_pil = Image.fromarray(image_sk, mode="RGB")
            else:
                raise TypeError(f"image with shape: {image_sk.shape[3]} is not supported.")
    elif isinstance(image, np.ndarray):
        if image.shape[0] < 5:  # image in CHW
            image = image[:, :, ::-1]
        image_pil = Image.fromarray(image)
    else:
        raise TypeError("read image with 'pillow' using 'Image.open()'")  # pyright: ignore[reportUnreachable]
    return image_pil

read_large_image(image_path)

Reads a large image from the specified image path.

Parameters:

Name Type Description Default
image_path
str

The path to the image file.

required

Returns:

Name Type Description
tuple

A tuple containing the image data and a flag indicating whether cv2 was used to read the image. The image data is a numpy array representing the image in RGB format. The flag is True if cv2 was used, False otherwise.

Source code in sahi/utils/cv.py
def read_large_image(image_path: str):
    """
    Reads a large image from the specified image path.

    Args:
        image_path (str): The path to the image file.

    Returns:
        tuple: A tuple containing the image data and a flag indicating whether cv2 was used to read the image.
            The image data is a numpy array representing the image in RGB format.
            The flag is True if cv2 was used, False otherwise.
    """
    use_cv2 = True
    # read image, cv2 fails on large files
    try:
        # convert to rgb (cv2 reads in bgr)
        img_cv2 = cv2.imread(image_path, 1)
        image0 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
    except Exception as e:
        logger.debug(f"OpenCV failed reading image with error {e}, trying skimage instead")
        try:
            import skimage.io
        except ImportError:
            raise ImportError(
                'Please run "pip install -U scikit-image" to install scikit-image first for large image handling.'
            )
        image0 = skimage.io.imread(image_path, as_grey=False).astype(np.uint8)  # [::-1]
        use_cv2 = False
    return image0, use_cv2

select_random_color()

Selects a random color from a predefined list of colors.

Returns:

Name Type Description
list

A list representing the RGB values of the selected color.

Source code in sahi/utils/cv.py
def select_random_color():
    """
    Selects a random color from a predefined list of colors.

    Returns:
        list: A list representing the RGB values of the selected color.

    """
    colors = [
        [0, 255, 0],
        [0, 0, 255],
        [255, 0, 0],
        [0, 255, 255],
        [255, 255, 0],
        [255, 0, 255],
        [80, 70, 180],
        [250, 80, 190],
        [245, 145, 50],
        [70, 150, 250],
        [50, 190, 190],
    ]
    return colors[random.randrange(0, 10)]

visualize_object_predictions(image, object_prediction_list, rect_th=None, text_size=None, text_th=None, color=None, hide_labels=False, hide_conf=False, output_dir=None, file_name='prediction_visual', export_format='png')

Visualizes prediction category names, bounding boxes over the source image and exports it to output folder.

Parameters:

Name Type Description Default
object_prediction_list

a list of prediction.ObjectPrediction

required
rect_th
Optional[int]

rectangle thickness

None
text_size
Optional[float]

size of the category name over box

None
text_th
Optional[int]

text thickness

None
color
Optional[tuple]

annotation color in the form: (0, 255, 0)

None
hide_labels
bool

hide labels

False
hide_conf
bool

hide confidence

False
output_dir
Optional[str]

directory for resulting visualization to be exported

None
file_name
Optional[str]

exported file will be saved as: output_dir+file_name+".png"

'prediction_visual'
export_format
Optional[str]

can be specified as 'jpg' or 'png'

'png'
Source code in sahi/utils/cv.py
def visualize_object_predictions(
    image: np.ndarray,
    object_prediction_list,
    rect_th: Optional[int] = None,
    text_size: Optional[float] = None,
    text_th: Optional[int] = None,
    color: Optional[tuple] = None,
    hide_labels: bool = False,
    hide_conf: bool = False,
    output_dir: Optional[str] = None,
    file_name: Optional[str] = "prediction_visual",
    export_format: Optional[str] = "png",
):
    """
    Visualizes prediction category names, bounding boxes over the source image
    and exports it to output folder.

    Args:
        object_prediction_list: a list of prediction.ObjectPrediction
        rect_th: rectangle thickness
        text_size: size of the category name over box
        text_th: text thickness
        color: annotation color in the form: (0, 255, 0)
        hide_labels: hide labels
        hide_conf: hide confidence
        output_dir: directory for resulting visualization to be exported
        file_name: exported file will be saved as: output_dir+file_name+".png"
        export_format: can be specified as 'jpg' or 'png'
    """
    elapsed_time = time.time()
    # deepcopy image so that original is not altered
    image = copy.deepcopy(image)
    # select predefined classwise color palette if not specified
    if color is None:
        colors = Colors()
    else:
        colors = None
    # set rect_th for boxes
    rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.003), 2)
    # set text_th for category names
    text_th = text_th or max(rect_th - 1, 1)
    # set text_size for category names
    text_size = text_size or rect_th / 3

    # add masks or obb polygons to image if present
    for object_prediction in object_prediction_list:
        # deepcopy object_prediction_list so that original is not altered
        object_prediction = object_prediction.deepcopy()
        # arange label to be displayed
        label = f"{object_prediction.category.name}"
        if not hide_conf:
            label += f" {object_prediction.score.value:.2f}"
        # set color
        if colors is not None:
            color = colors(object_prediction.category.id)
        # visualize masks or obb polygons if present
        has_mask = object_prediction.mask is not None
        is_obb_pred = False
        if has_mask:
            segmentation = object_prediction.mask.segmentation
            if len(segmentation) == 1 and len(segmentation[0]) == 8:
                is_obb_pred = True

            if is_obb_pred:
                points = np.array(segmentation).reshape((-1, 1, 2)).astype(np.int32)
                cv2.polylines(image, [points], isClosed=True, color=color or (0, 0, 0), thickness=rect_th)

                if not hide_labels:
                    lowest_point = points[points[:, :, 1].argmax()][0]
                    box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0]
                    outside = lowest_point[1] - box_height - 3 >= 0
                    text_bg_point1 = (
                        lowest_point[0],
                        lowest_point[1] - box_height - 3 if outside else lowest_point[1] + 3,
                    )
                    text_bg_point2 = (lowest_point[0] + box_width, lowest_point[1])
                    cv2.rectangle(
                        image, text_bg_point1, text_bg_point2, color or (0, 0, 0), thickness=-1, lineType=cv2.LINE_AA
                    )
                    cv2.putText(
                        image,
                        label,
                        (lowest_point[0], lowest_point[1] - 2 if outside else lowest_point[1] + box_height + 2),
                        0,
                        text_size,
                        (255, 255, 255),
                        thickness=text_th,
                    )
            else:
                # draw mask
                rgb_mask = apply_color_mask(object_prediction.mask.bool_mask, color or (0, 0, 0))
                image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)

        # add bboxes to image if is_obb_pred=False
        if not is_obb_pred:
            bbox = object_prediction.bbox.to_xyxy()

            # set bbox points
            point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
            # visualize boxes
            cv2.rectangle(
                image,
                point1,
                point2,
                color=color or (0, 0, 0),
                thickness=rect_th,
            )

            if not hide_labels:
                box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
                    0
                ]  # label width, height
                outside = point1[1] - box_height - 3 >= 0  # label fits outside box
                point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
                # add bounding box text
                cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA)  # filled
                cv2.putText(
                    image,
                    label,
                    (point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
                    0,
                    text_size,
                    (255, 255, 255),
                    thickness=text_th,
                )

    # export if output_dir is present
    if output_dir is not None:
        # export image with predictions
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        # save inference result
        save_path = str(Path(output_dir) / ((file_name or "") + "." + (export_format or "")))
        cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

    elapsed_time = time.time() - elapsed_time
    return {"image": image, "elapsed_time": elapsed_time}

visualize_prediction(image, boxes, classes, masks=None, rect_th=None, text_size=None, text_th=None, color=None, hide_labels=False, output_dir=None, file_name='prediction_visual')

Visualizes prediction classes, bounding boxes over the source image and exports it to output folder.

Parameters:

Name Type Description Default
image
ndarray

The source image.

required
boxes
List[List]

List of bounding boxes coordinates.

required
classes
List[str]

List of class labels corresponding to each bounding box.

required
masks
Optional[List[ndarray]]

List of masks corresponding to each bounding box. Defaults to None.

None
rect_th
int

Thickness of the bounding box rectangle. Defaults to None.

None
text_size
float

Size of the text for class labels. Defaults to None.

None
text_th
int

Thickness of the text for class labels. Defaults to None.

None
color
tuple

Color of the bounding box and text. Defaults to None.

None
hide_labels
bool

Whether to hide the class labels. Defaults to False.

False
output_dir
Optional[str]

Output directory to save the visualization. Defaults to None.

None
file_name
Optional[str]

File name for the saved visualization. Defaults to "prediction_visual".

'prediction_visual'

Returns:

Name Type Description
dict

A dictionary containing the visualized image and the elapsed time for the visualization process.

Source code in sahi/utils/cv.py
def visualize_prediction(
    image: np.ndarray,
    boxes: List[List],
    classes: List[str],
    masks: Optional[List[np.ndarray]] = None,
    rect_th: Optional[int] = None,
    text_size: Optional[float] = None,
    text_th: Optional[int] = None,
    color: Optional[tuple] = None,
    hide_labels: bool = False,
    output_dir: Optional[str] = None,
    file_name: Optional[str] = "prediction_visual",
):
    """
    Visualizes prediction classes, bounding boxes over the source image
    and exports it to output folder.

    Args:
        image (np.ndarray): The source image.
        boxes (List[List]): List of bounding boxes coordinates.
        classes (List[str]): List of class labels corresponding to each bounding box.
        masks (Optional[List[np.ndarray]], optional): List of masks corresponding to each bounding box. Defaults to None.
        rect_th (int, optional): Thickness of the bounding box rectangle. Defaults to None.
        text_size (float, optional): Size of the text for class labels. Defaults to None.
        text_th (int, optional): Thickness of the text for class labels. Defaults to None.
        color (tuple, optional): Color of the bounding box and text. Defaults to None.
        hide_labels (bool, optional): Whether to hide the class labels. Defaults to False.
        output_dir (Optional[str], optional): Output directory to save the visualization. Defaults to None.
        file_name (Optional[str], optional): File name for the saved visualization. Defaults to "prediction_visual".

    Returns:
        dict: A dictionary containing the visualized image and the elapsed time for the visualization process.
    """
    elapsed_time = time.time()
    # deepcopy image so that original is not altered
    image = copy.deepcopy(image)
    # select predefined classwise color palette if not specified
    if color is None:
        colors = Colors()
    else:
        colors = None
    # set rect_th for boxes
    rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.003), 2)
    # set text_th for category names
    text_th = text_th or max(rect_th - 1, 1)
    # set text_size for category names
    text_size = text_size or rect_th / 3

    # add masks to image if present
    if masks is not None and color is None:
        logger.error("Cannot add mask, no color tuple given")
    elif masks is not None and color is not None:
        for mask in masks:
            # deepcopy mask so that original is not altered
            mask = copy.deepcopy(mask)
            # draw mask
            rgb_mask = apply_color_mask(np.squeeze(mask), color)
            image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)

    # add bboxes to image if present
    for box_indice in range(len(boxes)):
        # deepcopy boxso that original is not altered
        box = copy.deepcopy(boxes[box_indice])
        class_ = classes[box_indice]

        # set color
        if colors is not None:
            mycolor = colors(class_)
        elif color is not None:
            mycolor = color
        else:
            logger.error("color cannot be defined")
            continue

        # set bbox points
        point1, point2 = [int(box[0]), int(box[1])], [int(box[2]), int(box[3])]
        # visualize boxes
        cv2.rectangle(
            image,
            point1,
            point2,
            color=mycolor,
            thickness=rect_th,
        )

        if not hide_labels:
            # arange bounding box text location
            label = f"{class_}"
            box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
                0
            ]  # label width, height
            outside = point1[1] - box_height - 3 >= 0  # label fits outside box
            point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
            # add bounding box text
            cv2.rectangle(image, point1, point2, color or (0, 0, 0), -1, cv2.LINE_AA)  # filled
            cv2.putText(
                image,
                label,
                (point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
                0,
                text_size,
                (255, 255, 255),
                thickness=text_th,
            )
    if output_dir:
        # create output folder if not present
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        # save inference result
        save_path = os.path.join(output_dir, (file_name or "unknown") + ".png")
        cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

    elapsed_time = time.time() - elapsed_time
    return {"image": image, "elapsed_time": elapsed_time}