Skip to content

plot_3d

py3dinterpolations.plotting.plot_3d

3D volume plot using plotly.

plot_3d_model(modeler, plot_points=False, scale_points=1.0, volume_kwargs=None)

Plot 3D interpolation result as a plotly Volume.

Parameters:

Name Type Description Default
modeler Modeler

Modeler with prediction results.

required
plot_points bool

Whether to overlay training points.

False
scale_points float

Scale factor for point marker sizes.

1.0
volume_kwargs dict[str, object] | None

Extra kwargs for go.Volume.

None

Returns:

Type Description
Figure

Plotly Figure.

Source code in py3dinterpolations/plotting/plot_3d.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def plot_3d_model(
    modeler: Modeler,
    plot_points: bool = False,
    scale_points: float = 1.0,
    volume_kwargs: dict[str, object] | None = None,
) -> go.Figure:
    """Plot 3D interpolation result as a plotly Volume.

    Args:
        modeler: Modeler with prediction results.
        plot_points: Whether to overlay training points.
        scale_points: Scale factor for point marker sizes.
        volume_kwargs: Extra kwargs for go.Volume.

    Returns:
        Plotly Figure.
    """
    if volume_kwargs is None:
        volume_kwargs = {}

    if modeler.griddata.preprocessing_params is not None:
        gd_reversed = reverse_preprocessing(modeler.griddata)
    else:
        gd_reversed = modeler.griddata

    assert modeler.result is not None
    # ZYX -> XYZ
    values = np.einsum("ZXY->XYZ", modeler.result.interpolated)

    data: list[go.Volume | go.Scatter3d] = [
        go.Volume(
            x=modeler.grid.mesh["X"].flatten(),
            y=modeler.grid.mesh["Y"].flatten(),
            z=modeler.grid.mesh["Z"].flatten(),
            value=values.flatten(),
            opacityscale=[(0, 0), (1, 1)],
            cmin=gd_reversed.specs.vmin,
            cmax=gd_reversed.specs.vmax,
            **volume_kwargs,
        ),
    ]

    if plot_points:
        params = modeler.griddata.preprocessing_params
        if params is not None:
            points = gd_reversed.data.copy().reset_index()
        else:
            points = modeler.griddata.data.copy().reset_index()

        data.append(
            go.Scatter3d(
                x=points["X"],
                y=points["Y"],
                z=points["Z"],
                mode="markers",
                marker=dict(
                    size=points["V"].to_list(),
                    sizemode="area",
                    sizeref=2.0 * max(points["V"]) / (scale_points**2),
                    color=points["V"],
                    sizemin=1,
                ),
            )
        )

    fig = go.Figure(data=data)
    fig.update_scenes(aspectmode="data")
    return fig