Skip to content

AutoModel

sahi.auto_model

Classes

AutoDetectionModel

Source code in sahi/auto_model.py
class AutoDetectionModel:
    @staticmethod
    def from_pretrained(
        model_type: str,
        model_path: Optional[str] = None,
        model: Optional[Any] = None,
        config_path: Optional[str] = None,
        device: Optional[str] = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: Optional[Dict] = None,
        category_remapping: Optional[Dict] = None,
        load_at_init: bool = True,
        image_size: Optional[int] = None,
        **kwargs,
    ) -> DetectionModel:
        """
        Loads a DetectionModel from given path.

        Args:
            model_type: str
                Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")
            model_path: str
                Path of the detection model (ex. 'model.pt')
            config_path: str
                Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py')
            device: str
                Device, "cpu" or "cuda:0"
            mask_threshold: float
                Value to threshold mask pixels, should be between 0 and 1
            confidence_threshold: float
                All predictions with score < confidence_threshold will be discarded
            category_mapping: dict: str to str
                Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
            category_remapping: dict: str to int
                Remap category ids based on category names, after performing inference e.g. {"car": 3}
            load_at_init: bool
                If True, automatically loads the model at initialization
            image_size: int
                Inference input size.

        Returns:
            Returns an instance of a DetectionModel

        Raises:
            ImportError: If given {model_type} framework is not installed
        """
        if model_type in ULTRALYTICS_MODEL_NAMES:
            model_type = "ultralytics"
        model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
        DetectionModel = import_model_class(model_type, model_class_name)

        return DetectionModel(
            model_path=model_path,
            model=model,
            config_path=config_path,
            device=device,
            mask_threshold=mask_threshold,
            confidence_threshold=confidence_threshold,
            category_mapping=category_mapping,
            category_remapping=category_remapping,
            load_at_init=load_at_init,
            image_size=image_size,
            **kwargs,
        )
Functions
from_pretrained(model_type, model_path=None, model=None, config_path=None, device=None, mask_threshold=0.5, confidence_threshold=0.3, category_mapping=None, category_remapping=None, load_at_init=True, image_size=None, **kwargs) staticmethod

Loads a DetectionModel from given path.

Parameters:

Name Type Description Default
model_type str

str Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")

required
model_path Optional[str]

str Path of the detection model (ex. 'model.pt')

None
config_path Optional[str]

str Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py')

None
device Optional[str]

str Device, "cpu" or "cuda:0"

None
mask_threshold float

float Value to threshold mask pixels, should be between 0 and 1

0.5
confidence_threshold float

float All predictions with score < confidence_threshold will be discarded

0.3
category_mapping Optional[Dict]

dict: str to str Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}

None
category_remapping Optional[Dict]

dict: str to int Remap category ids based on category names, after performing inference e.g. {"car": 3}

None
load_at_init bool

bool If True, automatically loads the model at initialization

True
image_size Optional[int]

int Inference input size.

None

Returns:

Type Description
DetectionModel

Returns an instance of a DetectionModel

Raises:

Type Description
ImportError

If given {model_type} framework is not installed

Source code in sahi/auto_model.py
@staticmethod
def from_pretrained(
    model_type: str,
    model_path: Optional[str] = None,
    model: Optional[Any] = None,
    config_path: Optional[str] = None,
    device: Optional[str] = None,
    mask_threshold: float = 0.5,
    confidence_threshold: float = 0.3,
    category_mapping: Optional[Dict] = None,
    category_remapping: Optional[Dict] = None,
    load_at_init: bool = True,
    image_size: Optional[int] = None,
    **kwargs,
) -> DetectionModel:
    """
    Loads a DetectionModel from given path.

    Args:
        model_type: str
            Name of the detection framework (example: "ultralytics", "huggingface", "torchvision")
        model_path: str
            Path of the detection model (ex. 'model.pt')
        config_path: str
            Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py')
        device: str
            Device, "cpu" or "cuda:0"
        mask_threshold: float
            Value to threshold mask pixels, should be between 0 and 1
        confidence_threshold: float
            All predictions with score < confidence_threshold will be discarded
        category_mapping: dict: str to str
            Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
        category_remapping: dict: str to int
            Remap category ids based on category names, after performing inference e.g. {"car": 3}
        load_at_init: bool
            If True, automatically loads the model at initialization
        image_size: int
            Inference input size.

    Returns:
        Returns an instance of a DetectionModel

    Raises:
        ImportError: If given {model_type} framework is not installed
    """
    if model_type in ULTRALYTICS_MODEL_NAMES:
        model_type = "ultralytics"
    model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
    DetectionModel = import_model_class(model_type, model_class_name)

    return DetectionModel(
        model_path=model_path,
        model=model,
        config_path=config_path,
        device=device,
        mask_threshold=mask_threshold,
        confidence_threshold=confidence_threshold,
        category_mapping=category_mapping,
        category_remapping=category_remapping,
        load_at_init=load_at_init,
        image_size=image_size,
        **kwargs,
    )

Functions