Source code for time_stream.qc

from abc import ABC, abstractmethod
from datetime import date, datetime, time
from typing import List, Optional, Tuple, Type, Union

import polars as pl

from time_stream import TimeSeries
from time_stream.enums import ClosedInterval
from time_stream.utils import get_date_filter

# Registry for built-in QC checks
_QC_REGISTRY = {}


def register_qc_check(cls: Type["QCCheck"]) -> Type["QCCheck"]:
    """Decorator to register quality control check classes using their name attribute.

    Args:
        cls: The quality control class to register.

    Returns:
        The decorated class.
    """
    _QC_REGISTRY[cls.name] = cls
    return cls


class QCCheck(ABC):
    """Base class for quality control checks."""

    _ts = None

    @property
    def ts(self) -> TimeSeries:
        if self._ts is None:
            raise AttributeError("TimeSeries has not been initialised for this QC check.")
        return self._ts

    @property
    @abstractmethod
    def name(self) -> str:
        """Return the name of the QC check."""
        pass

    @abstractmethod
    def expr(self, check_column: str) -> pl.Expr:
        """Return the Polars expression for this QC check.

        Args:
            check_column: The column to apply the check to.

        Returns:
            Boolean expression where True indicates values that should be flagged.
        """
        pass

    @classmethod
    def get(cls, check: Union[str, "QCCheck", Type["QCCheck"]], **kwargs) -> "QCCheck":
        """Factory method to get a QC check instance from string names or class type.

        Args:
            check: The QC check specification, which can be:
                - A string name: e.g. "threshold", "range", "spike"
                - A class type: ThresholdCheck, RangeCheck, etc.
                - An instance: ThresholdCheck(), or any QCCheck instance
            **kwargs: Parameters specific to the check type, used to initialise the class object.
                      Ignored if check is already an instance of the class.

        Returns:
            An instance of the appropriate QCCheck subclass.

        Raises:
            KeyError: If a string name is not registered as a QC check.
            TypeError: If the input type is not supported or a class doesn't inherit from QCCheck.

        Examples:
            >>> # From string
            >>> qc = QCCheck.get("threshold")
            >>>
            >>> # From class
            >>> qc = QCCheck.get(ThresholdCheck)
            >>>
            >>> # From instance
            >>> qc = QCCheck.get(ThresholdCheck(arg1, arg2))
        """
        # If it's already an instance, return it
        if isinstance(check, QCCheck):
            return check

        # If it's a string, look it up in the registry
        if isinstance(check, str):
            try:
                return _QC_REGISTRY[check](**kwargs)
            except KeyError:
                raise KeyError(f"Unknown QC check '{check}'.")

        # If it's a class, check the subclass type and return
        elif isinstance(check, type):
            if issubclass(check, QCCheck):
                return check(**kwargs)  # type: ignore[misc]
            else:
                raise TypeError(f"QC check class {check.__name__} must inherit from QCCheck")

        else:
            raise TypeError(f"QC check must be a string or a QCCheck class. Got {type(check).__name__}")

    def apply(
        self,
        ts: TimeSeries,
        check_column: str,
        observation_interval: Optional[datetime | Tuple[datetime, datetime | None]] = None,
    ) -> pl.Series:
        """Apply the QC check to the TimeSeries.

        Args:
            ts: The TimeSeries to check.
            check_column: The column to perform the check on.
            observation_interval: Optional time interval to limit the check to.

        Returns:
            pl.Series: Boolean series of the resolved expression on the TimeSeries.
        """
        # Validate column exists
        if check_column not in ts.columns and check_column != ts.time_name:
            raise KeyError(f"Check column '{check_column}' not found in TimeSeries.")

        # Set the timeseries property, in case class expr method needs access to properties in the object
        self._ts = ts

        # Get the check expression
        check_expr = self.expr(check_column)

        # Apply observation interval filter if specified
        if observation_interval:
            date_filter = get_date_filter(ts.time_name, observation_interval)
            check_expr = check_expr & date_filter

        # Evaluate and return the result of the QC check
        #   Naming the series to an empty string so as not to cause confusion.
        #   Up to user if they want to name it and add on to the TimeSeries dataframe.
        result = ts.df.select(check_expr).to_series().alias("")

        return result


