Skip to content

Commit 98c7f9f

Browse files
committed
adding heatmap plot for sensitivity analysis #632
1 parent fa9bc26 commit 98c7f9f

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed

autoemulate/experimental/sensitivity_analysis.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,30 @@ def top_n_sobol_params(
341341
.nlargest(top_n)
342342
.index.tolist()
343343
)
344+
def plot_sa_heatmap(
345+
self,
346+
results: pd.DataFrame,
347+
index: str = "ST",
348+
top_n: int | None = None,
349+
cmap: str = "coolwarm",
350+
normalize: bool = True,
351+
figsize: tuple | None = None,
352+
):
353+
"""
354+
Plot a normalized Sobol sensitivity analysis heatmap.
355+
356+
Parameters:
357+
si_df (pd.DataFrame): Sensitivity index dataframe with columns ['index', 'parameter', 'output', 'value'].
358+
index (str): The type of sensitivity index to plot (e.g., 'ST').
359+
top_n (int, optional): Number of top parameters to include. Defaults to all.
360+
cmap (str, optional): Matplotlib colormap. Defaults to 'coolwarm'.
361+
normalize (bool, optional): Whether to normalize values to [0, 1]. Defaults to True.
362+
figsize (tuple, optional): Figure size as (width, height) in inches. If None,
363+
"""
364+
# Determine which parameters to include
365+
parameter_list = self.top_n_sobol_params(results, top_n=len(results['parameter'].unique()) if top_n is None else top_n)
366+
367+
return _plot_sa_heatmap(results, index, parameter_list, cmap, normalize, fig_size=figsize)
344368

345369

346370
def _sobol_results_to_df(results: dict[str, ResultDict]) -> pd.DataFrame:
@@ -754,3 +778,72 @@ def _create_morris_plot(
754778
ax.set_ylabel("μ* (Modified Mean)")
755779
ax.set_title(f"Output: {output_name}")
756780
ax.grid(True, alpha=0.3)
781+
782+
783+
784+
def _plot_sa_heatmap(si_df, index, parameters, cmap='coolwarm', normalize=True, fig_size=None):
785+
"""
786+
Plot a sensitivity analysis heatmap for a given index.
787+
788+
Parameters:
789+
si_df (pd.DataFrame): Sensitivity index dataframe with columns ['index', 'parameter', 'output', 'value'].
790+
index (str): The type of sensitivity index to plot (e.g., 'ST').
791+
parameters (list): List of parameters to include in the grid.
792+
cmap (str, optional): Matplotlib colormap. Defaults to 'coolwarm'.
793+
normalize (bool, optional): Whether to normalize values to [0, 1]. Defaults to True.
794+
fig_size (tuple, optional): Figure size as (width, height) in inches. If None,
795+
"""
796+
797+
# Filter the dataframe for the specified index
798+
df = si_df[si_df['index'] == index]
799+
800+
# Pivot the dataframe to get a matrix: rows = outputs, cols = parameters
801+
heatmap_df = (
802+
df[df['parameter'].isin(parameters)]
803+
.pivot_table(index='output', columns='parameter', values='value', fill_value=np.nan)
804+
.reindex(columns=parameters) # Ensure column order
805+
)
806+
807+
# Normalize if requested
808+
if normalize:
809+
min_value = heatmap_df.min().min()
810+
max_value = heatmap_df.max().max()
811+
value_range = max_value - min_value if max_value != min_value else 1
812+
heatmap_df = (heatmap_df - min_value) / value_range
813+
814+
# Convert to NumPy array
815+
data_np = heatmap_df.to_numpy()
816+
817+
# layout - add space for legend
818+
n_rows, n_cols = _calculate_layout(data_np.shape[1], data_np.shape[0])
819+
fig_size = fig_size or (4.5 * n_cols, 4.5 * n_rows + 2) # Extra width for legend
820+
821+
# Plotting
822+
fig, ax = plt.subplots(figsize=fig_size)
823+
cax = ax.imshow(data_np, cmap=cmap, aspect='auto')
824+
825+
# Colorbar
826+
cbar = fig.colorbar(cax, ax=ax)
827+
cbar_label = 'Normalized Sensitivity' if normalize else 'Sensitivity'
828+
cbar.set_label(cbar_label, rotation=270, labelpad=15)
829+
830+
# Labels and ticks
831+
ax.set_title(f"{index} Sensitivity Analysis Heatmap", fontsize=14, pad=12)
832+
ax.set_xlabel("Parameters", fontsize=12)
833+
ax.set_ylabel("Outputs", fontsize=12)
834+
835+
ax.set_xticks(np.arange(len(parameters)))
836+
ax.set_xticklabels(parameters, rotation=45, ha='right')
837+
ax.set_yticks(np.arange(len(heatmap_df.index)))
838+
ax.set_yticklabels(heatmap_df.index)
839+
840+
# Gridlines
841+
ax.set_xticks(np.arange(-0.5, len(parameters), 1), minor=True)
842+
ax.set_yticks(np.arange(-0.5, len(heatmap_df.index), 1), minor=True)
843+
ax.grid(which='minor', color='w', linestyle='-', linewidth=2)
844+
ax.tick_params(which='minor', bottom=False, left=False)
845+
846+
plt.tight_layout()
847+
848+
return display_figure(fig)
849+

0 commit comments

Comments
 (0)