Skip to content

plot_study

This module provides functions to create and customize study plots, including heatmaps and 3D volume renderings.

Functions:

Name Description
_set_style
_add_text_annotation
_smooth
_mask
_add_contours
_add_diagonal_lines
add_QR_code
_set_labels
plot_heatmap
plot_3D

add_QR_code(fig, link, position_qr='top-right')

Adds a QR code pointing to the given link to the figure.

Parameters:

Name Type Description Default
fig Figure

The figure to add the QR code to.

required
link str

The link to encode in the QR code.

required

Returns:

Type Description
Figure

plt.Figure: The figure with the QR code.

Source code in study_da/plot/plot_study.py
def add_QR_code(fig: plt.Figure, link: str, position_qr="top-right") -> plt.Figure:
    """
    Adds a QR code pointing to the given link to the figure.

    Args:
        fig (plt.Figure): The figure to add the QR code to.
        link (str): The link to encode in the QR code.

    Returns:
        plt.Figure: The figure with the QR code.
    """
    # Add QR code pointing to the github repository
    qr = qrcode.QRCode(
        # version=None,
        box_size=10,
        border=1,
    )
    qr.add_data(link)
    qr.make(fit=False)
    im = qr.make_image(fill_color="black", back_color="transparent")
    if position_qr == "top-right":
        newax = fig.add_axes([0.9, 0.9, 0.05, 0.05], anchor="NE", zorder=1)
    elif position_qr == "bottom-right":
        newax = fig.add_axes([0.9, 0.1, 0.05, 0.05], anchor="SE", zorder=1)
    elif position_qr == "bottom-left":
        newax = fig.add_axes([0.1, 0.1, 0.05, 0.05], anchor="SW", zorder=1)
    elif position_qr == "top-left":
        newax = fig.add_axes([0.1, 0.9, 0.05, 0.05], anchor="NW", zorder=1)
    else:
        raise ValueError(f"Position {position_qr} not recognized")
    newax.imshow(im, resample=False, interpolation="none", filternorm=False)
    # Add link below qrcode
    newax.plot([0, 0], [0, 0], color="white", label="link")
    _ = newax.annotate(
        "lin",
        xy=(0, 300),
        xytext=(0, 300),
        fontsize=30,
        url=link,
        bbox=dict(color="white", alpha=1e-6, url=link),
        alpha=0,
    )
    # Hide X and Y axes label marks
    newax.xaxis.set_tick_params(labelbottom=False)
    newax.yaxis.set_tick_params(labelleft=False)
    # Hide X and Y axes tick marks
    newax.set_xticks([])
    newax.set_yticks([])
    newax.set_axis_off()

    return fig

plot_3D(dataframe_data, x_variable, y_variable, z_variable, color_variable, xlabel=None, ylabel=None, z_label=None, title='', vmin=4.5, vmax=7.5, surface_count=30, opacity=0.2, figsize=(1000, 1000), colormap='RdBu', colorbar_title_text='Minimum DA (σ)', display_colormap=False, output_path='output.png', output_path_html='output.html', display_plot=True, dark_theme=False)

Plots a 3D volume rendering from the given dataframe.

Parameters:

Name Type Description Default
dataframe_data DataFrame

The dataframe containing the data to plot.

required
x_variable str

The variable to plot on the x-axis.

required
y_variable str

The variable to plot on the y-axis.

required
z_variable str

The variable to plot on the z-axis.

required
color_variable str

The variable to use for the color scale.

required
xlabel Optional[str]

The label for the x-axis. Defaults to None.

None
ylabel Optional[str]

The label for the y-axis. Defaults to None.

None
z_label Optional[str]

The label for the z-axis. Defaults to None.

None
title str

The title of the plot. Defaults to "".

''
vmin float

The minimum value for the color scale. Defaults to 4.5.

4.5
vmax float

The maximum value for the color scale. Defaults to 7.5.

7.5
surface_count int

The number of surfaces for volume rendering. Defaults to 30.

30
opacity float

The opacity of the volume rendering. Defaults to 0.2.

0.2
figsize tuple[float, float]

The size of the figure. Defaults to (1000, 1000).

(1000, 1000)
colormap str

The colormap to use. Defaults to "RdBu".

'RdBu'
colorbar_title_text str

