Skip to content

Commit

Permalink
introduce new legend approach requiring ProcessedLayers in AxisEntries
Browse files Browse the repository at this point in the history
  • Loading branch information
jkrumbiegel committed May 29, 2024
1 parent be83bdb commit b50b94b
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/AlgebraOfGraphics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ include("dict.jl")
include("theme.jl")
include("helpers.jl")
include("scales.jl")
include("algebra/layer.jl")
include("entries.jl")
include("facet.jl")
include("algebra/layer.jl")
include("algebra/layers.jl")
include("algebra/select.jl")
include("algebra/processing.jl")
Expand Down
7 changes: 5 additions & 2 deletions src/algebra/layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ function Base.:*(l::Layer, l′::Layer)
end

## Format for layer after processing
const PlotType = Type{<:Plot}

Base.@kwdef struct ProcessedLayer <: AbstractDrawable
plottype::PlotType=Plot{plot}
Expand Down Expand Up @@ -198,18 +199,20 @@ function compute_grid_positions(categoricalscales, primary=NamedArguments())
end
end

const MultiAesScaleDict{T} = Dictionary{Type{<:Aesthetic},Dictionary{Union{Nothing,Symbol},T}}

function rescale(p::ProcessedLayer, categoricalscales::MultiAesScaleDict{CategoricalScale})
aes_mapping = aesthetic_mapping(p)

primary = map(keys(p.primary), p.primary) do key, values
aes = hardcoded_or_mapped_aes(key, aes_mapping)
aes = hardcoded_or_mapped_aes(p, key, aes_mapping)
scale_id = get(p.scale_mapping, key, nothing)
scale_dict = get(categoricalscales, aes, nothing)
scale = scale_dict === nothing ? nothing : get(scale_dict, scale_id, nothing)
return rescale(values, scale)
end
positional = map(keys(p.positional), p.positional) do key, values
aes = hardcoded_or_mapped_aes(key, aes_mapping)
aes = hardcoded_or_mapped_aes(p, key, aes_mapping)
scale_id = get(p.scale_mapping, key, nothing)
scale_dict = get(categoricalscales, aes, nothing)
scale = scale_dict === nothing ? nothing : get(scale_dict, scale_id, nothing)
Expand Down
14 changes: 10 additions & 4 deletions src/algebra/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,13 @@ end

const AestheticMapping = Dictionary{Union{Int,Symbol},Type{<:Aesthetic}}

function hardcoded_or_mapped_aes(key::Union{Int,Symbol}, aes_mapping::AestheticMapping)
@something hardcoded_visual_scale(key) aes_mapping[key]
function hardcoded_or_mapped_aes(processedlayer, key::Union{Int,Symbol}, aes_mapping::AestheticMapping)
hardcoded = hardcoded_visual_scale(key)
hardcoded !== nothing && return hardcoded
if !haskey(aes_mapping, key)
throw(ArgumentError("ProcessedLayer with plot type $(processedlayer.plottype) did not have $key in its AestheticMapping. The mapping was $aes_mapping"))
end
return aes_mapping[key]
end

