From 5071d7214330befc09da9289096c8fa4982e2641 Mon Sep 17 00:00:00 2001 From: ffreyer Date: Thu, 26 Dec 2024 23:15:15 +0100 Subject: [PATCH] switch to Vector of connections --- src/utilities/RenderPipeline.jl | 117 +++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 39 deletions(-) diff --git a/src/utilities/RenderPipeline.jl b/src/utilities/RenderPipeline.jl index bd09765f55a..d6473f89d7d 100644 --- a/src/utilities/RenderPipeline.jl +++ b/src/utilities/RenderPipeline.jl @@ -8,6 +8,8 @@ _promote_type(::Type{Float8}, ::Type{Float16}) = Float16 _promote_type(::Type{Float16}, ::Type{Float8}) = Float16 +# TODO: consider adding a "reuse immediately" flag so Stage can communicate that +# it allows output = input struct Format dims::Int type::DataType @@ -50,13 +52,14 @@ is_compatible_with(::Type, ::Type) = false # Connections can have multiple inputs and outputs # e.g. multiple Renders write to objectid and FXAA, SSAO, Display/pick read it struct ConnectionT{T} - inputs::Dict{T, Symbol} # stage => output name - outputs::Dict{T, Symbol} # stage => input name + inputs::Dict{T, Int} # stage => output Index + outputs::Dict{T, Int} # stage => input Index format::Format # derived from inputs & outputs formats end struct Stage - name::String + name::Symbol + # order matters for outputs inputs::Dict{Symbol, Int} outputs::Dict{Symbol, Int} @@ -65,16 +68,17 @@ struct Stage # ^ technically all of these are constants # v these are not - input_connections::Dict{Symbol, ConnectionT{Stage}} - output_connections::Dict{Symbol, ConnectionT{Stage}} + input_connections::Vector{ConnectionT{Stage}} + output_connections::Vector{ConnectionT{Stage}} end const Connection = ConnectionT{Stage} function Stage(name; inputs = NTuple{0, Pair{Symbol, Format}}(), outputs = NTuple{0, Pair{Symbol, Format}}()) - stage = Stage(name, - Dict{Symbol, Int}(), Dict{Symbol, Int}(), Format[], Format[], - Dict{Symbol, Connection}(), Dict{Symbol, Connection}() + stage = Stage(Symbol(name), + Dict{Symbol, Int}(), Dict{Symbol, Int}(), + Format[], Format[], + Connection[], Connection[] ) foreach(enumerate(inputs)) do (i, (k, v)) stage.inputs[k] = i @@ -87,11 +91,8 @@ function Stage(name; inputs = NTuple{0, Pair{Symbol, Format}}(), outputs = NTupl return stage end -function Connection(source::Stage, input::Symbol, target::Stage, output::Symbol) - format = Format( - source.output_formats[source.outputs[input]], - target.input_formats[target.inputs[output]] - ) +function Connection(source::Stage, input::Integer, target::Stage, output::Integer) + format = Format(source.output_formats[input], target.input_formats[output]) return Connection(Dict(source => input), Dict(target => output), format) end @@ -149,33 +150,40 @@ function connect!(pipeline::Pipeline, src::Integer, output::Symbol, trg::Integer haskey(source.outputs, output) || error("output $output does not exist in source stage") haskey(target.inputs, input) || error("input $input does not exist in target stage") + # intialize if not yet initialized + isempty(source.output_connections) && resize!(source.output_connections, length(source.output_formats)) + isempty(target.input_connections) && resize!(target.input_connections, length(target.input_formats)) + + output_idx = source.outputs[output] + input_idx = target.inputs[input] + # create requested connection - connection = Connection(source, output, target, input) + connection = Connection(source, output_idx, target, input_idx) # if the input or output already has an edge, merge it with the create edge # e.g. the color output of source is used for a second stage # or the color input of target is written to by second stage - if haskey(source.output_connections, output) - old = source.output_connections[output] + if isassigned(source.output_connections, output_idx) + old = source.output_connections[output_idx] # There should be exactly one matching connection, and it's probably # near the end? idx = findlast(c -> c === old, pipeline.connections)::Int deleteat!(pipeline.connections, idx) connection = merge(connection, old) end - if haskey(target.input_connections, input) - old = target.input_connections[input] + if isassigned(target.input_connections, input_idx) + old = target.input_connections[input_idx] idx = findlast(c -> c === old, pipeline.connections)::Int deleteat!(pipeline.connections, idx) connection = merge(connection, old) end # attach connection to every input and output - for (stage, key) in connection.inputs - stage.output_connections[key] = connection + for (stage, idx) in connection.inputs + stage.output_connections[idx] = connection end - for (stage, key) in connection.outputs - stage.input_connections[key] = connection + for (stage, idx) in connection.outputs + stage.input_connections[idx] = connection end push!(pipeline.connections, connection) @@ -202,8 +210,8 @@ function verify(pipeline::Pipeline) end if earliest_input >= earliest_output - inputs = join(("$(stage.name).$key" for (stage, key) in connection.inputs), ", ") - outputs = join(("$(stage.name).$key" for (stage, key) in connection.outputs), ", ") + inputs = join(("$(stage.name).$idx" for (stage, idx) in connection.inputs), ", ") + outputs = join(("$(stage.name).$idx" for (stage, idx) in connection.outputs), ", ") error("Connection ($inputs) -> ($outputs) is read before being written to. Not allowed. ($earliest_input ≥ $earliest_output)") end end @@ -333,21 +341,23 @@ function Base.show(io::IO, ::MIME"text/plain", stage::Stage) if !isempty(stage.inputs) print(io, "\ninputs:") - ks = keys(stage.inputs) + ks = collect(keys(stage.inputs)) + sort!(ks, by = k -> stage.inputs[k]) pad = mapreduce(k -> length(string(k)), max, ks) - for k in ks - mark = haskey(stage.input_connections, k) ? 'x' : ' ' - print(io, "\n [$mark] ", lpad(string(k), pad), "::", stage.input_formats[stage.inputs[k]]) + for (i, k) in enumerate(ks) + mark = isassigned(stage.input_connections, i) ? 'x' : ' ' + print(io, "\n [$mark] ", lpad(string(k), pad), "::", stage.input_formats[i]) end end if !isempty(stage.outputs) print(io, "\noutputs:") - ks = keys(stage.outputs) + ks = collect(keys(stage.outputs)) + sort!(ks, by = k -> stage.outputs[k]) pad = mapreduce(k -> length(string(k)), max, ks) - for k in ks - mark = haskey(stage.output_connections, k) ? 'x' : ' ' - print(io, "\n [$mark] ", lpad(string(k), pad), "::", stage.output_formats[stage.outputs[k]]) + for (i, k) in enumerate(ks) + mark = isassigned(stage.output_connections, i) ? 'x' : ' ' + print(io, "\n [$mark] ", lpad(string(k), pad), "::", stage.output_formats[i]) end end @@ -355,9 +365,24 @@ function Base.show(io::IO, ::MIME"text/plain", stage::Stage) end function _names(connection::Connection) - inputs = unique(values(connection.inputs)) - outputs = unique(values(connection.outputs)) - return unique(vcat(inputs, outputs)) + names = Set{Symbol}() + for (stage, idx) in connection.inputs + for (k, i) in stage.outputs + if idx == i + push!(names, k) + break + end + end + end + for (stage, idx) in connection.outputs + for (k, i) in stage.inputs + if idx == i + push!(names, k) + break + end + end + end + return collect(names) end function Base.show(io::IO, connection::Connection) @@ -376,8 +401,15 @@ function Base.show(io::IO, ::MIME"text/plain", connection::Connection) if !isempty(connection.inputs) print(io, "\ninputs:") - elements = map(keys(connection.inputs), values(connection.inputs)) do stage, k - (stage.name, string(k), stage.output_formats[stage.outputs[k]]) + elements = map(keys(connection.inputs), values(connection.inputs)) do stage, idx + key = :temp + for (k, i) in stage.outputs + if idx == i + key = k + break + end + end + return (string(stage.name), string(key), stage.output_formats[idx]) end pad1 = mapreduce(x -> length(x[1]), max, elements) pad2 = mapreduce(x -> length(x[2]), max, elements) @@ -388,8 +420,15 @@ function Base.show(io::IO, ::MIME"text/plain", connection::Connection) if !isempty(connection.outputs) print(io, "\noutputs:") - elements = map(keys(connection.outputs), values(connection.outputs)) do stage, k - (stage.name, string(k), stage.input_formats[stage.inputs[k]]) + elements = map(keys(connection.outputs), values(connection.outputs)) do stage, idx + key = :temp + for (k, i) in stage.inputs + if idx == i + key = k + break + end + end + return (string(stage.name), string(key), stage.input_formats[idx]) end pad1 = mapreduce(x -> length(x[1]), max, elements) pad2 = mapreduce(x -> length(x[2]), max, elements)