From 92bc781637efeab7863506ab34bb02b364026ecc Mon Sep 17 00:00:00 2001 From: Fabian Gans Date: Fri, 7 Oct 2022 13:28:01 +0200 Subject: [PATCH 1/3] Add a getindex with tables interface --- docs/src/expl/expl.md | 2 ++ src/Cubes/Cubes.jl | 77 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/docs/src/expl/expl.md b/docs/src/expl/expl.md index e69de29b..740176e7 100644 --- a/docs/src/expl/expl.md +++ b/docs/src/expl/expl.md @@ -0,0 +1,2 @@ +## Indexing and subsetting + diff --git a/src/Cubes/Cubes.jl b/src/Cubes/Cubes.jl index 3c53908b..f186a5d6 100644 --- a/src/Cubes/Cubes.jl +++ b/src/Cubes/Cubes.jl @@ -13,6 +13,7 @@ using YAXArrayBase: YAXArrayBase, iscompressed, dimnames, iscontdimval import YAXArrayBase: getattributes, iscontdim, dimnames, dimvals, getdata using DiskArrayTools: CFDiskArray using DocStringExtensions +using Tables: istable, schema, columns export concatenatecubes, caxes, subsetcube, readcubedata, renameaxis!, YAXArray, setchunks @@ -213,7 +214,81 @@ setchunks(c::YAXArray,chunks) = YAXArray(c.axes,c.data,c.properties,interpret_cu cubechunks(c) = approx_chunksize(eachchunk(c)) DiskArrays.eachchunk(c::YAXArray) = c.chunks getindex_all(a) = getindex(a, ntuple(_ -> Colon(), ndims(a))...) -Base.getindex(x::YAXArray, i...) = getdata(x)[i...] +function Base.getindex(x::YAXArray, i...) + @show length(i), istable(i) + if length(i)==1 && istable(first(i)) + batchextract(x,first(i)) + else + getdata(x)[i...] + end +end +function batchextract(x,i) + sch = schema(i) + axinds = map(sch.names) do n + findAxis(n,x) + end + + tcols = columns(i) + #Try to find a column denoting new axis name and values + newaxcol = nothing + if any(isnothing,axinds) + allnothings = findall(isnothing,axinds) + if length(allnothings) == 1 + newaxcol = allnothings[1] + end + axinds = filter(!isnothing,axinds) + end + if !all(diff(sort(axinds)).==1) + error("Axes indexed into currently need to be together") #Tofix + end + cartinds = map(axinds,tcols) do iax,col + axcur = caxes(x)[iax] + map(col) do val + axVal2Index(axcur,val) + end + end + indlist = map(cartinds...) do inds... + ind = map(inds,axinds) do i,ai + + end + ind + end + d = getdata(x)[indlist] + cax = caxes(x) + allax = 1:ndims(x) + axrem = setdiff(allax,axinds) + newax = if newaxcol == nothing + outaxis_from_data(cax,axinds,indlist) + else + outaxis_from_column(i,newaxcol) + end + outax = CubeAxis[axcopy(a) for a in cax][axrem] + insert!(outax,minimum(axinds),newax) + YAXArray(outax,d,x.properties) +end + +function outaxis_from_column(tab,icol) + axdata = columns(tab)[icol] + axname = schema(tab).names[icol] + if eltype(axdata) <: AbstractString || + (!issorted(axdata) && !issorted(axdata, rev = true)) + CategoricalAxis(axname, axdata) + else + RangeAxis(axname, axdata) + end +end + +function outaxis_from_data(cax,axinds,indlist) + mergeaxes = getindex.(Ref(cax),axinds) + mergenames = axname.(mergeaxes) + newname = join(mergenames,'_') + mergevals = map(indlist) do i + broadcast(mergeaxes,axinds) do ax,ai + ax.values[i[ai]] + end + end + CategoricalAxis(newname, mergevals) +end chunkoffset(c) = grid_offset(eachchunk(c)) # Implementation for YAXArrayBase interface From 42c1e09258ef8222b1baa0776f61e86f19bf6b8e Mon Sep 17 00:00:00 2001 From: Fabian Gans Date: Mon, 10 Oct 2022 14:16:17 +0200 Subject: [PATCH 2/3] add batch extraction method --- docs/src/expl/expl.md | 10 +++++++ src/Cubes/Cubes.jl | 31 +++++++++++++------- test/Cubes/batchextraction.jl | 53 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 85 insertions(+), 10 deletions(-) create mode 100644 test/Cubes/batchextraction.jl diff --git a/docs/src/expl/expl.md b/docs/src/expl/expl.md index 740176e7..da5c6c83 100644 --- a/docs/src/expl/expl.md +++ b/docs/src/expl/expl.md @@ -1,2 +1,12 @@ ## Indexing and subsetting +As for most array types, YAXArray also provides special indexing behavior when using the square brackets for indexing. +Assuming that `c` is a YAXArray, there are 3 different semantics to use the square brackets with, depending on the types of the arguments +provided to getindex. + +1. **Ranges and Integers only** as for example `c[1,4:8,:]` will access the underlying data according to the provided index in index space and read the data *into memory* as a plain Julia Array. +It is equivalent to `c.data[1,4:8,:]`. +2. **Keyword arguments with values or Intervals** as for example `c[longitude = 30..50, time=Date(2005,6,1), variable="air_temperature"]`. +This always creates a *view* into the specified subset of the data and return a new YAXArray with new axes without reading the data. Intervals and +values are always interpreted in the units as provided by the axis values. +3. **A Tables.jl-compatible object** for irregular extraction of a list of points or sub-arrays and random locations. For example calling `c[[(lon=30,lat=42),(lon=-50,lat=2.5)]]` will extract data at the specified coordinates and along all additional axes into memory. It returns a new YAXArray with a new Multi-Index axis along the selected longitudes and latitudes. diff --git a/src/Cubes/Cubes.jl b/src/Cubes/Cubes.jl index f186a5d6..74e6973b 100644 --- a/src/Cubes/Cubes.jl +++ b/src/Cubes/Cubes.jl @@ -238,8 +238,14 @@ function batchextract(x,i) end axinds = filter(!isnothing,axinds) end - if !all(diff(sort(axinds)).==1) - error("Axes indexed into currently need to be together") #Tofix + allax = 1:ndims(x) + axrem = setdiff(allax,axinds) + ai1, ai2 = extrema(axinds) + if !all(diff(sort(collect(axinds))).==1) + #Axes to be extracted from are not consecutive in cube -> permute + p = [1:(ai1-1);collect(axinds);filter(!in(axinds),ai1:ai2);(ai2+1:ndims(x))] + x_perm = permutedims(x,p) + return batchextract(x_perm,i) end cartinds = map(axinds,tcols) do iax,col axcur = caxes(x)[iax] @@ -247,16 +253,20 @@ function batchextract(x,i) axVal2Index(axcur,val) end end - indlist = map(cartinds...) do inds... - ind = map(inds,axinds) do i,ai - + + before = ntuple(_->Colon(),ai1-1) + after = ntuple(_->Colon(),ndims(x)-ai2) + sp = issorted(axinds) ? nothing : sortperm(collect(axinds)) + function makeindex(sp, inds...) + if sp === nothing + CartesianIndex(inds...) + else + CartesianIndex(inds[sp]...) end - ind end - d = getdata(x)[indlist] + indlist = makeindex.(Ref(sp),cartinds...) + d = getdata(x)[before...,indlist,after...] cax = caxes(x) - allax = 1:ndims(x) - axrem = setdiff(allax,axinds) newax = if newaxcol == nothing outaxis_from_data(cax,axinds,indlist) else @@ -282,9 +292,10 @@ function outaxis_from_data(cax,axinds,indlist) mergeaxes = getindex.(Ref(cax),axinds) mergenames = axname.(mergeaxes) newname = join(mergenames,'_') + minai = minimum(axinds) mergevals = map(indlist) do i broadcast(mergeaxes,axinds) do ax,ai - ax.values[i[ai]] + ax.values[i[ai-minai+1]] end end CategoricalAxis(newname, mergevals) diff --git a/test/Cubes/batchextraction.jl b/test/Cubes/batchextraction.jl new file mode 100644 index 00000000..df038985 --- /dev/null +++ b/test/Cubes/batchextraction.jl @@ -0,0 +1,53 @@ +using Test + +@testset "Batch extraction along multiple axes" begin +lons = range(30,35,step=0.25) +lats = range(50,55,step=0.25) +times = Date(2000,1,1):Month(1):Date(2000,12,31) + +data = rand(length(lons),length(lats), length(times)); + +c = YAXArray([RangeAxis("longitude",lons),RangeAxis("latitude",lats),RangeAxis("time",times)],data) +c_perm = permutedims(c,(3,2,1)) + + +sites_names = [(lon = rand()*5+30, lat = rand()*5+50,site = string(i)) for i in 1:200] +sites_pure = [n[[:lon,:lat]] for n in sites_names] +lon,lat = sites_pure[10] + +r = c[sites_names] +@test r isa YAXArray +@test YAXArrays.Cubes.axname.(caxes(r)) == ["site","time"] +@test r.site.values == string.(1:200) +@test all(isequal.(c[lon=lon,lat=lat][:], r[10,:])) + +r = c_perm[sites_names] +@test r isa YAXArray +@test YAXArrays.Cubes.axname.(caxes(r)) == ["time","site"] +@test r.site.values == string.(1:200) +@test all(isequal.(c[lon=lon,lat=lat][:], r[:,10])) + +r = c[sites_pure] +@test r isa YAXArray +@test YAXArrays.Cubes.axname.(caxes(r)) == ["longitude_latitude","time"] +map(r.longitude_latitude.values,[(n.lon,n.lat) for n in sites_pure]) do ll, ll_real + abs(ll[1]-ll_real[1]) <= 0.125 && abs(ll[2]-ll_real[2]) <= 0.125 +end |> all +@test all(isequal.(c[lon=lon,lat=lat][:], r[10,:])) + +r = c_perm[sites_pure] +@test r isa YAXArray +@test YAXArrays.Cubes.axname.(caxes(r)) == ["time","longitude_latitude"] +map(r.longitude_latitude.values,[(n.lon,n.lat) for n in sites_pure]) do ll, ll_real + abs(ll[1]-ll_real[1]) <= 0.125 && abs(ll[2]-ll_real[2]) <= 0.125 +end |> all +@test all(isequal.(c[lon=lon,lat=lat][:], r[:,10])) + +othersites = [(lon=32.0,time=Date(2000,6,1),point=3),(lon=33.0,time=Date(2000,7,1),point=5)] +r = c[othersites] +@test r isa YAXArray +@test YAXArrays.Cubes.axname.(caxes(r)) == ["point","latitude"] +@test r.point.values == [3,5] +@test c[lon=33.0,time=Date(2000,7,1)][:] == r[point=5][:] + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index f1133520..e6e335f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ include("tools.jl") include("Cubes/axes.jl") include("Cubes/cubes.jl") include("Cubes/transformedcubes.jl") +include("Cubes/batchextraction.jl") include("Datasets/datasets.jl") From 9532f38b654b7a286e7d200cd52316d7e97c9ed8 Mon Sep 17 00:00:00 2001 From: Fabian Gans Date: Tue, 11 Oct 2022 11:24:03 +0200 Subject: [PATCH 3/3] Make tests work on 1.6 --- .github/workflows/CI.yml | 1 + src/Cubes/Cubes.jl | 1 - test/Cubes/batchextraction.jl | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1f5f6237..ac36793c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -14,6 +14,7 @@ jobs: matrix: version: - '1.6' + - '1' os: - ubuntu-latest - macOS-latest diff --git a/src/Cubes/Cubes.jl b/src/Cubes/Cubes.jl index 74e6973b..dc1af391 100644 --- a/src/Cubes/Cubes.jl +++ b/src/Cubes/Cubes.jl @@ -215,7 +215,6 @@ cubechunks(c) = approx_chunksize(eachchunk(c)) DiskArrays.eachchunk(c::YAXArray) = c.chunks getindex_all(a) = getindex(a, ntuple(_ -> Colon(), ndims(a))...) function Base.getindex(x::YAXArray, i...) - @show length(i), istable(i) if length(i)==1 && istable(first(i)) batchextract(x,first(i)) else diff --git a/test/Cubes/batchextraction.jl b/test/Cubes/batchextraction.jl index df038985..cad691fe 100644 --- a/test/Cubes/batchextraction.jl +++ b/test/Cubes/batchextraction.jl @@ -12,7 +12,7 @@ c_perm = permutedims(c,(3,2,1)) sites_names = [(lon = rand()*5+30, lat = rand()*5+50,site = string(i)) for i in 1:200] -sites_pure = [n[[:lon,:lat]] for n in sites_names] +sites_pure = [(lon = n.lon, lat=n.lat) for n in sites_names] lon,lat = sites_pure[10] r = c[sites_names]