function compute_axes_grid(d::AbstractDrawable;
Expand All @@ -151,7 +156,7 @@ function compute_axes_grid(d::AbstractDrawable;

for (key, scale) in pairs(catscales)
scale_id = get(processedlayer.scale_mapping, key, nothing)
aes = hardcoded_or_mapped_aes(key, aes_mapping)
aes = hardcoded_or_mapped_aes(processedlayer, key, aes_mapping)
if !haskey(categoricalscales, aes)
insert!(categoricalscales, aes, eltype(categoricalscales)())
end
Expand All @@ -178,7 +183,8 @@ function compute_axes_grid(d::AbstractDrawable;
AxisSpec(c, axis),
entries_grid[c],
categoricalscales,
continuousscales_grid[c]
continuousscales_grid[c],
pls_grid[c],
)
end

Expand Down
9 changes: 4 additions & 5 deletions src/entries.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
const PlotType = Type{<:Plot}

"""
Entry(plottype::PlotType, positional::Arguments, named::NamedArguments)
Expand Down Expand Up @@ -44,13 +42,13 @@ end
# each Aesthetic can (potentially) have multiple separate scales associated with it, for example
# two different color scales. For some aesthetics like AesX or AesLayout it doesn't make sense to have more than one.
# Those should trigger meaningful error messages if they are used with multiple scales
const MultiAesScaleDict{T} = Dictionary{Type{<:Aesthetic},Dictionary{Union{Nothing,Symbol},T}}

struct AxisSpecEntries
axis::AxisSpec
entries::Vector{Entry}
categoricalscales::MultiAesScaleDict{CategoricalScale}
continuousscales::MultiAesScaleDict{ContinuousScale}
processedlayers::Vector{ProcessedLayer} # the layers that were used to create the entries, for legend purposes
end

"""
Expand All @@ -65,15 +63,16 @@ struct AxisEntries
entries::Vector{Entry}
categoricalscales::MultiAesScaleDict{CategoricalScale}
continuousscales::MultiAesScaleDict{ContinuousScale}
processedlayers::Vector{ProcessedLayer} # the layers that were used to create the entries, for legend purposes
end

function AxisEntries(ae::AxisSpecEntries, fig)
ax = ae.axis.type(fig[ae.axis.position...]; pairs(ae.axis.attributes)...)
AxisEntries(ax, ae.entries, ae.categoricalscales, ae.continuousscales)
AxisEntries(ax, ae.entries, ae.categoricalscales, ae.continuousscales, ae.processedlayers)
end

function AxisEntries(ae::AxisSpecEntries, ax::Union{Axis, Axis3})
AxisEntries(ax, ae.entries, ae.categoricalscales, ae.continuousscales)
AxisEntries(ax, ae.entries, ae.categoricalscales, ae.continuousscales, ae.processedlayers)
end

function Makie.plot!(ae::AxisEntries)
Expand Down
86 changes: 71 additions & 15 deletions src/guides/legend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ function compute_legend(grid::Matrix{AxisEntries})
# if no legendable scale is present, return nothing
isempty(scales) && return nothing

processedlayers = first(grid).processedlayers

plottypes, attributes = plottypes_attributes(entries(grid))

# turn dict of dicts into single-level dict
Expand All @@ -66,30 +68,84 @@ function compute_legend(grid::Matrix{AxisEntries})
end
end

titles = unique!(collect(map(getlabel, scales_flattened)))

titles = []
labels = Vector{AbstractString}[]
elements_list = Vector{Vector{LegendElement}}[]

for title in titles
label_attrs = [key for (key, val) in pairs(scales_flattened) if getlabel(val) == title]
uniquevalues = mapreduce(k -> datavalues(scales_flattened[k]), assert_equal, label_attrs)
elements = map(eachindex(uniquevalues)) do idx
local elements = LegendElement[]
for (P, attrs) in zip(plottypes, attributes)
shared_attrs = attrs label_attrs
isempty(shared_attrs) && continue
options = [attr => plotvalues(scales_flattened[attr])[idx] for attr in shared_attrs]
append!(elements, legend_elements(P; options...))
for (aes, scaledict) in pairs(scales)
for (scale_id, scale) in pairs(scaledict)
push!(titles, @show(getlabel(scale)))

display(scale)

datavals = @show datavalues(scale)
plotvals = @show plotvalues(scale)

legend_els = [LegendElement[] for _ in datavals]

for processedlayer in processedlayers
aes_mapping = aesthetic_mapping(processedlayer)
ProcessedLayer
matching_keys = filter(keys(merge(dictionary(processedlayer.positional), processedlayer.primary, processedlayer.named))) do key
get(aes_mapping, key, nothing) === aes &&
get(processedlayer.scale_mapping, key, nothing) === scale_id
end

isempty(matching_keys) && continue

for (i, (dataval, plotval)) in enumerate(zip(datavals, plotvals))
append!(legend_els[i], legend_elements(processedlayer, MixedArguments(map(key -> plotval, matching_keys))))
end

end
return elements

# label_attrs = [key for (key, val) in pairs(scales_flattened) if getlabel(val) == title]
# @show label_attrs
# uniquevalues = mapreduce(k -> datavalues(scales_flattened[k]), assert_equal, label_attrs)
# @show uniquevalues
# elements = map(eachindex(uniquevalues)) do idx
# local elements = LegendElement[]
# for (P, attrs) in zip(plottypes, attributes)
# shared_attrs = attrs ∩ label_attrs
# isempty(shared_attrs) && continue
# options = [attr => plotvalues(scales_flattened[attr])[idx] for attr in shared_attrs]
# append!(elements, legend_elements(P; options...))
# end
# return elements
# end
# push!(labels, map(to_string, uniquevalues))
# push!(elements_list, elements)
push!(labels, string.(datavals))
push!(elements_list, legend_els)
end
push!(labels, map(to_string, uniquevalues))
push!(elements_list, elements)
end
return elements_list, labels, titles
end

function legend_elements(p::ProcessedLayer, scale_args::MixedArguments)
legend_elements(p.plottype, scale_args)
end

function legend_elements(::Type{Scatter}, scale_args::MixedArguments)
[MarkerElement(
color = haskey(scale_args, :color) ? scale_args[:color] : Makie.current_default_theme()[:markercolor],
markerpoints = [Point2f(0.5, 0.5)],
marker = Makie.current_default_theme()[:marker],
)]
end

function legend_elements(::Type{BarPlot}, scale_args::MixedArguments)
[PolyElement(
color = haskey(scale_args, :color) ? scale_args[:color] : Makie.current_default_theme()[:patchcolor],
)]
end

function legend_elements(::Type{HLines}, scale_args::MixedArguments)
[LineElement(
color = haskey(scale_args, :color) ? scale_args[:color] : Makie.current_default_theme()[:linecolor],
)]
end

# Notes

# TODO: correctly handle composite plot types (now fall back to poly)
Expand Down
8 changes: 7 additions & 1 deletion src/transformations/visual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ function _aesthetic_mapping(::Type{Scatter}, attributes)
1 => AesX,
2 => AesY,
:color => AesColor,
:strokecolor => AesColor,
])
end

_aesthetic_mapping(::Type{HLines}, attributes) = dictionary([1 => AesY])
function _aesthetic_mapping(::Type{HLines}, attributes)
dictionary([
1 => AesY,
:color => AesColor,
])
end

0 comments on commit b50b94b

Please sign in to comment.