Skip to content

models

py3dinterpolations.modelling.models

Model registry for 3D interpolation.

BaseModel

Bases: ABC

Interface for interpolation models.

All models must implement fit() and predict() with consistent signatures.

name abstractmethod property

Human-readable model name.

fit(x, y, z, v) abstractmethod

Fit the model to training data.

Parameters:

Name Type Description Default
x ndarray

X coordinates of training points.

required
y ndarray

Y coordinates of training points.

required
z ndarray

Z coordinates of training points.

required
v ndarray

Values at training points.

required
Source code in py3dinterpolations/modelling/models/base.py
16
17
18
19
20
21
22
23
24
25
26
@abstractmethod
def fit(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, v: np.ndarray) -> None:
    """Fit the model to training data.

    Args:
        x: X coordinates of training points.
        y: Y coordinates of training points.
        z: Z coordinates of training points.
        v: Values at training points.
    """
    ...

predict(grid_x, grid_y, grid_z, **kwargs) abstractmethod

Predict on 1D grid arrays.

Parameters:

Name Type Description Default
grid_x ndarray

1D array of X grid coordinates.

required
grid_y ndarray

1D array of Y grid coordinates.

required
grid_z ndarray

1D array of Z grid coordinates.

required

Returns:

Type Description
InterpolationResult

Interpolation result with at least the interpolated field.

Source code in py3dinterpolations/modelling/models/base.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@abstractmethod
def predict(
    self,
    grid_x: np.ndarray,
    grid_y: np.ndarray,
    grid_z: np.ndarray,
    **kwargs: object,
) -> InterpolationResult:
    """Predict on 1D grid arrays.

    Args:
        grid_x: 1D array of X grid coordinates.
        grid_y: 1D array of Y grid coordinates.
        grid_z: 1D array of Z grid coordinates.

    Returns:
        Interpolation result with at least the interpolated field.
    """
    ...

IDWModel(power=1.0, threshold=1e-10)

Bases: BaseModel

Vectorized IDW interpolation.

Uses numpy broadcasting instead of Python loops for ~1000x speedup on typical workloads. Batches computation for memory safety.

Parameters:

Name Type Description Default
power float

Power parameter controlling distance decay. Higher values give more weight to nearby points.

1.0
threshold float

Distance below which a point is treated as coincident with a training point (exact interpolation).

1e-10
Source code in py3dinterpolations/modelling/models/idw.py
25
26
27
28
29
def __init__(self, power: float = 1.0, threshold: float = 1e-10):
    self._power = power
    self._threshold = threshold
    self._points: np.ndarray | None = None
    self._values: np.ndarray | None = None

fit(x, y, z, v)

Store training data.

Source code in py3dinterpolations/modelling/models/idw.py
31
32
33
34
def fit(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, v: np.ndarray) -> None:
    """Store training data."""
    self._points = np.column_stack([x, y, z])
    self._values = v

predict(grid_x, grid_y, grid_z, **kwargs)

Predict on a regular grid defined by 1D arrays.

Returns:

Type Description
InterpolationResult

InterpolationResult with shape (len(grid_z), len(grid_y), len(grid_x))

InterpolationResult

to match pykrige's output convention.

Source code in py3dinterpolations/modelling/models/idw.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def predict(
    self,
    grid_x: np.ndarray,
    grid_y: np.ndarray,
    grid_z: np.ndarray,
    **kwargs: object,
) -> InterpolationResult:
    """Predict on a regular grid defined by 1D arrays.

    Returns:
        InterpolationResult with shape (len(grid_z), len(grid_y), len(grid_x))
        to match pykrige's output convention.
    """
    if self._points is None:
        msg = "Model must be fit before predicting"
        raise RuntimeError(msg)

    # Build meshgrid in ij (XYZ) indexing for computation
    mx, my, mz = np.meshgrid(grid_x, grid_y, grid_z, indexing="ij")
    query_points = np.column_stack([mx.ravel(), my.ravel(), mz.ravel()])

    # Batch processing for memory safety
    n_points = len(query_points)
    result = np.empty(n_points)
    for start in range(0, n_points, _BATCH_SIZE):
        end = min(start + _BATCH_SIZE, n_points)
        result[start:end] = self._predict_batch(query_points[start:end])

    # Reshape to (X, Y, Z) then transpose to (Z, Y, X) to match pykrige
    interpolated = result.reshape(mx.shape)
    interpolated = np.einsum("xyz->zyx", interpolated)

    return InterpolationResult(interpolated=interpolated, variance=None)

KrigingModel(**kriging_params)

Bases: BaseModel

Ordinary Kriging 3D wrapper around pykrige.

pykrige fits at construction time, so fit() constructs the OrdinaryKriging3D instance.

Parameters:

Name Type Description Default
**kriging_params object

Parameters passed to OrdinaryKriging3D constructor.

{}
Source code in py3dinterpolations/modelling/models/kriging.py
20
21
22
def __init__(self, **kriging_params: object):
    self._params = kriging_params
    self._model: OrdinaryKriging3D | None = None

