Skip to content

plot_2d

py3dinterpolations.plotting.plot_2d

2D slice plots along a given axis.

plot_2d_model(modeler, axis='Z', plot_points=False, annotate_points=False, figure_width=8)

Plot 2D slices of a 3D interpolation along an axis.

Parameters:

Name Type Description Default
modeler Modeler

Modeler with prediction results.

required
axis str

Axis to slice along ("X", "Y", or "Z").

'Z'
plot_points bool

Whether to overlay training points on slices.

False
annotate_points bool

Whether to annotate point values.

False
figure_width float

Figure width in inches.

8

Returns:

Type Description
Figure

Matplotlib Figure.

Source code in py3dinterpolations/plotting/plot_2d.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def plot_2d_model(
    modeler: Modeler,
    axis: str = "Z",
    plot_points: bool = False,
    annotate_points: bool = False,
    figure_width: float = 8,
) -> Figure:
    """Plot 2D slices of a 3D interpolation along an axis.

    Args:
        modeler: Modeler with prediction results.
        axis: Axis to slice along ("X", "Y", or "Z").
        plot_points: Whether to overlay training points on slices.
        annotate_points: Whether to annotate point values.
        figure_width: Figure width in inches.

    Returns:
        Matplotlib Figure.
    """
    assert modeler.result is not None
    axis_data = modeler.grid.grid[axis]

    num_rows, num_cols = number_of_plots(len(axis_data), n_cols=2)

    figure_height_ratio = 1.25
    fig = plt.figure(
        dpi=300, figsize=(figure_width, figure_width * figure_height_ratio)
    )
    gs = gridspec.GridSpec(
        num_rows,
        num_cols + 1,
        width_ratios=[1] * num_cols + [0.1],
    )

    axes = []
    for row in range(num_rows):
        for col in range(num_cols):
            axes.append(plt.subplot(gs[row, col]))

    colorbar_ax = plt.subplot(gs[:, -1])
    colorbar_ax.spines["top"].set_visible(False)
    colorbar_ax.spines["bottom"].set_visible(False)
    colorbar_ax.spines["left"].set_visible(False)
    colorbar_ax.spines["right"].set_visible(False)
    colorbar_ax.set_xticks([])
    colorbar_ax.set_yticks([])
    colorbar_ax.set_xticklabels([])
    colorbar_ax.set_yticklabels([])

    colorbar_inset_ax = inset_axes(
        colorbar_ax, width="100%", height="50%", loc="center"
    )

    if modeler.griddata.preprocessing_params is not None:
        gd_reversed = reverse_preprocessing(modeler.griddata)
    else:
        gd_reversed = modeler.griddata
    norm = Normalize(gd_reversed.specs.vmin, gd_reversed.specs.vmax)

    img = None
    for ax, i in zip(axes, range(len(axis_data)), strict=False):
        if axis == "Z":
            matrix = modeler.result.interpolated[i, :, :]
        elif axis == "Y":
            matrix = modeler.result.interpolated[:, i, :]
        elif axis == "X":
            matrix = modeler.result.interpolated[:, :, i]
        else:
            keys = list(SLICING_AXIS.keys())
            msg = f"axis {axis} not implemented. Choose from {keys}"
            raise NotImplementedError(msg)

        img = ax.imshow(
            matrix.squeeze(),
            origin="lower",
            extent=(
                modeler.grid.get_axis(SLICING_AXIS[axis]["X'"]).min,
                modeler.grid.get_axis(SLICING_AXIS[axis]["X'"]).max,
                modeler.grid.get_axis(SLICING_AXIS[axis]["Y'"]).min,
                modeler.grid.get_axis(SLICING_AXIS[axis]["Y'"]).max,
            ),
            cmap="plasma",
            norm=norm,
        )

        from_value = modeler.grid.grid[axis][i]
        axis_res = modeler.grid.get_axis(axis).res
        to_value = from_value + axis_res

        if plot_points:
            points_df = gd_reversed.data.copy().reset_index()
            points = points_df[
                (points_df[axis] >= from_value) & (points_df[axis] < to_value)
            ].copy()
            points = points.sort_values(by=["V"])
            ax.scatter(
                points[SLICING_AXIS[axis]["X'"]],
                points[SLICING_AXIS[axis]["Y'"]],
                c=points["V"],
                cmap="plasma",
                norm=norm,
                s=figure_width / 2,
            )
            if annotate_points:
                for _idx, row in points.iterrows():
                    ax.annotate(
                        f"{row['V']:.0f}",
                        xy=(
                            row[SLICING_AXIS[axis]["X'"]],
                            row[SLICING_AXIS[axis]["Y'"]],
                        ),
                        xytext=(2, 2),
                        textcoords="offset points",
                        fontsize=figure_width / 2,
                    )

        ax.set_title(f"{axis} = {from_value}\u00f7{to_value} m")

    fig.suptitle(f"Along {axis} axis")

    if img is not None:
        plt.colorbar(
            img,
            cax=colorbar_inset_ax,
            format="%.0f",
            fraction=0.1,
        )

    if len(axis_data) < num_rows * num_cols:
        for i in range(len(axis_data), num_rows * num_cols):
            axes[i].set_visible(False)

    return fig