Skip to content

Commit 18fe84f

Browse files
committed
Automatically support NCZarr
1 parent 6d425db commit 18fe84f

File tree

4 files changed

+50
-36
lines changed

4 files changed

+50
-36
lines changed

xarray/backends/api.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,16 @@
9999

100100

101101
def get_default_netcdf_write_engine(
102+
path_or_file: str | IOBase | None,
102103
format: T_NetcdfTypes | None,
103-
to_fileobject: bool,
104104
) -> Literal["netcdf4", "h5netcdf", "scipy"]:
105105
"""Return the default netCDF library to use for writing a netCDF file."""
106+
106107
module_names = {
107108
"netcdf4": "netCDF4",
108109
"scipy": "scipy",
109110
"h5netcdf": "h5netcdf",
110111
}
111-
112112
candidates = list(OPTIONS["netcdf_engine_order"])
113113

114114
if format is not None:
@@ -128,15 +128,24 @@ def get_default_netcdf_write_engine(
128128
if format not in {"NETCDF3_64BIT", "NETCDF3_CLASSIC"}:
129129
candidates.remove("scipy")
130130

131-
if to_fileobject:
131+
nczarr_mode = isinstance(path_or_file, str) and path_or_file.endswith(
132+
"#mode=nczarr"
133+
)
134+
if nczarr_mode:
135+
candidates[:] = ["netcdf4"]
136+
137+
if isinstance(path_or_file, IOBase):
132138
candidates.remove("netcdf4")
133139

134140
for engine in candidates:
135141
module_name = module_names[engine]
136142
if importlib.util.find_spec(module_name) is not None:
137143
return engine
138144

139-
format_str = f" with {format=}" if format is not None else ""
145+
if nczarr_mode:
146+
format_str = " in NCZarr format"
147+
else:
148+
format_str = f" with {format=}" if format is not None else ""
140149
libraries = ", ".join(module_names[c] for c in candidates)
141150
raise ValueError(
142151
f"cannot write NetCDF files{format_str} because none of the suitable "
@@ -2077,8 +2086,7 @@ def to_netcdf(
20772086
path_or_file = _normalize_path(path_or_file)
20782087

20792088
if engine is None:
2080-
to_fileobject = isinstance(path_or_file, IOBase)
2081-
engine = get_default_netcdf_write_engine(format, to_fileobject)
2089+
engine = get_default_netcdf_write_engine(path_or_file, format)
20822090

20832091
# validate Dataset keys, DataArray names, and attr keys/values
20842092
_validate_dataset_names(dataset)

xarray/core/datatree_io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,9 @@ def _datatree_to_netcdf(
5555
filepath = _normalize_path(filepath)
5656

5757
if engine is None:
58-
to_fileobject = isinstance(filepath, io.IOBase)
5958
engine = get_default_netcdf_write_engine(
59+
path_or_file=filepath,
6060
format="NETCDF4", # required for supporting groups
61-
to_fileobject=to_fileobject,
6261
) # type: ignore[assignment]
6362

6463
if group is not None:

xarray/tests/test_backends.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7277,7 +7277,8 @@ def _create_nczarr(self, filename):
72777277
# https://github.com/Unidata/netcdf-c/issues/2259
72787278
ds = ds.drop_vars("dim3")
72797279

7280-
ds.to_netcdf(f"file://{filename}#mode=nczarr", engine="netcdf4")
7280+
# engine="netcdf4" is not required for backwards compatibility
7281+
ds.to_netcdf(f"file://{filename}#mode=nczarr")
72817282
return ds
72827283

72837284
def test_open_nczarr(self) -> None:

xarray/tests/test_backends_api.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import io
34
import re
45
import sys
56
from numbers import Number
@@ -23,53 +24,44 @@
2324
@requires_scipy
2425
@requires_h5netcdf
2526
def test_get_default_netcdf_write_engine() -> None:
26-
engine = get_default_netcdf_write_engine(format=None, to_fileobject=False)
27+
engine = get_default_netcdf_write_engine("", format=None)
2728
assert engine == "h5netcdf"
2829

29-
engine = get_default_netcdf_write_engine(format="NETCDF4", to_fileobject=False)
30+
engine = get_default_netcdf_write_engine("", format="NETCDF4")
3031
assert engine == "h5netcdf"
3132

32-
engine = get_default_netcdf_write_engine(
33-
format="NETCDF4_CLASSIC", to_fileobject=False
34-
)
33+
engine = get_default_netcdf_write_engine("", format="NETCDF4_CLASSIC")
3534
assert engine == "netcdf4"
3635

37-
engine = get_default_netcdf_write_engine(format="NETCDF4", to_fileobject=True)
36+
engine = get_default_netcdf_write_engine(io.BytesIO(), format="NETCDF4")
3837
assert engine == "h5netcdf"
3938

40-
engine = get_default_netcdf_write_engine(
41-
format="NETCDF3_CLASSIC", to_fileobject=False
42-
)
39+
engine = get_default_netcdf_write_engine("", format="NETCDF3_CLASSIC")
4340
assert engine == "scipy"
4441

45-
engine = get_default_netcdf_write_engine(
46-
format="NETCDF3_CLASSIC", to_fileobject=True
47-
)
42+
engine = get_default_netcdf_write_engine(io.BytesIO(), format="NETCDF3_CLASSIC")
4843
assert engine == "scipy"
4944

45+
engine = get_default_netcdf_write_engine("path.zarr#mode=nczarr", format=None)
46+
assert engine == "netcdf4"
47+
5048
with xr.set_options(netcdf_engine_order=["netcdf4", "scipy", "h5netcdf"]):
51-
engine = get_default_netcdf_write_engine(format=None, to_fileobject=False)
49+
engine = get_default_netcdf_write_engine("", format=None)
5250
assert engine == "netcdf4"
5351

54-
engine = get_default_netcdf_write_engine(format="NETCDF4", to_fileobject=False)
52+
engine = get_default_netcdf_write_engine("", format="NETCDF4")
5553
assert engine == "netcdf4"
5654

57-
engine = get_default_netcdf_write_engine(
58-
format="NETCDF4_CLASSIC", to_fileobject=False
59-
)
55+
engine = get_default_netcdf_write_engine("", format="NETCDF4_CLASSIC")
6056
assert engine == "netcdf4"
6157

62-
engine = get_default_netcdf_write_engine(format="NETCDF4", to_fileobject=True)
58+
engine = get_default_netcdf_write_engine(io.BytesIO(), format="NETCDF4")
6359
assert engine == "h5netcdf"
6460

65-
engine = get_default_netcdf_write_engine(
66-
format="NETCDF3_CLASSIC", to_fileobject=False
67-
)
61+
engine = get_default_netcdf_write_engine("", format="NETCDF3_CLASSIC")
6862
assert engine == "netcdf4"
6963

70-
engine = get_default_netcdf_write_engine(
71-
format="NETCDF3_CLASSIC", to_fileobject=True
72-
)
64+
engine = get_default_netcdf_write_engine(io.BytesIO(), format="NETCDF3_CLASSIC")
7365
assert engine == "scipy"
7466

7567

@@ -81,17 +73,31 @@ def test_default_engine_h5netcdf(monkeypatch):
8173
monkeypatch.delitem(sys.modules, "scipy", raising=False)
8274
monkeypatch.setattr(sys, "meta_path", [])
8375

84-
engine = get_default_netcdf_write_engine(format=None, to_fileobject=False)
76+
engine = get_default_netcdf_write_engine("", format=None)
8577
assert engine == "h5netcdf"
8678

8779
with pytest.raises(
8880
ValueError,
8981
match=re.escape(
9082
"cannot write NetCDF files with format='NETCDF3_CLASSIC' because "
91-
"none of the suitable backend libraries (netCDF4, scipy) are installed"
83+
"none of the suitable backend libraries (scipy, netCDF4) are installed"
84+
),
85+
):
86+
get_default_netcdf_write_engine("", format="NETCDF3_CLASSIC")
87+
88+
89+
def test_default_engine_nczarr_no_netcdf4_python(monkeypatch):
90+
monkeypatch.delitem(sys.modules, "netCDF4", raising=False)
91+
monkeypatch.setattr(sys, "meta_path", [])
92+
93+
with pytest.raises(
94+
ValueError,
95+
match=re.escape(
96+
"cannot write NetCDF files in NCZarr format because "
97+
"none of the suitable backend libraries (netCDF4) are installed"
9298
),
9399
):
94-
get_default_netcdf_write_engine(format="NETCDF3_CLASSIC", to_fileobject=False)
100+
get_default_netcdf_write_engine("#mode=nczarr", format=None)
95101

96102

97103
def test_custom_engine() -> None:

0 commit comments

Comments
 (0)