The label for the colorbar. Defaults to "Minimum DA (σ)".

'Minimum DA (σ)'
display_colormap bool

Whether to display the colormap. Defaults to False.

False
output_path str

The path to save the plot image. Defaults to "output.png".

'output.png'
output_path_html str

The path to save the plot HTML. Defaults to "output.html".

'output.html'
display_plot bool

Whether to display the plot. Defaults to True.

True
dark_theme bool

Whether to use a dark theme. Defaults to False.

False

Returns:

Type Description
Any

go.Figure: The plotly figure object.

Source code in study_da/plot/plot_study.py
def plot_3D(
    dataframe_data: pd.DataFrame,
    x_variable: str,
    y_variable: str,
    z_variable: str,
    color_variable: str,
    xlabel: Optional[str] = None,
    ylabel: Optional[str] = None,
    z_label: Optional[str] = None,
    title: str = "",
    vmin: float = 4.5,
    vmax: float = 7.5,
    surface_count: int = 30,
    opacity: float = 0.2,
    figsize: tuple[float, float] = (1000, 1000),
    colormap: str = "RdBu",
    colorbar_title_text: str = "Minimum DA (σ)",
    display_colormap: bool = False,
    output_path: str = "output.png",
    output_path_html: str = "output.html",
    display_plot: bool = True,
    dark_theme: bool = False,
) -> Any:
    """
    Plots a 3D volume rendering from the given dataframe.

    Args:
        dataframe_data (pd.DataFrame): The dataframe containing the data to plot.
        x_variable (str): The variable to plot on the x-axis.
        y_variable (str): The variable to plot on the y-axis.
        z_variable (str): The variable to plot on the z-axis.
        color_variable (str): The variable to use for the color scale.
        xlabel (Optional[str], optional): The label for the x-axis. Defaults to None.
        ylabel (Optional[str], optional): The label for the y-axis. Defaults to None.
        z_label (Optional[str], optional): The label for the z-axis. Defaults to None.
        title (str, optional): The title of the plot. Defaults to "".
        vmin (float, optional): The minimum value for the color scale. Defaults to 4.5.
        vmax (float, optional): The maximum value for the color scale. Defaults to 7.5.
        surface_count (int, optional): The number of surfaces for volume rendering. Defaults to 30.
        opacity (float, optional): The opacity of the volume rendering. Defaults to 0.2.
        figsize (tuple[float, float], optional): The size of the figure. Defaults to (1000, 1000).
        colormap (str, optional): The colormap to use. Defaults to "RdBu".
        colorbar_title_text (str, optional): The label for the colorbar. Defaults to "Minimum DA (σ)".
        display_colormap (bool, optional): Whether to display the colormap. Defaults to False.
        output_path (str, optional): The path to save the plot image. Defaults to "output.png".
        output_path_html (str, optional): The path to save the plot HTML. Defaults to "output.html".
        display_plot (bool, optional): Whether to display the plot. Defaults to True.
        dark_theme (bool, optional): Whether to use a dark theme. Defaults to False.

    Returns:
        go.Figure: The plotly figure object.
    """
    # Check if plotly is installed
    try:
        import plotly.graph_objects as go
    except ImportError as e:
        raise ImportError("Please install plotly to use this function") from e

    X = np.array(dataframe_data[x_variable])
    Y = np.array(dataframe_data[y_variable])
    Z = np.array(dataframe_data[z_variable])
    values = np.array(dataframe_data[color_variable])
    fig = go.Figure(
        data=go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=values.flatten(),
            isomin=vmin,
            isomax=vmax,
            opacity=opacity,  # needs to be small to see through all surfaces
            surface_count=surface_count,  # needs to be a large number for good volume rendering
            colorscale=colormap,
            colorbar_title_text=colorbar_title_text,
        )
    )

    fig.update_layout(
        scene_xaxis_title_text=xlabel,
        scene_yaxis_title_text=ylabel,
        scene_zaxis_title_text=z_label,
        title=title,
    )

    # Get a good initial view, dezoomed
    fig.update_layout(scene_camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)))

    # Center the title
    fig.update_layout(title_x=0.5, title_y=0.9, title_xanchor="center", title_yanchor="top")

    # Specify the width and height of the figure
    fig.update_layout(width=figsize[0], height=figsize[1])

    # Remove margins and padding
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))

    # Display the colormap
    if not display_colormap:
        fig.update_layout(coloraxis_showscale=False)
        fig.update_traces(showscale=False)
    else:
        # Make colorbar smaller
        fig.update_layout(coloraxis_colorbar=dict(thickness=10, len=0.5))

    # Set the theme
    if dark_theme:
        fig.update_layout(template="plotly_dark")

    # Display/save/return the figure
    if output_path is not None:
        fig.write_image(output_path)

    if output_path_html is not None:
        fig.write_html(output_path_html)

    if display_plot:
        fig.show()

    return fig

