diff --git a/ext/NCDatasetsMPIExt.jl b/ext/NCDatasetsMPIExt.jl index 2af3d4a3..b56a0139 100644 --- a/ext/NCDatasetsMPIExt.jl +++ b/ext/NCDatasetsMPIExt.jl @@ -3,13 +3,16 @@ using MPI using NCDatasets using NCDatasets: NC_COLLECTIVE, + NC_FORMAT_NETCDF4, + NC_FORMAT_NETCDF4_CLASSIC, NC_GLOBAL, NC_INDEPENDENT, Variable, _dataset_ncmode, check, dataset, - libnetcdf + libnetcdf, + nc_inq_format import NCDatasets: NCDataset, @@ -58,14 +61,27 @@ Change the parallel access mode of the variable `ncv` or all variables of the da More information is available in the [NetCDF documentation](https://web.archive.org/web/20240414204638/https://docs.unidata.ucar.edu/netcdf-c/current/parallel_io.html). """ function access(ncv::Variable,par_access::Symbol) + ds = dataset(ncv) + ncid = ds.ncid varid = ncv.varid - ncid = dataset(ncv).ncid - nc_var_par_access(ncid,varid,parallel_access_mode(par_access)) + + if nc_inq_format(ncid) in (NC_FORMAT_NETCDF4, NC_FORMAT_NETCDF4_CLASSIC) + nc_var_par_access(ncid,varid,parallel_access_mode(par_access)) + else + error("The netCDF 3 and 5 formats do not allow different access methods per variable. You need to call this function for the whole data set: NCDatasets.access(ds,$par_access)") + end end # set collective or independent IO globally (for all variables) function access(ds::NCDataset,par_access::Symbol) - nc_var_par_access(ds.ncid,NC_GLOBAL,parallel_access_mode(par_access)) + if nc_inq_format(ds.ncid) in (NC_FORMAT_NETCDF4, NC_FORMAT_NETCDF4_CLASSIC) + for (varname,ncv) in ds + access(ncv.var,par_access) + end + else + # only for PnetCDF + nc_var_par_access(ds.ncid,NC_GLOBAL,parallel_access_mode(par_access)) + end end """ diff --git a/test/test_mpi_script.jl b/test/test_mpi_script.jl index fc91bdd2..e5932911 100644 --- a/test/test_mpi_script.jl +++ b/test/test_mpi_script.jl @@ -25,7 +25,7 @@ ncv = defVar(ds,"temp",Int32,("lon","lat")) # see # https://web.archive.org/web/20240414204638/https://docs.unidata.ucar.edu/netcdf-c/current/parallel_io.html NCDatasets.access(ncv.var,:collective) - +NCDatasets.access(ds,:collective) @debug("rank $(mpi_rank) writing to netCDF variable") ncv[:,i] .= mpi_rank