@@ -341,6 +341,30 @@ def top_n_sobol_params(
341
341
.nlargest (top_n )
342
342
.index .tolist ()
343
343
)
344
+ def plot_sa_heatmap (
345
+ self ,
346
+ results : pd .DataFrame ,
347
+ index : str = "ST" ,
348
+ top_n : int | None = None ,
349
+ cmap : str = "coolwarm" ,
350
+ normalize : bool = True ,
351
+ figsize : tuple | None = None ,
352
+ ):
353
+ """
354
+ Plot a normalized Sobol sensitivity analysis heatmap.
355
+
356
+ Parameters:
357
+ si_df (pd.DataFrame): Sensitivity index dataframe with columns ['index', 'parameter', 'output', 'value'].
358
+ index (str): The type of sensitivity index to plot (e.g., 'ST').
359
+ top_n (int, optional): Number of top parameters to include. Defaults to all.
360
+ cmap (str, optional): Matplotlib colormap. Defaults to 'coolwarm'.
361
+ normalize (bool, optional): Whether to normalize values to [0, 1]. Defaults to True.
362
+ figsize (tuple, optional): Figure size as (width, height) in inches. If None,
363
+ """
364
+ # Determine which parameters to include
365
+ parameter_list = self .top_n_sobol_params (results , top_n = len (results ['parameter' ].unique ()) if top_n is None else top_n )
366
+
367
+ return _plot_sa_heatmap (results , index , parameter_list , cmap , normalize , fig_size = figsize )
344
368
345
369
346
370
def _sobol_results_to_df (results : dict [str , ResultDict ]) -> pd .DataFrame :
@@ -754,3 +778,72 @@ def _create_morris_plot(
754
778
ax .set_ylabel ("μ* (Modified Mean)" )
755
779
ax .set_title (f"Output: { output_name } " )
756
780
ax .grid (True , alpha = 0.3 )
781
+
782
+
783
+
784
+ def _plot_sa_heatmap (si_df , index , parameters , cmap = 'coolwarm' , normalize = True , fig_size = None ):
785
+ """
786
+ Plot a sensitivity analysis heatmap for a given index.
787
+
788
+ Parameters:
789
+ si_df (pd.DataFrame): Sensitivity index dataframe with columns ['index', 'parameter', 'output', 'value'].
790
+ index (str): The type of sensitivity index to plot (e.g., 'ST').
791
+ parameters (list): List of parameters to include in the grid.
792
+ cmap (str, optional): Matplotlib colormap. Defaults to 'coolwarm'.
793
+ normalize (bool, optional): Whether to normalize values to [0, 1]. Defaults to True.
794
+ fig_size (tuple, optional): Figure size as (width, height) in inches. If None,
795
+ """
796
+
797
+ # Filter the dataframe for the specified index
798
+ df = si_df [si_df ['index' ] == index ]
799
+
800
+ # Pivot the dataframe to get a matrix: rows = outputs, cols = parameters
801
+ heatmap_df = (
802
+ df [df ['parameter' ].isin (parameters )]
803
+ .pivot_table (index = 'output' , columns = 'parameter' , values = 'value' , fill_value = np .nan )
804
+ .reindex (columns = parameters ) # Ensure column order
805
+ )
806
+
807
+ # Normalize if requested
808
+ if normalize :
809
+ min_value = heatmap_df .min ().min ()
810
+ max_value = heatmap_df .max ().max ()
811
+ value_range = max_value - min_value if max_value != min_value else 1
812
+ heatmap_df = (heatmap_df - min_value ) / value_range
813
+
814
+ # Convert to NumPy array
815
+ data_np = heatmap_df .to_numpy ()
816
+
817
+ # layout - add space for legend
818
+ n_rows , n_cols = _calculate_layout (data_np .shape [1 ], data_np .shape [0 ])
819
+ fig_size = fig_size or (4.5 * n_cols , 4.5 * n_rows + 2 ) # Extra width for legend
820
+
821
+ # Plotting
822
+ fig , ax = plt .subplots (figsize = fig_size )
823
+ cax = ax .imshow (data_np , cmap = cmap , aspect = 'auto' )
824
+
825
+ # Colorbar
826
+ cbar = fig .colorbar (cax , ax = ax )
827
+ cbar_label = 'Normalized Sensitivity' if normalize else 'Sensitivity'
828
+ cbar .set_label (cbar_label , rotation = 270 , labelpad = 15 )
829
+
830
+ # Labels and ticks
831
+ ax .set_title (f"{ index } Sensitivity Analysis Heatmap" , fontsize = 14 , pad = 12 )
832
+ ax .set_xlabel ("Parameters" , fontsize = 12 )
833
+ ax .set_ylabel ("Outputs" , fontsize = 12 )
834
+
835
+ ax .set_xticks (np .arange (len (parameters )))
836
+ ax .set_xticklabels (parameters , rotation = 45 , ha = 'right' )
837
+ ax .set_yticks (np .arange (len (heatmap_df .index )))
838
+ ax .set_yticklabels (heatmap_df .index )
839
+
840
+ # Gridlines
841
+ ax .set_xticks (np .arange (- 0.5 , len (parameters ), 1 ), minor = True )
842
+ ax .set_yticks (np .arange (- 0.5 , len (heatmap_df .index ), 1 ), minor = True )
843
+ ax .grid (which = 'minor' , color = 'w' , linestyle = '-' , linewidth = 2 )
844
+ ax .tick_params (which = 'minor' , bottom = False , left = False )
845
+
846
+ plt .tight_layout ()
847
+
848
+ return display_figure (fig )
849
+
0 commit comments