Skip to content

Commit

Permalink
switch to Vector of connections
Browse files Browse the repository at this point in the history
  • Loading branch information
ffreyer committed Dec 26, 2024
1 parent ff2eb83 commit 5071d72
Showing 1 changed file with 78 additions and 39 deletions.
117 changes: 78 additions & 39 deletions src/utilities/RenderPipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ _promote_type(::Type{Float8}, ::Type{Float16}) = Float16
_promote_type(::Type{Float16}, ::Type{Float8}) = Float16


# TODO: consider adding a "reuse immediately" flag so Stage can communicate that
# it allows output = input
struct Format
dims::Int
type::DataType
Expand Down Expand Up @@ -50,13 +52,14 @@ is_compatible_with(::Type, ::Type) = false
# Connections can have multiple inputs and outputs
# e.g. multiple Renders write to objectid and FXAA, SSAO, Display/pick read it
struct ConnectionT{T}
inputs::Dict{T, Symbol} # stage => output name
outputs::Dict{T, Symbol} # stage => input name
inputs::Dict{T, Int} # stage => output Index
outputs::Dict{T, Int} # stage => input Index
format::Format # derived from inputs & outputs formats
end

struct Stage
name::String
name::Symbol

# order matters for outputs
inputs::Dict{Symbol, Int}
outputs::Dict{Symbol, Int}
Expand All @@ -65,16 +68,17 @@ struct Stage
# ^ technically all of these are constants

# v these are not
input_connections::Dict{Symbol, ConnectionT{Stage}}
output_connections::Dict{Symbol, ConnectionT{Stage}}
input_connections::Vector{ConnectionT{Stage}}
output_connections::Vector{ConnectionT{Stage}}
end

const Connection = ConnectionT{Stage}

