Skip to content

Commit

Permalink
Initial work towards hparam support
Browse files Browse the repository at this point in the history
- added TBHparams wrapper type
- added first attempt at hparams_summary
- updated imports in tensorboard/tensorboard.jl to include hparams plugin

More work on summarizing hparams and hparam configs

small fixes from testing

adding test case for HParamConfig logging

small fixes from testing

adding logger_dispatch_overrides tests for TBHParamsConfig and TBHParams

all tests passing

add log_hparams and log_hparams_config

remove unused SESSION_END_INFO_TAG

comment

remove unnecessary TBHParamsConfig wrapper

remove isnothing (Julia 1.0 compat)

Update LogHParams.jl
  • Loading branch information
PhilipVinc committed Apr 25, 2021
1 parent 658e4b6 commit fcc4b81
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 4 deletions.
187 changes: 187 additions & 0 deletions src/Loggers/LogHParams.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
PLUGIN_NAME = "hparams"
PLUGIN_DATA_VERSION = 0

EXPERIMENT_TAG = "_hparams_/experiment"
SESSION_START_INFO_TAG = "_hparams_/session_start_info"


struct DiscreteDomain{T}
values::AbstractVector{T}
end

DiscreteDomainElem = Union{String, Float64, Bool}

hparams_datatype_sym(::Type{String}) = :DATA_TYPE_STRING
hparams_datatype_sym(::Type{Float64}) = :DATA_TYPE_FLOAT64
hparams_datatype_sym(::Type{Bool}) = :DATA_TYPE_BOOL