[docs] @register_qc_check class ComparisonCheck(QCCheck): """Compares values against a given value using a comparison operator.""" name = "comparison" def __init__(self, compare_to: float | List, operator: str, flag_na: Optional[bool] = False) -> None: """Initialize comparison check. Args: compare_to: The value for comparison. operator: Comparison operator. One of: '>', '>=', '<', '<=', '==', '!=', 'is_in'. flag_na: If True, also flag NaN/null values as failing the check. Defaults to False. """ self.compare_to = compare_to self.operator = operator self.flag_na = flag_na def expr(self, check_column: str) -> pl.Expr: """Return the Polars expression for threshold checking.""" operator_map = { ">": pl.col(check_column) > self.compare_to, ">=": pl.col(check_column) >= self.compare_to, "<": pl.col(check_column) < self.compare_to, "<=": pl.col(check_column) <= self.compare_to, "==": pl.col(check_column) == self.compare_to, "!=": pl.col(check_column) != self.compare_to, "is_in": pl.col(check_column).is_in(self.compare_to), } if self.operator not in operator_map: raise KeyError(f"Invalid operator '{self.operator}'. Use: {', '.join(operator_map.keys())}") operator_expr = operator_map[self.operator] if self.flag_na: operator_expr = operator_expr | pl.col(check_column).is_null() return operator_expr
[docs] @register_qc_check class RangeCheck(QCCheck): """Check that values fall within an acceptable range.""" name = "range" def __init__( self, min_value: float | time | date | datetime, max_value: float | time | date | datetime, closed: Optional[str | ClosedInterval] = "both", within: Optional[bool] = True, ) -> None: """Initialize range check. Args: min_value: Minimum of the range. max_value: Maximum of the range. closed: Define which sides of the interval are closed (inclusive) {'both', 'left', 'right', 'none'} (default = "both") within: Whether values get flagged when within or outside the range (default = True (within)). """ self.min_value = min_value self.max_value = max_value self.closed = ClosedInterval(closed) self.within = within def expr(self, check_column: str) -> pl.Expr: """Return the Polars expression for range checking.""" if type(self.min_value) is not type(self.max_value): raise TypeError("'min_value' and 'max_value' must be of same type") check_type = type(self.min_value) # Check if we're doing a time-based range check if check_type is time: check_column = pl.col(check_column).dt.time() # Consider ranges that cross midnight, e.g. min_value = 11:00, max_value = 01:00 if self.min_value > self.max_value: # Swap the values so the comparison operators work the correct way around self.min_value, self.max_value = self.max_value, self.min_value # Reverse the within parameter, as we've swapped the min/max logic self.within = not self.within # We also need to swap the close parameter (if "both" or "none") # Don't have to change "left" or "right" as it shakes out the same even when reversing the min/max if self.closed == ClosedInterval.BOTH: self.closed = ClosedInterval.NONE elif self.closed == ClosedInterval.NONE: self.closed = ClosedInterval.BOTH elif check_type is date: # For datetime.date objects (NOT datetime.datetime!), we want to consider the whole date part of the column check_column = pl.col(check_column).dt.date() else: # This should handle numeric objects and datetime.datetime objects check_column = pl.col(check_column) in_range = check_column.is_between( self.min_value, self.max_value, closed=self.closed.value, # type: ignore[arg-type] ignore Literal typing as the enum constrains the values ) return in_range if self.within else ~in_range
@register_qc_check class TimeRangeCheck(RangeCheck): """Flag rows where the primary time column of the time series fall within an acceptable range. This can either be used with min / max values of: - datetime.time : Useful for scenarios where there are consistent errors at a certain time of day, e.g., during an automated sensor calibration time. - datetime.date : Useful for scenarios where a specific date range is known to be bad, e.g., during a time of sensor errors not picked up elsewhere. - datetime.datetime : As above, but where there you need to add a time to the date range as well. Note: This is equivalent to using `RangeCheck` with `check_column = ts.time_name`. However, adding this as a convenience method as it may not be obvious that the `RangeCheck` can be used for this purpose. """ name = "time_range" def expr(self, _: str) -> pl.Expr: return super().expr(self.ts.time_name)
[docs] @register_qc_check class SpikeCheck(QCCheck): """Detect spikes by assessing differences with neighboring values.""" name = "spike" def __init__(self, threshold: float): """Initialize spike detection check. Args: threshold: The spike detection threshold. """ self.threshold = threshold def expr(self, check_column: str) -> pl.Expr: """Return the Polars expression for spike detection. The algorithm: 1. Calculate differences between current value and neighbors 2. Compute total combined difference and skew 3. Flag where (total_difference - skew) > threshold * 2 """ # Calculate differences with temporal neighbors prev_val = pl.col(check_column).shift(1) next_val = pl.col(check_column).shift(-1) diff_prev = pl.col(check_column) - prev_val diff_next = next_val - pl.col(check_column) # Calculate total difference and skew d = (diff_prev - diff_next).abs() skew = (diff_prev.abs() - diff_next.abs()).abs() d_no_skew = d - skew # Double the threshold since we're summing differences return d_no_skew > (self.threshold * 2.0)