Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Add support for logging hyperparameters #87

Closed
wants to merge 13 commits into from
187 changes: 187 additions & 0 deletions src/Loggers/LogHParams.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
PLUGIN_NAME = "hparams"
PLUGIN_DATA_VERSION = 0

EXPERIMENT_TAG = "_hparams_/experiment"
SESSION_START_INFO_TAG = "_hparams_/session_start_info"
Comment on lines +4 to +5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You got those from tensorboardx.py?



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a few comments on what are Domain types used for? I guess it's to encode the domain of a hyper parameter but still... it would make the code more clear for future maintenance

struct DiscreteDomain{T}
values::AbstractVector{T}
end

DiscreteDomainElem = Union{String, Float64, Bool}

hparams_datatype_sym(::Type{String}) = :DATA_TYPE_STRING
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes SubString might be passed in. I'd support at least that one too.
Ideally, I'd say that this should be ::Type{<:AbstractString} and a conversion for exotic types should be made where appropriate 

hparams_datatype_sym(::Type{Float64}) = :DATA_TYPE_FLOAT64
hparams_datatype_sym(::Type{Bool}) = :DATA_TYPE_BOOL

function hparams_datatype(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem
tensorboard.hparams._DataType[hparams_datatype_sym(T)]
end

ProtoBuf.google.protobuf.Value(x::Number) = Value(number_value=x)
ProtoBuf.google.protobuf.Value(x::Bool) = Value(bool_value=x)
ProtoBuf.google.protobuf.Value(x::AbstractString) = Value(string_value=x)
function ProtoBuf.google.protobuf.Value(x)
@warn "Cannot create a ProtoBuf.google.protobuf.Value of type $(typeof(x)), defaulting to null."
Value(null_value=Int32(0))
end
Comment on lines +22 to +28
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all type-piracy, and it might stop working in the future when something is changed in ProtoBuf.jl.
Is there a specific reason you need this, and can not create your own Value function that you use locally?

You should not define specialised methods for types we do not own on functions we do not own.
Maybe, do not import Protobuf.google.protobuf.Value, define your own Value functions forwarding to Protobuf.google.protobuf.Value.



function ProtoBuf.google.protobuf.ListValue(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, type piracy. Don't import ListValue and create your own function

ProtoBuf.google.protobuf.ListValue(
values = ProtoBuf.google.protobuf.Value.(domain.values)
)
end

struct IntervalDomain
min_value::Float64
max_value::Float64
end

Interval(d::IntervalDomain) = Interval(min_value=d.min_value, max_value=d.max_value)

HParamDomain = Union{IntervalDomain, DiscreteDomain}

struct HParam
name::AbstractString
domain::HParamDomain
display_name::AbstractString
description::AbstractString
end
Comment on lines +46 to +51
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we can only serialise standard tensor board types (Float32 and C-strings) I think that this here should be a type-stable struct with String types instead of AbstractStrings and proper conversion should be done in the constructor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it depends: is this an internal type used for logging, or the user-accessible API-type?



function HParamInfo(hparam::HParam)
domain = hparam.domain
domain_kwargs = if isa(domain, IntervalDomain)
(;domain_interval = Interval(domain))
else
@assert isa(domain, DiscreteDomain)
(_type = hparams_datatype(domain),
domain_discrete = ProtoBuf.google.protobuf.ListValue(domain))
end
HParamInfo(;name = hparam.name,
description = hparam.description,
display_name = hparam.display_name,
domain_kwargs...)
end

struct Metric
tag::AbstractString
group::AbstractString
display_name::AbstractString
description::AbstractString
dataset_type::Symbol

function Metric(tag::AbstractString,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the inner constructor really necessary?

group::AbstractString,
display_name::AbstractString,
description::AbstractString,
dataset_type::Symbol)
valid_dataset_types = keys(tensorboard.hparams.DatasetType)
if dataset_type ∉ valid_dataset_types
throw(ArgumentError("dataset_type of $(dataset_type) is not one of $(map(string, valid_dataset_types))."))
else
new(tag, group, display_name, description, dataset_type)
end
end
end

function MetricInfo(metric::Metric)
MetricInfo(
name=MetricName(
group=metric.group,
tag=metric.tag,
),
display_name=metric.display_name,
description=metric.description,
dataset_type=tensorboard.hparams.DatasetType[metric.dataset_type]
)
end

struct HParamsConfig
hparams::AbstractVector{HParam}
metrics::AbstractVector{Metric}
Comment on lines +103 to +104
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hparams::AbstractVector{HParam}
metrics::AbstractVector{Metric}
hparams::Vector{HParam}
metrics::Vector{Metric}

time_created_secs::Float64
end

function SummaryMetadata(hparams_plugin_data::HParamsPluginData)
SummaryMetadata(
plugin_data = SummaryMetadata_PluginData(
plugin_name = PLUGIN_NAME,
content = serialize_proto(hparams_plugin_data)
)
)
end

function Summary_Value(tag, hparams_plugin_data::HParamsPluginData)
null_tensor = TensorProto(dtype = _DataType.DT_FLOAT, tensor_shape = TensorShapeProto(dim=[]))
Summary_Value(
tag = tag,
metadata = SummaryMetadata(hparams_plugin_data),
tensor = null_tensor
)
end

function log_hparams_config(logger::TBLogger,
hparams_config::HParamsConfig;
step=nothing)
summ = SummaryCollection(
hparams_config_summary(hparams_config)
)
write_event(logger.file, make_event(logger, summ, step=step))
end

function hparams_config_summary(config::HParamsConfig)
Summary_Value(
EXPERIMENT_TAG,
HParamsPluginData(
version = PLUGIN_DATA_VERSION,
experiment = Experiment(
hparam_infos = HParamInfo.(config.hparams),
metric_infos = MetricInfo.(config.metrics),
time_created_secs = config.time_created_secs
)
)
)
end

function log_hparams(logger::TBLogger,
hparams_dict::Dict{HParam, Any},
group_name::AbstractString,
trial_id::AbstractString,
start_time_secs::Union{Float64, Nothing};
step=nothing)
summ = SummaryCollection(
hparams_summary(hparams_dict,
group_name,
trial_id,
start_time_secs)
)
write_event(logger.file, make_event(logger, summ, step=step))
end

function hparams_summary(hparams_dict::Dict{HParam, Any},
group_name::AbstractString,
trial_id::AbstractString,
start_time_secs=Union{Float64, Nothing})
if start_time_secs === nothing
start_time_secs = time()
end

Summary_Value(
SESSION_START_INFO_TAG,
HParamsPluginData(
version = PLUGIN_DATA_VERSION,
session_start_info = SessionStartInfo(
group_name = group_name,
start_time_secs = start_time_secs,
hparams = Dict(
hparam.name => ProtoBuf.google.protobuf.Value(val)
for (hparam, val) ∈ hparams_dict
)
)
)
)
end

15 changes: 13 additions & 2 deletions src/TensorBoardLogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info,

export TBLogger, reset!, set_step!, increment_step!, set_step_increment!
export log_histogram, log_value, log_vector, log_text, log_image, log_images,
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar,
log_hparams, log_hparams_config
export map_summaries

export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC,
Expand All @@ -30,13 +31,20 @@ export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC,
export tb_multiline, tb_margin

# Wrapper types
export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios
export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios, TBHParams

# Protobuffer definitions for tensorboard
include("protojl/tensorboard/tensorboard.jl")
using .tensorboard: Summary_Value, GraphDef, Summary, Event, SessionLog_SessionStatus, SessionLog
using .tensorboard: TensorShapeProto_Dim, TensorShapeProto, TextPluginData
using .tensorboard: TensorProto, SummaryMetadata, SummaryMetadata_PluginData, _DataType
using .tensorboard.hparams: HParamsPluginData, Experiment, SessionStartInfo, SessionEndInfo, HParamInfo, MetricInfo, HParamInfo, Interval, MetricName, DatasetType
import .tensorboard.hparams
import .tensorboard: SummaryMetadata, Summary
import .tensorboard.hparams: HParamInfo, MetricInfo, Interval

using ProtoBuf
import ProtoBuf.google.protobuf: Value, ListValue

include("PNG.jl")
using .PNGImage
Expand All @@ -58,6 +66,9 @@ include("Loggers/LogEmbeddings.jl")
# Custom Scalar Plugin
include("Loggers/LogCustomScalar.jl")

include("Loggers/LogHParams.jl")


include("logger_dispatch.jl")
include("logger_dispatch_overrides.jl")

Expand Down
19 changes: 19 additions & 0 deletions src/logger_dispatch_overrides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,22 @@ preprocess(name, val::TBVector{T,N}, data) where {T<:Complex,N} =
push!(data, name*"/re"=>TBVector(real.(content(val))), name*"/im"=>TBVector(imag.(content(val))))
summary_impl(name, val::TBVector) = histogram_summary(name, collect(0:length(val.data)),
val.data)

########## Hyperparameters ########################

# FIXME: name unused?
summary_impl(name, val::HParamsConfig) = hparams_config_summary(val.data)
preprocess(name, val::HParamsConfig, data) = push!(data, name=>val)

struct TBHParams <: WrapperLogType
# TODO: The types in the hparam domain and this dict's values are constrained.
# e.g. an hparam with a discrete domain of ["a", "b"] must have string values
# Consider ways to enforce this relationship in the type system.
data::Dict{HParam, Any}
Comment on lines +218 to +221
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this better? I don't understand the todo

# FIXME: group_name auto generated in the Python implementation (Tensorboard)
group_name::AbstractString
Comment on lines +222 to +223
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapper types are essentially the user-api. if this is an internal field autogenerated, then you should remove it (and auto-generate it when necessary).
if it's optional and in some cases it's useful to define it, then use a constructor to deal with the two cases.

trial_id::AbstractString
start_time_secs::Union{Float64, Nothing}
end
content(x::TBHParams) = x.data
summary_impl(name, val::TBHParams) = hparams_summary(val.data, val.group_name, val.trial_id, val.start_time_secs)
8 changes: 6 additions & 2 deletions src/protojl/tensorboard/tensorboard.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ module tensorboard
include("plugins/custom_scalar/layout_pb.jl")
include("plugins/text/plugin_data_pb.jl")

#include("plugins/hparams/hparams.jl")

# Needs separate module due to conflicting "_DataType" export
module hparams
include("plugins/hparams/api_pb.jl")
include("plugins/hparams/hparams_util_pb.jl")
include("plugins/hparams/plugin_data_pb.jl")
end
end
52 changes: 51 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using TensorBoardLogger, Logging
using TensorBoardLogger: preprocess, summary_impl
using TensorBoardLogger: IntervalDomain, DiscreteDomain, HParam, Metric, HParamsConfig
using Test
using TestImages
using ImageCore
Expand Down Expand Up @@ -139,7 +140,7 @@ end
@test π != log_image(logger, "rand/LCN", rand(10, 3, 2), LCN, step = step)

if VERSION >= v"1.3.0"
using MLDatasets: MNIST
using MLDatasets: MNIST

sample = MNIST.traintensor(1:3)
@test π != log_image(logger, "mnist/HWN", sample, HWN, step = step)
Expand Down Expand Up @@ -268,6 +269,55 @@ end
@test π != log_embeddings(logger, "random2", mat, step = step+1)
end

@testset "HParamConfig Logger" begin
logger = TBLogger(test_log_dir*"t", tb_overwrite)
step = 1

interval_domain = IntervalDomain(0.1, 3.0)
hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1")

discrete_domain_strs = ["a", "b", "c"]
discrete_domain = DiscreteDomain(discrete_domain_strs)
hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2")

hparams = [hparam1, hparam2]

metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION)
metrics = [metric]
hparams_config = HParamsConfig(hparams, metrics, 1.2)
ss = TensorBoardLogger.hparams_config_summary(hparams_config)

@test isa(ss, TensorBoardLogger.Summary_Value)
@test ss.tag == TensorBoardLogger.EXPERIMENT_TAG

# TODO: Deserialize and test more properties

log_hparams_config(logger, hparams_config ;step=step)
end

@testset "HParams Logger" begin
logger = TBLogger(test_log_dir*"t", tb_overwrite)
step = 1

interval_domain = IntervalDomain(0.1, 3.0)
hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1")

discrete_domain_strs = ["a", "b", "c"]
discrete_domain = DiscreteDomain(discrete_domain_strs)
hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2")

hparams_dict = Dict(hparam1 => 1.2, hparam2 => "b")

ss = TensorBoardLogger.hparams_summary(hparams_dict, "group_name", "trial_id", nothing)

@test isa(ss, TensorBoardLogger.Summary_Value)
@test ss.tag == TensorBoardLogger.SESSION_START_INFO_TAG

# TODO: Deserialize and test more properties
log_hparams(logger, hparams_dict, "group_name", "trial_id", nothing ;step=step)
end


@testset "Logger dispatch overrides" begin
include("test_logger_dispatch_overrides.jl")
end
Expand Down
19 changes: 19 additions & 0 deletions test/test_logger_dispatch_overrides.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using TensorBoardLogger, Test
using TensorBoardLogger: preprocess, content
using TensorBoardLogger: TBHParams
using TestImages
using ImageCore
@testset "TBText" begin
Expand Down Expand Up @@ -61,3 +62,21 @@ end
@test first(data) == ("test/1"=>TBAudio(y, sample_rate))
@test last(data) == ("test/2"=>TBAudio(y, sample_rate))
end

@testset "HParamsConfig" begin
data = Vector{Pair{String,Any}}()
hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1")
metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION)
params_config = HParamsConfig([hparam], [metric], 1.2)
preprocess("test", params_config, data)
@test first(data) == ("test"=>params_config)
end

@testset "TBHParams" begin
data = Vector{Pair{String,Any}}()
hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1")
hparams_dict = Dict(hparam => 1.2)
tbh_params = TBHParams(hparams_dict, "group_name", "trial_id", nothing)
preprocess("test", tbh_params, data)
@test first(data) == ("test"=>tbh_params)
end