plot_heatmap(dataframe_data, horizontal_variable, vertical_variable, color_variable, link=None, position_qr='top-right', plot_contours=True, xlabel=None, ylabel=None, tick_interval=2, round_xticks=None, round_yticks=None, symmetric_missing=False, mask_lower_triangle=False, mask_upper_triangle=False, plot_diagonal_lines=True, shift_diagonal_lines=1, xaxis_ticks_on_top=True, title='', vmin=4.5, vmax=7.5, k_masking=-1, green_contour=6.0, min_level_contours=1, max_level_contours=15, delta_levels_contours=0.5, figsize=None, label_cbar='Minimum DA (' + '$\\sigma$' + ')', colormap='coolwarm_r', style='ggplot', output_path='output.png', display_plot=True, latex_fonts=True, vectorize=False, fill_missing_value_with=None, dpi=300)

Plots a heatmap from the given dataframe.

Parameters:

Name Type Description Default
dataframe_data DataFrame

The dataframe containing the data to plot.

required
horizontal_variable str

The variable to plot on the horizontal axis.

required
vertical_variable str

The variable to plot on the vertical axis.

required
color_variable str

The variable to use for the color scale.

required
link Optional[str]

A link to encode in a QR code. Defaults to None.

None
plot_contours bool

Whether to plot contours. Defaults to True.

True
xlabel Optional[str]

The label for the x-axis. Defaults to None.

None
ylabel Optional[str]

The label for the y-axis. Defaults to None.

None
tick_interval int

The interval for the ticks. Defaults to 2.

2
round_xticks Optional[int]

The number of decimal places to round the x-ticks to. Defaults to None.

None
round_yticks Optional[int]

The number of decimal places to round the y-ticks to. Defaults to None.

None
symmetric_missing bool

Whether to make the matrix symmetric by replacing the lower triangle with the upper triangle. Defaults to False.

False
mask_lower_triangle bool

Whether to mask the lower triangle. Defaults to False.

False
mask_upper_triangle bool

Whether to mask the upper triangle. Defaults to False.

False
plot_diagonal_lines bool

Whether to plot diagonal lines. Defaults to True.

True
shift_diagonal_lines int

The shift for the diagonal lines. Defaults to 1.

1
xaxis_ticks_on_top bool

Whether to place the x-axis ticks on top. Defaults to True.

True
title str

The title of the plot. Defaults to "".

''
vmin float

The minimum value for the color scale. Defaults to 4.5.

4.5
vmax float

The maximum value for the color scale. Defaults to 7.5.

7.5
k_masking int

The k parameter for masking. Defaults to -1.

-1
green_contour Optional[float]

The value for the green contour line. Defaults to 6.0.

6.0
min_level_contours float

The minimum level for the contours. Defaults to 1.

1
max_level_contours float

The maximum level for the contours. Defaults to 15.

15
delta_levels_contours float

The delta between contour levels. Defaults to 0.5.

0.5
figsize Optional[tuple[float, float]]

The size of the figure. Defaults to None.

None
label_cbar str

The label for the colorbar. Defaults to "Minimum DA ($\sigma$)".

'Minimum DA (' + '$\\sigma$' + ')'
colormap str

The colormap to use. Defaults to "coolwarm_r".

'coolwarm_r'
style str

The style to use for the plot. Defaults to "ggplot".

'ggplot'
output_path str

The path to save the plot. Defaults to "output.pdf".

'output.png'
display_plot bool

Whether to display the plot. Defaults to True.

True
latex_fonts bool

Whether to use LaTeX fonts. Defaults to True.

True
vectorize bool

Whether to vectorize the plot. Defaults to False.

