Skip to content

Huggingface Model

sahi.models.huggingface

HuggingFace Transformers detection model wrapper for SAHI.

Provides integration with Hugging Face Transformers library for object detection and instance segmentation models like DETR variants.

Classes

HuggingfaceDetectionModel

Bases: DetectionModel

HuggingFace Transformers object detection model.

Supports DETR-style object detection models and GroundingDINO-style zero-shot detection models.

Source code in sahi/models/huggingface.py
class HuggingfaceDetectionModel(DetectionModel):
    """HuggingFace Transformers object detection model.

    Supports DETR-style object detection models and GroundingDINO-style zero-shot detection models.
    """

    def __init__(
        self,
        model_path: str | None = None,
        model: object | None = None,
        processor: object | None = None,
        config_path: str | None = None,
        device: str | None = None,
        mask_threshold: float = 0.5,
        confidence_threshold: float = 0.3,
        category_mapping: dict | None = None,
        category_remapping: dict | None = None,
        load_at_init: bool = True,
        image_size: int | None = None,
        token: str | None = None,
        text_prompt: str | None = None,
        text_labels: list[str] | None = None,
        text_threshold: float = 0.25,
    ) -> None:
        """Initialize HuggingFace detection model."""
        self._processor = processor
        self._original_shapes: list[tuple[int, ...]] = []
        self._token = token
        self.text_prompt = text_prompt
        self.text_labels = text_labels
        self.text_threshold = text_threshold
        self._original_input_ids: Any | None = None
        self._is_zero_shot_model = False
        self._category_name_to_id: dict[str, int] = {}
        existing_packages = getattr(self, "required_packages", None) or []
        self.required_packages = [*list(existing_packages), "torch", "transformers"]
        ensure_package_minimum_version("transformers", "4.42.0")
        super().__init__(
            model_path,
            model,
            config_path,
            device,
            mask_threshold,
            confidence_threshold,
            category_mapping,
            category_remapping,
            load_at_init,
            image_size,
        )

    @property
    def processor(self) -> Any:
        """Return the image processor."""
        return self._processor

    @property
    def image_shapes(self) -> list:
        """Return original image shapes."""
        # TODO: remove this property in a future release; use _original_shapes directly
        return self._original_shapes

    @property
    def num_categories(self) -> int:
        """Returns number of categories."""
        if self._is_zero_shot_model:
            return len(self.category_mapping)
        return self.model.config.num_labels  # type: ignore[attr-defined]

    def load_model(self) -> None:
        """Load model from HuggingFace."""
        from transformers import AutoConfig, AutoModelForObjectDetection, AutoProcessor

        hf_token = os.getenv("HF_TOKEN", self._token)
        assert self.model_path is not None, "model_path must be provided for HuggingFace models"
        config = AutoConfig.from_pretrained(self.model_path, token=hf_token)
        if self._is_zero_shot(config):
            ensure_package_minimum_version("transformers", "4.49.0")
            from transformers import AutoModelForZeroShotObjectDetection

            model_class: Any = AutoModelForZeroShotObjectDetection
        else:
            model_class = AutoModelForObjectDetection
        model = model_class.from_pretrained(self.model_path, token=hf_token)
        if self.image_size is not None:
            # RT-DETR family expects explicit height/width; other models use shortest_edge
            if model.__class__.__name__.startswith("RTDetr"):
                size: dict[str, int | None] = {"height": self.image_size, "width": self.image_size}
            else:
                size = {"shortest_edge": self.image_size, "longest_edge": None}
            # use_fast=True raises error: AttributeError: 'SizeDict' object has no attribute 'keys'
            processor = AutoProcessor.from_pretrained(
                self.model_path, size=size, do_resize=True, use_fast=False, token=hf_token
            )
        else:
            processor = AutoProcessor.from_pretrained(self.model_path, use_fast=False, token=hf_token)
        self.set_model(model, processor)

    def set_model(self, model: Any, processor: Any | None = None, **kwargs: Any) -> None:
        """Set the detection model and processor."""
        processor = processor or self.processor
        if processor is None:
            raise ValueError(f"'processor' is required to be set, got {processor}.")
        self._is_zero_shot_model = self._is_zero_shot(model)
        valid_processor = "ImageProcessor" in processor.__class__.__name__ or self._is_zero_shot(processor)
        if "ObjectDetection" not in model.__class__.__name__ or not valid_processor:
            raise ValueError(
                "Given 'model' is not an ObjectDetectionModel or 'processor' is not a valid ImageProcessor."
            )
        self.model = model
        self.model.to(self.device)  # type: ignore[attr-defined]
        self._processor = processor
        if self._is_zero_shot_model:
            self.category_mapping = {i: name for i, name in enumerate(self.text_labels or [])}
            self._category_name_to_id = {name: i for i, name in self.category_mapping.items()}
        else:
            self.category_mapping = self.model.config.id2label  # type: ignore[attr-defined]

    def perform_inference(self, image: list | np.ndarray) -> None:
        """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        """
        import torch

        # Confirm model is loaded
        if self.model is None or self.processor is None:
            raise RuntimeError("Model is not loaded, load it by calling .load_model()")

        with torch.no_grad():
            if self._is_zero_shot_model:
                text = self._get_zero_shot_text_input(len(image) if isinstance(image, list) else 1)
                inputs = self.processor(images=image, text=text, return_tensors="pt")
            else:
                inputs = self.processor(images=image, return_tensors="pt")
            inputs = {k: v.to(self.device) if hasattr(v, "to") else v for k, v in inputs.items()}
            outputs = self.model(**inputs)
        self._original_input_ids = inputs.get("input_ids")

        images = image if isinstance(image, list) else [image]
        self._original_shapes = [img.shape for img in images]
        self._original_predictions = outputs

    def perform_batch_inference(self, images: list[np.ndarray]) -> None:
        """Native batch inference: process all images in a single processor + model call.

        Unlike the base-class default (which runs images sequentially), this
        feeds the entire list to the HuggingFace processor at once and executes
        one batched forward pass.  The processor pads images to a uniform size
        internally, so images of different resolutions are handled correctly.

        This avoids setting ``_batch_images`` so
        ``convert_original_predictions`` uses the standard multi-image path
        rather than the sequential fallback.

        Args:
            images: List of numpy arrays (H, W, C) in RGB order.
        """
        self.perform_inference(images)

    # Models using per-class sigmoid (no background class in logits)
    _SIGMOID_CLS_PREFIXES = ("RTDetr", "ConditionalDetr", "DeformableDetr", "Deta", "GroundingDino")

    @property
    def _uses_sigmoid_cls(self) -> bool:
        """True for models that use per-class sigmoid instead of softmax+background."""
        cls_name = self.model.__class__.__name__
        return any(cls_name.startswith(p) for p in self._SIGMOID_CLS_PREFIXES)

    @staticmethod
    def _is_zero_shot(obj: Any) -> bool:
        """Return whether a HuggingFace config/model/processor is a GroundingDINO-style zero-shot detector."""
        if hasattr(obj, "post_process_grounded_object_detection"):
            return True
        return obj.__class__.__name__.startswith("GroundingDino") or getattr(obj, "model_type", "") == "grounding-dino"

    def _get_zero_shot_text_input(self, num_images: int) -> list:
        """Return per-image text input for the HuggingFace zero-shot processor."""
        prompt = self.text_labels or self.text_prompt
        if not prompt:
            raise ValueError("'text_labels' or 'text_prompt' is required for zero-shot HuggingFace detection models.")
        return [prompt] * num_images

    @staticmethod
    def _clamp_bbox(bbox: list, image_width: int, image_height: int) -> list:
        """Clamp a [x1, y1, x2, y2] box to image bounds."""
        x1, y1, x2, y2 = bbox
        return [max(0, x1), max(0, y1), min(x2, image_width), min(y2, image_height)]

    @staticmethod
    def _shift_and_full_shape(
        shift_amount_list: list[list[int | float]],
        full_shape_list: list[list[int | float]] | None,
        image_ind: int,
    ) -> tuple[list[int], list[int] | None]:
        """Return the int-cast shift amount and full shape for a single image."""
        shift_amount = [int(x) for x in shift_amount_list[image_ind]]
        full_shape = None if full_shape_list is None else [int(x) for x in full_shape_list[image_ind]]
        return shift_amount, full_shape

    def _get_zero_shot_category_id(self, category_name: str) -> int:
        """Return a stable category id for a zero-shot label, assigning a new one for unseen phrases."""
        if category_name not in self._category_name_to_id:
            new_id = len(self.category_mapping)
            self._category_name_to_id[category_name] = new_id
            self.category_mapping[new_id] = category_name
        return self._category_name_to_id[category_name]

    def get_valid_predictions(self, logits: Any, pred_boxes: Any) -> tuple:
        """Get predictions above confidence threshold.

        Args:
            logits: torch.Tensor
            pred_boxes: torch.Tensor

        Returns:
            scores: torch.Tensor
            cat_ids: torch.Tensor
            boxes: torch.Tensor
        """
        import torch

        if self._uses_sigmoid_cls:
            # RT-DETR family: per-class sigmoid, logits shape (Q, num_classes) — no background class
            probs = logits.sigmoid()
            scores, cat_ids = probs.max(-1)
            valid_mask = scores >= self.confidence_threshold
        else:
            # DETR family: softmax over (num_classes + 1), last index is no-object/background
            probs = logits.softmax(-1)
            scores = probs.max(-1).values
            cat_ids = probs.argmax(-1)
            valid_detections = torch.where(cat_ids < self.num_categories, 1, 0)
            valid_confidences = torch.where(scores >= self.confidence_threshold, 1, 0)
            valid_mask = valid_detections.logical_and(valid_confidences).bool()

        scores = scores[valid_mask]
        cat_ids = cat_ids[valid_mask]
        boxes = pred_boxes[valid_mask]
        return scores, cat_ids, boxes

    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: list[list[int | float]] | None = [[0, 0]],
        full_shape_list: list[list[int | float]] | None = None,
    ) -> None:
        """Convert predictions to ObjectPrediction list.

        self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.

        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        assert self._original_predictions is not None
        original_predictions: Any = self._original_predictions

        # compatibility for sahi v0.8.15
        shift_amount_list_typed: list[list[int | float]] = fix_shift_amount_list(shift_amount_list)
        full_shape_list_typed: list[list[int | float]] | None = fix_full_shape_list(full_shape_list)

        if self._is_zero_shot_model:
            self._create_object_prediction_list_from_zero_shot_predictions(
                original_predictions=original_predictions,
                shift_amount_list=shift_amount_list_typed,
                full_shape_list=full_shape_list_typed,
            )
            return

        from sahi.utils.cv import yolo_bbox_to_voc_bbox

        n_image = original_predictions.logits.shape[0]
        object_prediction_list_per_image = []
        for image_ind in range(n_image):
            image_height, image_width, _ = self.image_shapes[image_ind]
            scores, cat_ids, boxes = self.get_valid_predictions(
                logits=original_predictions.logits[image_ind], pred_boxes=original_predictions.pred_boxes[image_ind]
            )

            # create object_prediction_list
            object_prediction_list = []

            shift_amount, full_shape = self._shift_and_full_shape(
                shift_amount_list_typed, full_shape_list_typed, image_ind
            )

            for ind in range(len(boxes)):
                category_id = cat_ids[ind].item()
                bbox = yolo_bbox_to_voc_bbox(boxes[ind].tolist(), image_width=image_width, image_height=image_height)
                bbox = self._clamp_bbox(bbox, image_width, image_height)

                object_prediction = ObjectPrediction(
                    bbox=bbox,
                    segmentation=None,
                    category_id=category_id,
                    category_name=self.category_mapping[category_id] if self.category_mapping else "",  # type: ignore[index]
                    shift_amount=shift_amount,
                    score=scores[ind].item(),
                    full_shape=full_shape,
                )
                object_prediction_list.append(object_prediction)
            object_prediction_list_per_image.append(object_prediction_list)

        self._object_prediction_list_per_image = object_prediction_list_per_image

    def _create_object_prediction_list_from_zero_shot_predictions(
        self,
        original_predictions: Any,
        shift_amount_list: list[list[int | float]],
        full_shape_list: list[list[int | float]] | None = None,
    ) -> None:
        """Convert HuggingFace zero-shot detection output to ObjectPrediction objects."""
        if self._original_input_ids is None:
            raise RuntimeError("Zero-shot text input ids are missing. Run .perform_inference() before conversion.")

        target_sizes = [(image_shape[0], image_shape[1]) for image_shape in self.image_shapes]
        results = self.processor.post_process_grounded_object_detection(
            original_predictions,
            input_ids=self._original_input_ids,
            threshold=self.confidence_threshold,
            text_threshold=self.text_threshold,
            target_sizes=target_sizes,
        )

        object_prediction_list_per_image = []
        for image_ind, image_predictions in enumerate(results):
            image_height, image_width, _ = self.image_shapes[image_ind]
            shift_amount, full_shape = self._shift_and_full_shape(shift_amount_list, full_shape_list, image_ind)
            labels = image_predictions.get("text_labels") or image_predictions.get("labels", [])

            object_prediction_list = [
                ObjectPrediction(
                    bbox=self._clamp_bbox(bbox.tolist(), image_width, image_height),
                    segmentation=None,
                    category_id=self._get_zero_shot_category_id(str(name)),
                    category_name=str(name),
                    shift_amount=shift_amount,
                    score=float(score),
                    full_shape=full_shape,
                )
                for score, bbox, name in zip(image_predictions["scores"], image_predictions["boxes"], labels)
                # when fixed text_labels are given, drop combined phrases (e.g. "car truck")
                if not self.text_labels or str(name) in self.text_labels
            ]
            object_prediction_list_per_image.append(object_prediction_list)

        self._object_prediction_list_per_image = object_prediction_list_per_image
Attributes
image_shapes property

Return original image shapes.

num_categories property

Returns number of categories.

processor property

Return the image processor.

Functions
__init__(model_path=None, model=None, processor=None, config_path=None, device=None, mask_threshold=0.5, confidence_threshold=0.3, category_mapping=None, category_remapping=None, load_at_init=True, image_size=None, token=None, text_prompt=None, text_labels=None, text_threshold=0.25)

Initialize HuggingFace detection model.

Source code in sahi/models/huggingface.py
def __init__(
    self,
    model_path: str | None = None,
    model: object | None = None,
    processor: object | None = None,
    config_path: str | None = None,
    device: str | None = None,
    mask_threshold: float = 0.5,
    confidence_threshold: float = 0.3,
    category_mapping: dict | None = None,
    category_remapping: dict | None = None,
    load_at_init: bool = True,
    image_size: int | None = None,
    token: str | None = None,
    text_prompt: str | None = None,
    text_labels: list[str] | None = None,
    text_threshold: float = 0.25,
) -> None:
    """Initialize HuggingFace detection model."""
    self._processor = processor
    self._original_shapes: list[tuple[int, ...]] = []
    self._token = token
    self.text_prompt = text_prompt
    self.text_labels = text_labels
    self.text_threshold = text_threshold
    self._original_input_ids: Any | None = None
    self._is_zero_shot_model = False
    self._category_name_to_id: dict[str, int] = {}
    existing_packages = getattr(self, "required_packages", None) or []
    self.required_packages = [*list(existing_packages), "torch", "transformers"]
    ensure_package_minimum_version("transformers", "4.42.0")
    super().__init__(
        model_path,
        model,
        config_path,
        device,
        mask_threshold,
        confidence_threshold,
        category_mapping,
        category_remapping,
        load_at_init,
        image_size,
    )
get_valid_predictions(logits, pred_boxes)

Get predictions above confidence threshold.

Parameters:

Name Type Description Default
logits Any

torch.Tensor

required
pred_boxes Any

torch.Tensor

required

Returns:

Name Type Description
scores tuple

torch.Tensor

cat_ids tuple

torch.Tensor

boxes tuple

torch.Tensor

Source code in sahi/models/huggingface.py
def get_valid_predictions(self, logits: Any, pred_boxes: Any) -> tuple:
    """Get predictions above confidence threshold.

    Args:
        logits: torch.Tensor
        pred_boxes: torch.Tensor

    Returns:
        scores: torch.Tensor
        cat_ids: torch.Tensor
        boxes: torch.Tensor
    """
    import torch

    if self._uses_sigmoid_cls:
        # RT-DETR family: per-class sigmoid, logits shape (Q, num_classes) — no background class
        probs = logits.sigmoid()
        scores, cat_ids = probs.max(-1)
        valid_mask = scores >= self.confidence_threshold
    else:
        # DETR family: softmax over (num_classes + 1), last index is no-object/background
        probs = logits.softmax(-1)
        scores = probs.max(-1).values
        cat_ids = probs.argmax(-1)
        valid_detections = torch.where(cat_ids < self.num_categories, 1, 0)
        valid_confidences = torch.where(scores >= self.confidence_threshold, 1, 0)
        valid_mask = valid_detections.logical_and(valid_confidences).bool()

    scores = scores[valid_mask]
    cat_ids = cat_ids[valid_mask]
    boxes = pred_boxes[valid_mask]
    return scores, cat_ids, boxes
load_model()

Load model from HuggingFace.

Source code in sahi/models/huggingface.py
def load_model(self) -> None:
    """Load model from HuggingFace."""
    from transformers import AutoConfig, AutoModelForObjectDetection, AutoProcessor

    hf_token = os.getenv("HF_TOKEN", self._token)
    assert self.model_path is not None, "model_path must be provided for HuggingFace models"
    config = AutoConfig.from_pretrained(self.model_path, token=hf_token)
    if self._is_zero_shot(config):
        ensure_package_minimum_version("transformers", "4.49.0")
        from transformers import AutoModelForZeroShotObjectDetection

        model_class: Any = AutoModelForZeroShotObjectDetection
    else:
        model_class = AutoModelForObjectDetection
    model = model_class.from_pretrained(self.model_path, token=hf_token)
    if self.image_size is not None:
        # RT-DETR family expects explicit height/width; other models use shortest_edge
        if model.__class__.__name__.startswith("RTDetr"):
            size: dict[str, int | None] = {"height": self.image_size, "width": self.image_size}
        else:
            size = {"shortest_edge": self.image_size, "longest_edge": None}
        # use_fast=True raises error: AttributeError: 'SizeDict' object has no attribute 'keys'
        processor = AutoProcessor.from_pretrained(
            self.model_path, size=size, do_resize=True, use_fast=False, token=hf_token
        )
    else:
        processor = AutoProcessor.from_pretrained(self.model_path, use_fast=False, token=hf_token)
    self.set_model(model, processor)
perform_batch_inference(images)

Native batch inference: process all images in a single processor + model call.

Unlike the base-class default (which runs images sequentially), this feeds the entire list to the HuggingFace processor at once and executes one batched forward pass. The processor pads images to a uniform size internally, so images of different resolutions are handled correctly.

This avoids setting _batch_images so convert_original_predictions uses the standard multi-image path rather than the sequential fallback.

Parameters:

Name Type Description Default
images list[ndarray]

List of numpy arrays (H, W, C) in RGB order.

required
Source code in sahi/models/huggingface.py
def perform_batch_inference(self, images: list[np.ndarray]) -> None:
    """Native batch inference: process all images in a single processor + model call.

    Unlike the base-class default (which runs images sequentially), this
    feeds the entire list to the HuggingFace processor at once and executes
    one batched forward pass.  The processor pads images to a uniform size
    internally, so images of different resolutions are handled correctly.

    This avoids setting ``_batch_images`` so
    ``convert_original_predictions`` uses the standard multi-image path
    rather than the sequential fallback.

    Args:
        images: List of numpy arrays (H, W, C) in RGB order.
    """
    self.perform_inference(images)
perform_inference(image)

Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Parameters:

Name Type Description Default
image list | ndarray

np.ndarray A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.

required
Source code in sahi/models/huggingface.py
def perform_inference(self, image: list | np.ndarray) -> None:
    """Prediction is performed using self.model and the prediction result is set to self._original_predictions.

    Args:
        image: np.ndarray
            A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
    """
    import torch

    # Confirm model is loaded
    if self.model is None or self.processor is None:
        raise RuntimeError("Model is not loaded, load it by calling .load_model()")

    with torch.no_grad():
        if self._is_zero_shot_model:
            text = self._get_zero_shot_text_input(len(image) if isinstance(image, list) else 1)
            inputs = self.processor(images=image, text=text, return_tensors="pt")
        else:
            inputs = self.processor(images=image, return_tensors="pt")
        inputs = {k: v.to(self.device) if hasattr(v, "to") else v for k, v in inputs.items()}
        outputs = self.model(**inputs)
    self._original_input_ids = inputs.get("input_ids")

    images = image if isinstance(image, list) else [image]
    self._original_shapes = [img.shape for img in images]
    self._original_predictions = outputs
set_model(model, processor=None, **kwargs)

Set the detection model and processor.

Source code in sahi/models/huggingface.py
def set_model(self, model: Any, processor: Any | None = None, **kwargs: Any) -> None:
    """Set the detection model and processor."""
    processor = processor or self.processor
    if processor is None:
        raise ValueError(f"'processor' is required to be set, got {processor}.")
    self._is_zero_shot_model = self._is_zero_shot(model)
    valid_processor = "ImageProcessor" in processor.__class__.__name__ or self._is_zero_shot(processor)
    if "ObjectDetection" not in model.__class__.__name__ or not valid_processor:
        raise ValueError(
            "Given 'model' is not an ObjectDetectionModel or 'processor' is not a valid ImageProcessor."
        )
    self.model = model
    self.model.to(self.device)  # type: ignore[attr-defined]
    self._processor = processor
    if self._is_zero_shot_model:
        self.category_mapping = {i: name for i, name in enumerate(self.text_labels or [])}
        self._category_name_to_id = {name: i for i, name in self.category_mapping.items()}
    else:
        self.category_mapping = self.model.config.id2label  # type: ignore[attr-defined]

Functions