function hparams_datatype(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem
tensorboard.hparams._DataType[hparams_datatype_sym(T)]
end

ProtoBuf.google.protobuf.Value(x::Number) = Value(number_value=x)
ProtoBuf.google.protobuf.Value(x::Bool) = Value(bool_value=x)
ProtoBuf.google.protobuf.Value(x::AbstractString) = Value(string_value=x)
function ProtoBuf.google.protobuf.Value(x)
@warn "Cannot create a ProtoBuf.google.protobuf.Value of type $(typeof(x)), defaulting to null."
Value(null_value=Int32(0))
end


function ProtoBuf.google.protobuf.ListValue(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem
ProtoBuf.google.protobuf.ListValue(
values = ProtoBuf.google.protobuf.Value.(domain.values)
)
end

struct IntervalDomain
min_value::Float64
max_value::Float64
end

Interval(d::IntervalDomain) = Interval(min_value=d.min_value, max_value=d.max_value)

HParamDomain = Union{IntervalDomain, DiscreteDomain}

struct HParam
name::AbstractString
domain::HParamDomain
display_name::AbstractString
description::AbstractString
end


function HParamInfo(hparam::HParam)
domain = hparam.domain
domain_kwargs = if isa(domain, IntervalDomain)
(;domain_interval = Interval(domain))
else
@assert isa(domain, DiscreteDomain)
(_type = hparams_datatype(domain),
domain_discrete = ProtoBuf.google.protobuf.ListValue(domain))
end
HParamInfo(;name = hparam.name,
description = hparam.description,
display_name = hparam.display_name,
domain_kwargs...)
end

struct Metric
tag::AbstractString
group::AbstractString
display_name::AbstractString
description::AbstractString
dataset_type::Symbol

function Metric(tag::AbstractString,
group::AbstractString,
display_name::AbstractString,
description::AbstractString,
dataset_type::Symbol)
valid_dataset_types = keys(tensorboard.hparams.DatasetType)
if dataset_type valid_dataset_types
throw(ArgumentError("dataset_type of $(dataset_type) is not one of $(map(string, valid_dataset_types))."))
else
new(tag, group, display_name, description, dataset_type)
end
end
end

function MetricInfo(metric::Metric)
MetricInfo(
name=MetricName(
group=metric.group,
tag=metric.tag,
),
display_name=metric.display_name,
description=metric.description,
dataset_type=tensorboard.hparams.DatasetType[metric.dataset_type]
)
end

struct HParamsConfig
hparams::AbstractVector{HParam}
metrics::AbstractVector{Metric}
time_created_secs::Float64
end

function SummaryMetadata(hparams_plugin_data::HParamsPluginData)
SummaryMetadata(
plugin_data = SummaryMetadata_PluginData(
plugin_name = PLUGIN_NAME,
content = serialize_proto(hparams_plugin_data)
)
)
end

function Summary_Value(tag, hparams_plugin_data::HParamsPluginData)
null_tensor = TensorProto(dtype = _DataType.DT_FLOAT, tensor_shape = TensorShapeProto(dim=[]))
Summary_Value(
tag = tag,
metadata = SummaryMetadata(hparams_plugin_data),
tensor = null_tensor
)
end

function log_hparams_config(logger::TBLogger,
hparams_config::HParamsConfig;
step=nothing)
summ = SummaryCollection(
hparams_config_summary(hparams_config)
)
write_event(logger.file, make_event(logger, summ, step=step))
end

function hparams_config_summary(config::HParamsConfig)
Summary_Value(
EXPERIMENT_TAG,
HParamsPluginData(
version = PLUGIN_DATA_VERSION,
experiment = Experiment(
hparam_infos = HParamInfo.(config.hparams),
metric_infos = MetricInfo.(config.metrics),
time_created_secs = config.time_created_secs
)
)
)
end

function log_hparams(logger::TBLogger,
hparams_dict::Dict{HParam, Any},
group_name::AbstractString,
trial_id::AbstractString,
start_time_secs::Union{Float64, Nothing};
step=nothing)
summ = SummaryCollection(
hparams_summary(hparams_dict,
group_name,
trial_id,
start_time_secs)
)
write_event(logger.file, make_event(logger, summ, step=step))
end

function hparams_summary(hparams_dict::Dict{HParam, Any},
group_name::AbstractString,
trial_id::AbstractString,
start_time_secs=Union{Float64, Nothing})
if start_time_secs === nothing
start_time_secs = time()
end

Summary_Value(
SESSION_START_INFO_TAG,
HParamsPluginData(
version = PLUGIN_DATA_VERSION,
session_start_info = SessionStartInfo(
group_name = group_name,
start_time_secs = start_time_secs,
hparams = Dict(
hparam.name => ProtoBuf.google.protobuf.Value(val)
for (hparam, val) hparams_dict
)
)
)
)
end

15 changes: 13 additions & 2 deletions src/TensorBoardLogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info,

export TBLogger, reset!, set_step!, increment_step!, set_step_increment!
export log_histogram, log_value, log_vector, log_text, log_image, log_images,
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar,
log_hparams, log_hparams_config
export map_summaries, TBReader

export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC,
Expand All @@ -30,7 +31,7 @@ export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC,
export tb_multiline, tb_margin

# Wrapper types
export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios
export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios, TBHParams

# workaround for FileIO pre 1.6
include("FileIO_workaround.jl")
Expand All @@ -40,6 +41,13 @@ include("protojl/tensorboard/tensorboard.jl")
using .tensorboard: Summary_Value, GraphDef, Summary, Event, SessionLog_SessionStatus, SessionLog
using .tensorboard: TensorShapeProto_Dim, TensorShapeProto, TextPluginData
using .tensorboard: TensorProto, SummaryMetadata, SummaryMetadata_PluginData, _DataType
using .tensorboard.hparams: HParamsPluginData, Experiment, SessionStartInfo, SessionEndInfo, HParamInfo, MetricInfo, HParamInfo, Interval, MetricName, DatasetType
import .tensorboard.hparams
import .tensorboard: SummaryMetadata, Summary
import .tensorboard.hparams: HParamInfo, MetricInfo, Interval

using ProtoBuf
import ProtoBuf.google.protobuf: Value, ListValue

include("PNG.jl")
using .PNGImage
Expand All @@ -61,6 +69,9 @@ include("Loggers/LogEmbeddings.jl")
# Custom Scalar Plugin
include("Loggers/LogCustomScalar.jl")

include("Loggers/LogHParams.jl")


include("logger_dispatch.jl")
include("logger_dispatch_overrides.jl")

Expand Down
19 changes: 19 additions & 0 deletions src/logger_dispatch_overrides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,22 @@ preprocess(name, val::TBVector{T,N}, data) where {T<:Complex,N} =
push!(data, name*"/re"=>TBVector(real.(content(val))), name*"/im"=>TBVector(imag.(content(val))))
summary_impl(name, val::TBVector) = histogram_summary(name, collect(0:length(val.data)),
val.data)

########## Hyperparameters ########################

# FIXME: name unused?
summary_impl(name, val::HParamsConfig) = hparams_config_summary(val.data)
preprocess(name, val::HParamsConfig, data) = push!(data, name=>val)

struct TBHParams <: WrapperLogType
# TODO: The types in the hparam domain and this dict's values are constrained.
# e.g. an hparam with a discrete domain of ["a", "b"] must have string values
# Consider ways to enforce this relationship in the type system.
data::Dict{HParam, Any}
# FIXME: group_name auto generated in the Python implementation (Tensorboard)
group_name::AbstractString
trial_id::AbstractString
start_time_secs::Union{Float64, Nothing}
end
content(x::TBHParams) = x.data
summary_impl(name, val::TBHParams) = hparams_summary(val.data, val.group_name, val.trial_id, val.start_time_secs)
8 changes: 6 additions & 2 deletions src/protojl/tensorboard/tensorboard.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ module tensorboard
include("plugins/custom_scalar/layout_pb.jl")
include("plugins/text/plugin_data_pb.jl")

#include("plugins/hparams/hparams.jl")

# Needs separate module due to conflicting "_DataType" export
module hparams
include("plugins/hparams/api_pb.jl")
include("plugins/hparams/hparams_util_pb.jl")
include("plugins/hparams/plugin_data_pb.jl")
end
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using TensorBoardLogger, Logging
using TensorBoardLogger: preprocess, summary_impl
using TensorBoardLogger: IntervalDomain, DiscreteDomain, HParam, Metric, HParamsConfig
using Test
using TestImages
using ImageCore
Expand Down Expand Up @@ -271,6 +272,10 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true

close.(values(logger.all_files))
end

@testset "HParams Backend" begin
include("test_hparams.jl")
end

@testset "Embedding Logger" begin
logger = TBLogger(test_log_dir*"t", tb_overwrite)
Expand Down
47 changes: 47 additions & 0 deletions test/test_hparams.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
@testset "HParamConfig Logger" begin
logger = TBLogger(test_log_dir*"t", tb_overwrite)
step = 1

interval_domain = IntervalDomain(0.1, 3.0)
hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1")

discrete_domain_strs = ["a", "b", "c"]
discrete_domain = DiscreteDomain(discrete_domain_strs)
hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2")

hparams = [hparam1, hparam2]

metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION)
metrics = [metric]
hparams_config = HParamsConfig(hparams, metrics, 1.2)
ss = TensorBoardLogger.hparams_config_summary(hparams_config)

@test isa(ss, TensorBoardLogger.Summary_Value)
@test ss.tag == TensorBoardLogger.EXPERIMENT_TAG

# TODO: Deserialize and test more properties

log_hparams_config(logger, hparams_config ;step=step)
end

@testset "HParams Logger" begin
logger = TBLogger(test_log_dir*"t", tb_overwrite)
step = 1

interval_domain = IntervalDomain(0.1, 3.0)
hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1")

discrete_domain_strs = ["a", "b", "c"]
discrete_domain = DiscreteDomain(discrete_domain_strs)
hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2")

hparams_dict = Dict(hparam1 => 1.2, hparam2 => "b")

ss = TensorBoardLogger.hparams_summary(hparams_dict, "group_name", "trial_id", nothing)

@test isa(ss, TensorBoardLogger.Summary_Value)
@test ss.tag == TensorBoardLogger.SESSION_START_INFO_TAG

# TODO: Deserialize and test more properties
log_hparams(logger, hparams_dict, "group_name", "trial_id", nothing ;step=step)
end
19 changes: 19 additions & 0 deletions test/test_logger_dispatch_overrides.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using TensorBoardLogger, Test
using TensorBoardLogger: preprocess, content
using TensorBoardLogger: TBHParams
using TestImages
using ImageCore
@testset "TBText" begin
Expand Down Expand Up @@ -61,3 +62,21 @@ end
@test first(data) == ("test/1"=>TBAudio(y, sample_rate))
@test last(data) == ("test/2"=>TBAudio(y, sample_rate))
end

@testset "HParamsConfig" begin
data = Vector{Pair{String,Any}}()
hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1")
metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION)
params_config = HParamsConfig([hparam], [metric], 1.2)
preprocess("test", params_config, data)
@test first(data) == ("test"=>params_config)
end

@testset "TBHParams" begin
data = Vector{Pair{String,Any}}()
hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1")
hparams_dict = Dict(hparam => 1.2)
tbh_params = TBHParams(hparams_dict, "group_name", "trial_id", nothing)
preprocess("test", tbh_params, data)
@test first(data) == ("test"=>tbh_params)
end

0 comments on commit fcc4b81

Please sign in to comment.