From c69f7e62fd8bb6e48e19400590de1c1b4286955f Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Sat, 2 Jan 2021 14:43:12 +0100 Subject: [PATCH 1/4] Initial work towards hparam support - added TBHparams wrapper type - added first attempt at hparams_summary - updated imports in tensorboard/tensorboard.jl to include hparams plugin More work on summarizing hparams and hparam configs small fixes from testing adding test case for HParamConfig logging small fixes from testing adding logger_dispatch_overrides tests for TBHParamsConfig and TBHParams all tests passing add log_hparams and log_hparams_config remove unused SESSION_END_INFO_TAG comment remove unnecessary TBHParamsConfig wrapper remove isnothing (Julia 1.0 compat) Update LogHParams.jl --- src/Loggers/LogHParams.jl | 187 +++++++++++++++++++++++++ src/TensorBoardLogger.jl | 15 +- src/logger_dispatch_overrides.jl | 19 +++ src/protojl/tensorboard/tensorboard.jl | 8 +- test/runtests.jl | 5 + test/test_hparams.jl | 47 +++++++ test/test_logger_dispatch_overrides.jl | 19 +++ 7 files changed, 296 insertions(+), 4 deletions(-) create mode 100644 src/Loggers/LogHParams.jl create mode 100644 test/test_hparams.jl diff --git a/src/Loggers/LogHParams.jl b/src/Loggers/LogHParams.jl new file mode 100644 index 00000000..c992225b --- /dev/null +++ b/src/Loggers/LogHParams.jl @@ -0,0 +1,187 @@ +PLUGIN_NAME = "hparams" +PLUGIN_DATA_VERSION = 0 + +EXPERIMENT_TAG = "_hparams_/experiment" +SESSION_START_INFO_TAG = "_hparams_/session_start_info" + + +struct DiscreteDomain{T} + values::AbstractVector{T} +end + +DiscreteDomainElem = Union{String, Float64, Bool} + +hparams_datatype_sym(::Type{String}) = :DATA_TYPE_STRING +hparams_datatype_sym(::Type{Float64}) = :DATA_TYPE_FLOAT64 +hparams_datatype_sym(::Type{Bool}) = :DATA_TYPE_BOOL + +function hparams_datatype(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem + tensorboard.hparams._DataType[hparams_datatype_sym(T)] +end + +ProtoBuf.google.protobuf.Value(x::Number) = Value(number_value=x) +ProtoBuf.google.protobuf.Value(x::Bool) = Value(bool_value=x) +ProtoBuf.google.protobuf.Value(x::AbstractString) = Value(string_value=x) +function ProtoBuf.google.protobuf.Value(x) + @warn "Cannot create a ProtoBuf.google.protobuf.Value of type $(typeof(x)), defaulting to null." + Value(null_value=Int32(0)) +end + + +function ProtoBuf.google.protobuf.ListValue(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem + ProtoBuf.google.protobuf.ListValue( + values = ProtoBuf.google.protobuf.Value.(domain.values) + ) +end + +struct IntervalDomain + min_value::Float64 + max_value::Float64 +end + +Interval(d::IntervalDomain) = Interval(min_value=d.min_value, max_value=d.max_value) + +HParamDomain = Union{IntervalDomain, DiscreteDomain} + +struct HParam + name::AbstractString + domain::HParamDomain + display_name::AbstractString + description::AbstractString +end + + +function HParamInfo(hparam::HParam) + domain = hparam.domain + domain_kwargs = if isa(domain, IntervalDomain) + (;domain_interval = Interval(domain)) + else + @assert isa(domain, DiscreteDomain) + (_type = hparams_datatype(domain), + domain_discrete = ProtoBuf.google.protobuf.ListValue(domain)) + end + HParamInfo(;name = hparam.name, + description = hparam.description, + display_name = hparam.display_name, + domain_kwargs...) +end + +struct Metric + tag::AbstractString + group::AbstractString + display_name::AbstractString + description::AbstractString + dataset_type::Symbol + + function Metric(tag::AbstractString, + group::AbstractString, + display_name::AbstractString, + description::AbstractString, + dataset_type::Symbol) + valid_dataset_types = keys(tensorboard.hparams.DatasetType) + if dataset_type ∉ valid_dataset_types + throw(ArgumentError("dataset_type of $(dataset_type) is not one of $(map(string, valid_dataset_types)).")) + else + new(tag, group, display_name, description, dataset_type) + end + end +end + +function MetricInfo(metric::Metric) + MetricInfo( + name=MetricName( + group=metric.group, + tag=metric.tag, + ), + display_name=metric.display_name, + description=metric.description, + dataset_type=tensorboard.hparams.DatasetType[metric.dataset_type] + ) +end + +struct HParamsConfig + hparams::AbstractVector{HParam} + metrics::AbstractVector{Metric} + time_created_secs::Float64 +end + +function SummaryMetadata(hparams_plugin_data::HParamsPluginData) + SummaryMetadata( + plugin_data = SummaryMetadata_PluginData( + plugin_name = PLUGIN_NAME, + content = serialize_proto(hparams_plugin_data) + ) + ) +end + +function Summary_Value(tag, hparams_plugin_data::HParamsPluginData) + null_tensor = TensorProto(dtype = _DataType.DT_FLOAT, tensor_shape = TensorShapeProto(dim=[])) + Summary_Value( + tag = tag, + metadata = SummaryMetadata(hparams_plugin_data), + tensor = null_tensor + ) +end + +function log_hparams_config(logger::TBLogger, + hparams_config::HParamsConfig; + step=nothing) + summ = SummaryCollection( + hparams_config_summary(hparams_config) + ) + write_event(logger.file, make_event(logger, summ, step=step)) +end + +function hparams_config_summary(config::HParamsConfig) + Summary_Value( + EXPERIMENT_TAG, + HParamsPluginData( + version = PLUGIN_DATA_VERSION, + experiment = Experiment( + hparam_infos = HParamInfo.(config.hparams), + metric_infos = MetricInfo.(config.metrics), + time_created_secs = config.time_created_secs + ) + ) + ) +end + +function log_hparams(logger::TBLogger, + hparams_dict::Dict{HParam, Any}, + group_name::AbstractString, + trial_id::AbstractString, + start_time_secs::Union{Float64, Nothing}; + step=nothing) + summ = SummaryCollection( + hparams_summary(hparams_dict, + group_name, + trial_id, + start_time_secs) + ) + write_event(logger.file, make_event(logger, summ, step=step)) +end + +function hparams_summary(hparams_dict::Dict{HParam, Any}, + group_name::AbstractString, + trial_id::AbstractString, + start_time_secs=Union{Float64, Nothing}) + if start_time_secs === nothing + start_time_secs = time() + end + + Summary_Value( + SESSION_START_INFO_TAG, + HParamsPluginData( + version = PLUGIN_DATA_VERSION, + session_start_info = SessionStartInfo( + group_name = group_name, + start_time_secs = start_time_secs, + hparams = Dict( + hparam.name => ProtoBuf.google.protobuf.Value(val) + for (hparam, val) ∈ hparams_dict + ) + ) + ) + ) +end + diff --git a/src/TensorBoardLogger.jl b/src/TensorBoardLogger.jl index 4f3c65cd..3bf3b40f 100644 --- a/src/TensorBoardLogger.jl +++ b/src/TensorBoardLogger.jl @@ -20,7 +20,8 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info, export TBLogger, reset!, set_step!, increment_step!, set_step_increment! export log_histogram, log_value, log_vector, log_text, log_image, log_images, - log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar + log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar, + log_hparams, log_hparams_config export map_summaries, TBReader export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC, @@ -30,7 +31,7 @@ export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC, export tb_multiline, tb_margin # Wrapper types -export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios +export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios, TBHParams # workaround for FileIO pre 1.6 include("FileIO_workaround.jl") @@ -40,6 +41,13 @@ include("protojl/tensorboard/tensorboard.jl") using .tensorboard: Summary_Value, GraphDef, Summary, Event, SessionLog_SessionStatus, SessionLog using .tensorboard: TensorShapeProto_Dim, TensorShapeProto, TextPluginData using .tensorboard: TensorProto, SummaryMetadata, SummaryMetadata_PluginData, _DataType +using .tensorboard.hparams: HParamsPluginData, Experiment, SessionStartInfo, SessionEndInfo, HParamInfo, MetricInfo, HParamInfo, Interval, MetricName, DatasetType +import .tensorboard.hparams +import .tensorboard: SummaryMetadata, Summary +import .tensorboard.hparams: HParamInfo, MetricInfo, Interval + +using ProtoBuf +import ProtoBuf.google.protobuf: Value, ListValue include("PNG.jl") using .PNGImage @@ -61,6 +69,9 @@ include("Loggers/LogEmbeddings.jl") # Custom Scalar Plugin include("Loggers/LogCustomScalar.jl") +include("Loggers/LogHParams.jl") + + include("logger_dispatch.jl") include("logger_dispatch_overrides.jl") diff --git a/src/logger_dispatch_overrides.jl b/src/logger_dispatch_overrides.jl index 1eb4af1c..f8b1b289 100644 --- a/src/logger_dispatch_overrides.jl +++ b/src/logger_dispatch_overrides.jl @@ -207,3 +207,22 @@ preprocess(name, val::TBVector{T,N}, data) where {T<:Complex,N} = push!(data, name*"/re"=>TBVector(real.(content(val))), name*"/im"=>TBVector(imag.(content(val)))) summary_impl(name, val::TBVector) = histogram_summary(name, collect(0:length(val.data)), val.data) + +########## Hyperparameters ######################## + +# FIXME: name unused? +summary_impl(name, val::HParamsConfig) = hparams_config_summary(val.data) +preprocess(name, val::HParamsConfig, data) = push!(data, name=>val) + +struct TBHParams <: WrapperLogType + # TODO: The types in the hparam domain and this dict's values are constrained. + # e.g. an hparam with a discrete domain of ["a", "b"] must have string values + # Consider ways to enforce this relationship in the type system. + data::Dict{HParam, Any} + # FIXME: group_name auto generated in the Python implementation (Tensorboard) + group_name::AbstractString + trial_id::AbstractString + start_time_secs::Union{Float64, Nothing} +end +content(x::TBHParams) = x.data +summary_impl(name, val::TBHParams) = hparams_summary(val.data, val.group_name, val.trial_id, val.start_time_secs) diff --git a/src/protojl/tensorboard/tensorboard.jl b/src/protojl/tensorboard/tensorboard.jl index 1851dba2..99a89b60 100644 --- a/src/protojl/tensorboard/tensorboard.jl +++ b/src/protojl/tensorboard/tensorboard.jl @@ -34,6 +34,10 @@ module tensorboard include("plugins/custom_scalar/layout_pb.jl") include("plugins/text/plugin_data_pb.jl") - #include("plugins/hparams/hparams.jl") - + # Needs separate module due to conflicting "_DataType" export + module hparams + include("plugins/hparams/api_pb.jl") + include("plugins/hparams/hparams_util_pb.jl") + include("plugins/hparams/plugin_data_pb.jl") + end end diff --git a/test/runtests.jl b/test/runtests.jl index ad6c36d6..21c52f89 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using TensorBoardLogger, Logging using TensorBoardLogger: preprocess, summary_impl +using TensorBoardLogger: IntervalDomain, DiscreteDomain, HParam, Metric, HParamsConfig using Test using TestImages using ImageCore @@ -270,6 +271,10 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true close.(values(logger.all_files)) end + + @testset "HParams Backend" begin + include("test_hparams.jl") + end @testset "Embedding Logger" begin logger = TBLogger(test_log_dir*"t", tb_overwrite) diff --git a/test/test_hparams.jl b/test/test_hparams.jl new file mode 100644 index 00000000..5b4d6260 --- /dev/null +++ b/test/test_hparams.jl @@ -0,0 +1,47 @@ +@testset "HParamConfig Logger" begin + logger = TBLogger(test_log_dir*"t", tb_overwrite) + step = 1 + + interval_domain = IntervalDomain(0.1, 3.0) + hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1") + + discrete_domain_strs = ["a", "b", "c"] + discrete_domain = DiscreteDomain(discrete_domain_strs) + hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2") + + hparams = [hparam1, hparam2] + + metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION) + metrics = [metric] + hparams_config = HParamsConfig(hparams, metrics, 1.2) + ss = TensorBoardLogger.hparams_config_summary(hparams_config) + + @test isa(ss, TensorBoardLogger.Summary_Value) + @test ss.tag == TensorBoardLogger.EXPERIMENT_TAG + + # TODO: Deserialize and test more properties + + log_hparams_config(logger, hparams_config ;step=step) +end + +@testset "HParams Logger" begin + logger = TBLogger(test_log_dir*"t", tb_overwrite) + step = 1 + + interval_domain = IntervalDomain(0.1, 3.0) + hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1") + + discrete_domain_strs = ["a", "b", "c"] + discrete_domain = DiscreteDomain(discrete_domain_strs) + hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2") + + hparams_dict = Dict(hparam1 => 1.2, hparam2 => "b") + + ss = TensorBoardLogger.hparams_summary(hparams_dict, "group_name", "trial_id", nothing) + + @test isa(ss, TensorBoardLogger.Summary_Value) + @test ss.tag == TensorBoardLogger.SESSION_START_INFO_TAG + + # TODO: Deserialize and test more properties + log_hparams(logger, hparams_dict, "group_name", "trial_id", nothing ;step=step) +end diff --git a/test/test_logger_dispatch_overrides.jl b/test/test_logger_dispatch_overrides.jl index 9717721d..2e67a4c2 100644 --- a/test/test_logger_dispatch_overrides.jl +++ b/test/test_logger_dispatch_overrides.jl @@ -1,5 +1,6 @@ using TensorBoardLogger, Test using TensorBoardLogger: preprocess, content +using TensorBoardLogger: TBHParams using TestImages using ImageCore @testset "TBText" begin @@ -61,3 +62,21 @@ end @test first(data) == ("test/1"=>TBAudio(y, sample_rate)) @test last(data) == ("test/2"=>TBAudio(y, sample_rate)) end + +@testset "HParamsConfig" begin + data = Vector{Pair{String,Any}}() + hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1") + metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION) + params_config = HParamsConfig([hparam], [metric], 1.2) + preprocess("test", params_config, data) + @test first(data) == ("test"=>params_config) +end + +@testset "TBHParams" begin + data = Vector{Pair{String,Any}}() + hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1") + hparams_dict = Dict(hparam => 1.2) + tbh_params = TBHParams(hparams_dict, "group_name", "trial_id", nothing) + preprocess("test", tbh_params, data) + @test first(data) == ("test"=>tbh_params) +end From 66744216c5654ecf20b56d21a596e0b8efc83723 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Sun, 25 Apr 2021 16:48:37 +0200 Subject: [PATCH 2/4] fixes --- src/Loggers/LogHParams.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Loggers/LogHParams.jl b/src/Loggers/LogHParams.jl index c992225b..3fef01ec 100644 --- a/src/Loggers/LogHParams.jl +++ b/src/Loggers/LogHParams.jl @@ -11,7 +11,7 @@ end DiscreteDomainElem = Union{String, Float64, Bool} -hparams_datatype_sym(::Type{String}) = :DATA_TYPE_STRING +hparams_datatype_sym(::Type{<:AbstractString}) = :DATA_TYPE_STRING hparams_datatype_sym(::Type{Float64}) = :DATA_TYPE_FLOAT64 hparams_datatype_sym(::Type{Bool}) = :DATA_TYPE_BOOL @@ -19,16 +19,16 @@ function hparams_datatype(domain::DiscreteDomain{T}) where T <: DiscreteDomainEl tensorboard.hparams._DataType[hparams_datatype_sym(T)] end -ProtoBuf.google.protobuf.Value(x::Number) = Value(number_value=x) -ProtoBuf.google.protobuf.Value(x::Bool) = Value(bool_value=x) -ProtoBuf.google.protobuf.Value(x::AbstractString) = Value(string_value=x) -function ProtoBuf.google.protobuf.Value(x) +# custom constructors for ProtoBuf.google.protobuf.Value +function _Protobuf_Value(x) @warn "Cannot create a ProtoBuf.google.protobuf.Value of type $(typeof(x)), defaulting to null." Value(null_value=Int32(0)) end +_Protobuf_Value(x::Bool) = Value(bool_value=x) +_Protobuf_Value(x::Number) = Value(number_value=x) +_Protobuf_Value(x::AbstractString) = Value(string_value=x) - -function ProtoBuf.google.protobuf.ListValue(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem +function _Protobuf_ListValue(domain::DiscreteDomain{T})where T <: DiscreteDomainElem ProtoBuf.google.protobuf.ListValue( values = ProtoBuf.google.protobuf.Value.(domain.values) ) @@ -58,7 +58,7 @@ function HParamInfo(hparam::HParam) else @assert isa(domain, DiscreteDomain) (_type = hparams_datatype(domain), - domain_discrete = ProtoBuf.google.protobuf.ListValue(domain)) + domain_discrete = _Protobuf_ListValue(domain)) end HParamInfo(;name = hparam.name, description = hparam.description, @@ -177,7 +177,7 @@ function hparams_summary(hparams_dict::Dict{HParam, Any}, group_name = group_name, start_time_secs = start_time_secs, hparams = Dict( - hparam.name => ProtoBuf.google.protobuf.Value(val) + hparam.name => _ProtoBuf_Value(val) for (hparam, val) ∈ hparams_dict ) ) From e52b37c7a4db31ef058c3f711c6aa4087b25b315 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Sun, 25 Apr 2021 17:02:36 +0200 Subject: [PATCH 3/4] address some revie commnts --- src/Loggers/LogHParams.jl | 33 +++++++++++++++------------------ test/runtests.jl | 2 -- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/Loggers/LogHParams.jl b/src/Loggers/LogHParams.jl index 3fef01ec..85bcdec5 100644 --- a/src/Loggers/LogHParams.jl +++ b/src/Loggers/LogHParams.jl @@ -20,17 +20,17 @@ function hparams_datatype(domain::DiscreteDomain{T}) where T <: DiscreteDomainEl end # custom constructors for ProtoBuf.google.protobuf.Value -function _Protobuf_Value(x) +function _ProtoBuf_Value(x) @warn "Cannot create a ProtoBuf.google.protobuf.Value of type $(typeof(x)), defaulting to null." Value(null_value=Int32(0)) end -_Protobuf_Value(x::Bool) = Value(bool_value=x) -_Protobuf_Value(x::Number) = Value(number_value=x) -_Protobuf_Value(x::AbstractString) = Value(string_value=x) +_ProtoBuf_Value(x::Bool) = Value(bool_value=x) +_ProtoBuf_Value(x::Number) = Value(number_value=x) +_ProtoBuf_Value(x::AbstractString) = Value(string_value=x) -function _Protobuf_ListValue(domain::DiscreteDomain{T})where T <: DiscreteDomainElem +function _ProtoBuf_ListValue(domain::DiscreteDomain{T})where T <: DiscreteDomainElem ProtoBuf.google.protobuf.ListValue( - values = ProtoBuf.google.protobuf.Value.(domain.values) + values = _ProtoBuf_Value.(domain.values) ) end @@ -58,7 +58,7 @@ function HParamInfo(hparam::HParam) else @assert isa(domain, DiscreteDomain) (_type = hparams_datatype(domain), - domain_discrete = _Protobuf_ListValue(domain)) + domain_discrete = _ProtoBuf_ListValue(domain)) end HParamInfo(;name = hparam.name, description = hparam.description, @@ -67,10 +67,10 @@ function HParamInfo(hparam::HParam) end struct Metric - tag::AbstractString - group::AbstractString - display_name::AbstractString - description::AbstractString + tag::String + group::String + display_name::String + description::String dataset_type::Symbol function Metric(tag::AbstractString, @@ -79,11 +79,8 @@ struct Metric description::AbstractString, dataset_type::Symbol) valid_dataset_types = keys(tensorboard.hparams.DatasetType) - if dataset_type ∉ valid_dataset_types - throw(ArgumentError("dataset_type of $(dataset_type) is not one of $(map(string, valid_dataset_types)).")) - else - new(tag, group, display_name, description, dataset_type) - end + dataset_type ∉ valid_dataset_types && throw(ArgumentError("dataset_type of $(dataset_type) is not one of $(map(string, valid_dataset_types)).")) + new(convert(String,tag), convert(String,group), convert(String, display_name), convert(String, description), dataset_type) end end @@ -100,8 +97,8 @@ function MetricInfo(metric::Metric) end struct HParamsConfig - hparams::AbstractVector{HParam} - metrics::AbstractVector{Metric} + hparams::Vector{HParam} + metrics::Vector{Metric} time_created_secs::Float64 end diff --git a/test/runtests.jl b/test/runtests.jl index 21c52f89..c77e7e5d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,8 +11,6 @@ test_log_dir = "test_logs/" ENV["DATADEPS_ALWAYS_ACCEPT"] = true ENV["GKSwstype"] = "100" -ENV["DATADEPS_ALWAYS_ACCEPT"] = true - @testset "TensorBoardLogger" begin @testset "TBLogger" begin From 35467db7bdb870279c7cd72d66a606e5efe6bfb4 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Sun, 25 Apr 2021 17:13:45 +0200 Subject: [PATCH 4/4] simplify tests --- test/runtests.jl | 1 - test/test_hparams.jl | 9 +++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index c77e7e5d..6698cbb0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,5 @@ using TensorBoardLogger, Logging using TensorBoardLogger: preprocess, summary_impl -using TensorBoardLogger: IntervalDomain, DiscreteDomain, HParam, Metric, HParamsConfig using Test using TestImages using ImageCore diff --git a/test/test_hparams.jl b/test/test_hparams.jl index 5b4d6260..000d6df3 100644 --- a/test/test_hparams.jl +++ b/test/test_hparams.jl @@ -1,3 +1,12 @@ +using TensorBoardLogger +using Logging +using TensorBoardLogger: preprocess, summary_impl +using TensorBoardLogger: IntervalDomain, DiscreteDomain, HParam, Metric, HParamsConfig +using Test + +test_log_dir = "test_logs/" +ENV["DATADEPS_ALWAYS_ACCEPT"] = true + @testset "HParamConfig Logger" begin logger = TBLogger(test_log_dir*"t", tb_overwrite) step = 1