False
fill_missing_value_with Optional[str | float]

The value to fill missing values with. Can be a number or 'interpolate'. Defaults to None.

None
dpi int

The DPI for the plot. Defaults to 300.

300

Returns:

Type Description
tuple[Figure, Axes]

tuple[plt.Figure, plt.Axes]: The figure and axes of the plot.

Source code in study_da/plot/plot_study.py
def plot_heatmap(
    dataframe_data: pd.DataFrame,
    horizontal_variable: str,
    vertical_variable: str,
    color_variable: str,
    link: Optional[str] = None,
    position_qr: Optional[str] = "top-right",
    plot_contours: bool = True,
    xlabel: Optional[str] = None,
    ylabel: Optional[str] = None,
    tick_interval: int = 2,
    round_xticks: Optional[int] = None,
    round_yticks: Optional[int] = None,
    symmetric_missing: bool = False,
    mask_lower_triangle: bool = False,
    mask_upper_triangle: bool = False,
    plot_diagonal_lines: bool = True,
    shift_diagonal_lines: int = 1,
    xaxis_ticks_on_top: bool = True,
    title: str = "",
    vmin: float = 4.5,
    vmax: float = 7.5,
    k_masking: int = -1,
    green_contour: Optional[float] = 6.0,
    min_level_contours: float = 1,
    max_level_contours: float = 15,
    delta_levels_contours: float = 0.5,
    figsize: Optional[tuple[float, float]] = None,
    label_cbar: str = "Minimum DA (" + r"$\sigma$" + ")",
    colormap: str = "coolwarm_r",
    style: str = "ggplot",
    output_path: str = "output.png",
    display_plot: bool = True,
    latex_fonts: bool = True,
    vectorize: bool = False,
    fill_missing_value_with: Optional[str | float] = None,
    dpi=300,
) -> tuple[plt.Figure, plt.Axes]:
    """
    Plots a heatmap from the given dataframe.

    Args:
        dataframe_data (pd.DataFrame): The dataframe containing the data to plot.
        horizontal_variable (str): The variable to plot on the horizontal axis.
        vertical_variable (str): The variable to plot on the vertical axis.
        color_variable (str): The variable to use for the color scale.
        link (Optional[str], optional): A link to encode in a QR code. Defaults to None.
        plot_contours (bool, optional): Whether to plot contours. Defaults to True.
        xlabel (Optional[str], optional): The label for the x-axis. Defaults to None.
        ylabel (Optional[str], optional): The label for the y-axis. Defaults to None.
        tick_interval (int, optional): The interval for the ticks. Defaults to 2.
        round_xticks (Optional[int], optional): The number of decimal places to round the x-ticks to.
            Defaults to None.
        round_yticks (Optional[int], optional): The number of decimal places to round the y-ticks to.
            Defaults to None.
        symmetric_missing (bool, optional): Whether to make the matrix symmetric by replacing the
            lower triangle with the upper triangle. Defaults to False.
        mask_lower_triangle (bool, optional): Whether to mask the lower triangle. Defaults to False.
        mask_upper_triangle (bool, optional): Whether to mask the upper triangle. Defaults to False.
        plot_diagonal_lines (bool, optional): Whether to plot diagonal lines. Defaults to True.
        shift_diagonal_lines (int, optional): The shift for the diagonal lines. Defaults to 1.
        xaxis_ticks_on_top (bool, optional): Whether to place the x-axis ticks on top. Defaults to True.
        title (str, optional): The title of the plot. Defaults to "".
        vmin (float, optional): The minimum value for the color scale. Defaults to 4.5.
        vmax (float, optional): The maximum value for the color scale. Defaults to 7.5.
        k_masking (int, optional): The k parameter for masking. Defaults to -1.
        green_contour (Optional[float], optional): The value for the green contour line. Defaults to 6.0.
        min_level_contours (float, optional): The minimum level for the contours. Defaults to 1.
        max_level_contours (float, optional): The maximum level for the contours. Defaults to 15.
        delta_levels_contours (float, optional): The delta between contour levels. Defaults to 0.5.
        figsize (Optional[tuple[float, float]], optional): The size of the figure. Defaults to None.
        label_cbar (str, optional): The label for the colorbar. Defaults to "Minimum DA ($\sigma$)".
        colormap (str, optional): The colormap to use. Defaults to "coolwarm_r".
        style (str, optional): The style to use for the plot. Defaults to "ggplot".
        output_path (str, optional): The path to save the plot. Defaults to "output.pdf".
        display_plot (bool, optional): Whether to display the plot. Defaults to True.
        latex_fonts (bool, optional): Whether to use LaTeX fonts. Defaults to True.
        vectorize (bool, optional): Whether to vectorize the plot. Defaults to False.
        fill_missing_value_with (Optional[str | float], optional): The value to fill missing values
            with. Can be a number or 'interpolate'. Defaults to None.
        dpi (int, optional): The DPI for the plot. Defaults to 300.

    Returns:
        tuple[plt.Figure, plt.Axes]: The figure and axes of the plot.
    """
    # Use the requested style
    _set_style(style, latex_fonts, vectorize)

    # Get the dataframe to plot
    df_to_plot = dataframe_data.pivot(
        index=vertical_variable, columns=horizontal_variable, values=color_variable
    )

    # Get numpy array from dataframe
    data_array = df_to_plot.to_numpy(dtype=float)

    # Replace NaNs with a value if requested
    if fill_missing_value_with is not None:
        if isinstance(fill_missing_value_with, (int, float)):
            data_array[np.isnan(data_array)] = fill_missing_value_with
        elif fill_missing_value_with == "interpolate":
            # Interpolate missing values with griddata
            x = np.arange(data_array.shape[1])
            y = np.arange(data_array.shape[0])
            xx, yy = np.meshgrid(x, y)
            x = xx[~np.isnan(data_array)]
            y = yy[~np.isnan(data_array)]
            z = data_array[~np.isnan(data_array)]
            data_array = griddata((x, y), z, (xx, yy), method="cubic")

    # Mask the lower or upper triangle (checks are done in the function)
    data_array_masked, mask_main_array = _mask(
        mask_lower_triangle, mask_upper_triangle, data_array, k_masking
    )

    # Define colormap and set NaNs to white
    cmap = matplotlib.colormaps.get_cmap(colormap)
    cmap.set_bad("w")

    # Build heatmap, with inverted y axis
    fig, ax = plt.subplots()
    if figsize is not None:
        fig.set_size_inches(figsize)
    im = ax.imshow(data_array_masked, cmap=cmap, vmin=vmin, vmax=vmax)
    ax.invert_yaxis()

    # Add text annotations
    ax = _add_text_annotation(df_to_plot, data_array, ax, vmin, vmax)

    # Smooth data for contours
    mx = _smooth(data_array, symmetric_missing)

    # Plot contours if requested
    if plot_contours:
        ax = _add_contours(
            ax,
            data_array,
            mx,
            green_contour,
            min_level_contours,
            max_level_contours,
            delta_levels_contours,
            mask_main_array,
        )

    if plot_diagonal_lines:
        # Diagonal lines must be plotted after the contour lines, because of bug in matplotlib
        # Shift might need to be adjusted
        ax = _add_diagonal_lines(ax, shift=shift_diagonal_lines)

    # Define title and axis labels
    ax.set_title(
        title,
        fontsize=10,
    )

    # Set axis labels
    ax = _set_labels(
        ax,
        df_to_plot,
        data_array,
        horizontal_variable,
        vertical_variable,
        xlabel,
        ylabel,
        xaxis_ticks_on_top,
        tick_interval,
        round_xticks,
        round_yticks,
    )

    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax, fraction=0.026, pad=0.04)
    cbar.ax.set_ylabel(label_cbar, rotation=90, va="bottom", labelpad=15)

    # Remove potential grid
    plt.grid(visible=None)

    # Add QR code with a link to the topright side (a bit experimental, might need adjustments)
    if link is not None:
        fig = add_QR_code(fig, link, position_qr)

    # Save and potentially display the plot
    if output_path is not None:
        if output_path.endswith(".pdf") and not vectorize:
            raise ValueError("Please set vectorize=True to save as PDF")
        elif not output_path.endswith(".pdf") and vectorize:
            raise ValueError("Please set vectorize=False to save as PNG or JPG")
        plt.savefig(output_path, bbox_inches="tight", dpi=dpi)

    if display_plot:
        plt.show()
    return fig, ax