function Stage(name; inputs = NTuple{0, Pair{Symbol, Format}}(), outputs = NTuple{0, Pair{Symbol, Format}}())
stage = Stage(name,
Dict{Symbol, Int}(), Dict{Symbol, Int}(), Format[], Format[],
Dict{Symbol, Connection}(), Dict{Symbol, Connection}()
stage = Stage(Symbol(name),
Dict{Symbol, Int}(), Dict{Symbol, Int}(),
Format[], Format[],
Connection[], Connection[]
)
foreach(enumerate(inputs)) do (i, (k, v))
stage.inputs[k] = i
Expand All @@ -87,11 +91,8 @@ function Stage(name; inputs = NTuple{0, Pair{Symbol, Format}}(), outputs = NTupl
return stage
end

function Connection(source::Stage, input::Symbol, target::Stage, output::Symbol)
format = Format(
source.output_formats[source.outputs[input]],
target.input_formats[target.inputs[output]]
)
function Connection(source::Stage, input::Integer, target::Stage, output::Integer)
format = Format(source.output_formats[input], target.input_formats[output])
return Connection(Dict(source => input), Dict(target => output), format)
end

Expand Down Expand Up @@ -149,33 +150,40 @@ function connect!(pipeline::Pipeline, src::Integer, output::Symbol, trg::Integer
haskey(source.outputs, output) || error("output $output does not exist in source stage")
haskey(target.inputs, input) || error("input $input does not exist in target stage")

# intialize if not yet initialized
isempty(source.output_connections) && resize!(source.output_connections, length(source.output_formats))
isempty(target.input_connections) && resize!(target.input_connections, length(target.input_formats))

output_idx = source.outputs[output]
input_idx = target.inputs[input]

# create requested connection
connection = Connection(source, output, target, input)
connection = Connection(source, output_idx, target, input_idx)

# if the input or output already has an edge, merge it with the create edge
# e.g. the color output of source is used for a second stage
# or the color input of target is written to by second stage
if haskey(source.output_connections, output)
old = source.output_connections[output]
if isassigned(source.output_connections, output_idx)
old = source.output_connections[output_idx]
# There should be exactly one matching connection, and it's probably
# near the end?
idx = findlast(c -> c === old, pipeline.connections)::Int
deleteat!(pipeline.connections, idx)
connection = merge(connection, old)
end
if haskey(target.input_connections, input)
old = target.input_connections[input]
if isassigned(target.input_connections, input_idx)
old = target.input_connections[input_idx]
idx = findlast(c -> c === old, pipeline.connections)::Int
deleteat!(pipeline.connections, idx)
connection = merge(connection, old)
end

# attach connection to every input and output
for (stage, key) in connection.inputs
stage.output_connections[key] = connection
for (stage, idx) in connection.inputs
stage.output_connections[idx] = connection
end
for (stage, key) in connection.outputs
stage.input_connections[key] = connection
for (stage, idx) in connection.outputs
stage.input_connections[idx] = connection
end

push!(pipeline.connections, connection)
Expand All @@ -202,8 +210,8 @@ function verify(pipeline::Pipeline)
end

if earliest_input >= earliest_output
inputs = join(("$(stage.name).$key" for (stage, key) in connection.inputs), ", ")
outputs = join(("$(stage.name).$key" for (stage, key) in connection.outputs), ", ")
inputs = join(("$(stage.name).$idx" for (stage, idx) in connection.inputs), ", ")
outputs = join(("$(stage.name).$idx" for (stage, idx) in connection.outputs), ", ")
error("Connection ($inputs) -> ($outputs) is read before being written to. Not allowed. ($earliest_input$earliest_output)")
end
end
Expand Down Expand Up @@ -333,31 +341,48 @@ function Base.show(io::IO, ::MIME"text/plain", stage::Stage)

if !isempty(stage.inputs)
print(io, "\ninputs:")
ks = keys(stage.inputs)
ks = collect(keys(stage.inputs))
sort!(ks, by = k -> stage.inputs[k])
pad = mapreduce(k -> length(string(k)), max, ks)
for k in ks
mark = haskey(stage.input_connections, k) ? 'x' : ' '
print(io, "\n [$mark] ", lpad(string(k), pad), "::", stage.input_formats[stage.inputs[k]])
for (i, k) in enumerate(ks)
mark = isassigned(stage.input_connections, i) ? 'x' : ' '
print(io, "\n [$mark] ", lpad(string(k), pad), "::", stage.input_formats[i])
end
end

if !isempty(stage.outputs)
print(io, "\noutputs:")
ks = keys(stage.outputs)
ks = collect(keys(stage.outputs))
sort!(ks, by = k -> stage.outputs[k])
pad = mapreduce(k -> length(string(k)), max, ks)
for k in ks
mark = haskey(stage.output_connections, k) ? 'x' : ' '
print(io, "\n [$mark] ", lpad(string(k), pad), "::", stage.output_formats[stage.outputs[k]])
for (i, k) in enumerate(ks)
mark = isassigned(stage.output_connections, i) ? 'x' : ' '
print(io, "\n [$mark] ", lpad(string(k), pad), "::", stage.output_formats[i])
end
end

return
end

function _names(connection::Connection)
inputs = unique(values(connection.inputs))
outputs = unique(values(connection.outputs))
return unique(vcat(inputs, outputs))
names = Set{Symbol}()
for (stage, idx) in connection.inputs
for (k, i) in stage.outputs
if idx == i
push!(names, k)
break
end
end
end
for (stage, idx) in connection.outputs
for (k, i) in stage.inputs
if idx == i
push!(names, k)
break
end
end
end
return collect(names)
end

function Base.show(io::IO, connection::Connection)
Expand All @@ -376,8 +401,15 @@ function Base.show(io::IO, ::MIME"text/plain", connection::Connection)

if !isempty(connection.inputs)
print(io, "\ninputs:")
elements = map(keys(connection.inputs), values(connection.inputs)) do stage, k
(stage.name, string(k), stage.output_formats[stage.outputs[k]])
elements = map(keys(connection.inputs), values(connection.inputs)) do stage, idx
key = :temp
for (k, i) in stage.outputs
if idx == i
key = k
break
end
end
return (string(stage.name), string(key), stage.output_formats[idx])
end
pad1 = mapreduce(x -> length(x[1]), max, elements)
pad2 = mapreduce(x -> length(x[2]), max, elements)
Expand All @@ -388,8 +420,15 @@ function Base.show(io::IO, ::MIME"text/plain", connection::Connection)

if !isempty(connection.outputs)
print(io, "\noutputs:")
elements = map(keys(connection.outputs), values(connection.outputs)) do stage, k
(stage.name, string(k), stage.input_formats[stage.inputs[k]])
elements = map(keys(connection.outputs), values(connection.outputs)) do stage, idx
key = :temp
for (k, i) in stage.inputs
if idx == i
key = k
break
end
end
return (string(stage.name), string(key), stage.input_formats[idx])
end
pad1 = mapreduce(x -> length(x[1]), max, elements)
pad2 = mapreduce(x -> length(x[2]), max, elements)
Expand Down

0 comments on commit 5071d72

Please sign in to comment.