fit(x, y, z, v)

Fit by constructing the OrdinaryKriging3D model.

Source code in py3dinterpolations/modelling/models/kriging.py
24
25
26
def fit(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, v: np.ndarray) -> None:
    """Fit by constructing the OrdinaryKriging3D model."""
    self._model = OrdinaryKriging3D(x, y, z, v, **self._params)

predict(grid_x, grid_y, grid_z, **kwargs)

Execute kriging on the given grid arrays.

Returns:

Type Description
InterpolationResult

InterpolationResult with interpolated and variance arrays.

InterpolationResult

Shape is (len(grid_z), len(grid_y), len(grid_x)) per pykrige convention.

Source code in py3dinterpolations/modelling/models/kriging.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def predict(
    self,
    grid_x: np.ndarray,
    grid_y: np.ndarray,
    grid_z: np.ndarray,
    **kwargs: object,
) -> InterpolationResult:
    """Execute kriging on the given grid arrays.

    Returns:
        InterpolationResult with interpolated and variance arrays.
        Shape is (len(grid_z), len(grid_y), len(grid_x)) per pykrige convention.
    """
    if self._model is None:
        msg = "Model must be fit before predicting"
        raise RuntimeError(msg)
    interpolated, variance = self._model.execute(
        style="grid",
        xpoints=grid_x,
        ypoints=grid_y,
        zpoints=grid_z,
        **kwargs,
    )
    return InterpolationResult(
        interpolated=interpolated,
        variance=variance,
    )

SklearnModel(estimator, model_name='sklearn')

Bases: BaseModel

Wrapper for any sklearn estimator with fit/predict interface.

Handles classifiers (predict_proba) and regressors (predict).

Parameters:

Name Type Description Default
estimator SklearnEstimator

A sklearn estimator instance.

required
model_name str

Human-readable name for this model.

'sklearn'
Source code in py3dinterpolations/modelling/models/sklearn_model.py
19
20
21
def __init__(self, estimator: SklearnEstimator, model_name: str = "sklearn"):
    self._estimator = estimator
    self._model_name = model_name

fit(x, y, z, v)

Fit the sklearn estimator.

Source code in py3dinterpolations/modelling/models/sklearn_model.py
23
24
25
26
def fit(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, v: np.ndarray) -> None:
    """Fit the sklearn estimator."""
    X = np.column_stack([x, y, z])
    self._estimator.fit(X, v)

predict(grid_x, grid_y, grid_z, **kwargs)

Predict on a regular grid.

Returns:

Type Description
InterpolationResult

InterpolationResult with shape (len(grid_z), len(grid_y), len(grid_x))

InterpolationResult

to match the convention of other models.

Source code in py3dinterpolations/modelling/models/sklearn_model.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def predict(
    self,
    grid_x: np.ndarray,
    grid_y: np.ndarray,
    grid_z: np.ndarray,
    **kwargs: object,
) -> InterpolationResult:
    """Predict on a regular grid.

    Returns:
        InterpolationResult with shape (len(grid_z), len(grid_y), len(grid_x))
        to match the convention of other models.
    """
    mx, my, mz = np.meshgrid(grid_x, grid_y, grid_z, indexing="ij")
    X = np.column_stack([mx.ravel(), my.ravel(), mz.ravel()])

    predictions = self._estimator.predict(X)
    interpolated = predictions.reshape(mx.shape)
    # Transpose from XYZ to ZYX to match pykrige convention
    interpolated = np.einsum("xyz->zyx", interpolated)

    probability = None
    if isinstance(self._estimator, SklearnClassifier):
        proba = self._estimator.predict_proba(X)
        probability = proba.reshape((*mx.shape, -1))

    return InterpolationResult(
        interpolated=interpolated,
        probability=probability,
    )

get_model(model_type, **kwargs)

Instantiate a model by type.

Parameters:

Name Type Description Default
model_type ModelType | str

Model identifier, either a ModelType enum or its string value.

required
**kwargs object

Parameters passed to the model constructor.

{}

Returns:

Type Description
BaseModel

An instantiated model ready for fit().

Raises:

Type Description
ValueError

If model_type is not in the registry.

Source code in py3dinterpolations/modelling/models/__init__.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def get_model(model_type: ModelType | str, **kwargs: object) -> BaseModel:
    """Instantiate a model by type.

    Args:
        model_type: Model identifier, either a ModelType enum or its string value.
        **kwargs: Parameters passed to the model constructor.

    Returns:
        An instantiated model ready for fit().

    Raises:
        ValueError: If model_type is not in the registry.
    """
    model_type = ModelType(model_type)
    cls = MODEL_REGISTRY.get(model_type)
    if cls is None:
        available = list(MODEL_REGISTRY.keys())
        msg = f"Model {model_type!r} not in registry. Available: {available}"
        raise ValueError(msg)
    return cls(**kwargs)