Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial work towards hparam support (rebase of #87) #99

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions src/Loggers/LogHParams.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
PLUGIN_NAME = "hparams"
PLUGIN_DATA_VERSION = 0

EXPERIMENT_TAG = "_hparams_/experiment"
SESSION_START_INFO_TAG = "_hparams_/session_start_info"


struct DiscreteDomain{T}
values::AbstractVector{T}
end

DiscreteDomainElem = Union{String, Float64, Bool}

hparams_datatype_sym(::Type{<:AbstractString}) = :DATA_TYPE_STRING
hparams_datatype_sym(::Type{Float64}) = :DATA_TYPE_FLOAT64
hparams_datatype_sym(::Type{Bool}) = :DATA_TYPE_BOOL

function hparams_datatype(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem
tensorboard.hparams._DataType[hparams_datatype_sym(T)]
end

# custom constructors for ProtoBuf.google.protobuf.Value
function _ProtoBuf_Value(x)
@warn "Cannot create a ProtoBuf.google.protobuf.Value of type $(typeof(x)), defaulting to null."
Value(null_value=Int32(0))
end
_ProtoBuf_Value(x::Bool) = Value(bool_value=x)
_ProtoBuf_Value(x::Number) = Value(number_value=x)
_ProtoBuf_Value(x::AbstractString) = Value(string_value=x)

function _ProtoBuf_ListValue(domain::DiscreteDomain{T})where T <: DiscreteDomainElem
ProtoBuf.google.protobuf.ListValue(
values = _ProtoBuf_Value.(domain.values)
)
end

struct IntervalDomain
min_value::Float64
max_value::Float64
end

Interval(d::IntervalDomain) = Interval(min_value=d.min_value, max_value=d.max_value)

HParamDomain = Union{IntervalDomain, DiscreteDomain}

struct HParam
name::AbstractString
domain::HParamDomain
display_name::AbstractString
description::AbstractString
end


function HParamInfo(hparam::HParam)
domain = hparam.domain
domain_kwargs = if isa(domain, IntervalDomain)
(;domain_interval = Interval(domain))
else
@assert isa(domain, DiscreteDomain)
(_type = hparams_datatype(domain),
domain_discrete = _ProtoBuf_ListValue(domain))
end
HParamInfo(;name = hparam.name,
description = hparam.description,
display_name = hparam.display_name,
domain_kwargs...)
end

struct Metric
tag::String
group::String
display_name::String
description::String
dataset_type::Symbol

function Metric(tag::AbstractString,
group::AbstractString,
display_name::AbstractString,
description::AbstractString,
dataset_type::Symbol)
valid_dataset_types = keys(tensorboard.hparams.DatasetType)
dataset_type ∉ valid_dataset_types && throw(ArgumentError("dataset_type of $(dataset_type) is not one of $(map(string, valid_dataset_types))."))
new(convert(String,tag), convert(String,group), convert(String, display_name), convert(String, description), dataset_type)
end
end

function MetricInfo(metric::Metric)
MetricInfo(
name=MetricName(
group=metric.group,
tag=metric.tag,
),
display_name=metric.display_name,
description=metric.description,
dataset_type=tensorboard.hparams.DatasetType[metric.dataset_type]
)
end

struct HParamsConfig
hparams::Vector{HParam}
metrics::Vector{Metric}
time_created_secs::Float64
end

function SummaryMetadata(hparams_plugin_data::HParamsPluginData)
SummaryMetadata(
plugin_data = SummaryMetadata_PluginData(
plugin_name = PLUGIN_NAME,
content = serialize_proto(hparams_plugin_data)
)
)
end

function Summary_Value(tag, hparams_plugin_data::HParamsPluginData)
null_tensor = TensorProto(dtype = _DataType.DT_FLOAT, tensor_shape = TensorShapeProto(dim=[]))
Summary_Value(
tag = tag,
metadata = SummaryMetadata(hparams_plugin_data),
tensor = null_tensor
)
end

function log_hparams_config(logger::TBLogger,
hparams_config::HParamsConfig;
step=nothing)
summ = SummaryCollection(
hparams_config_summary(hparams_config)
)
write_event(logger.file, make_event(logger, summ, step=step))
end

function hparams_config_summary(config::HParamsConfig)
Summary_Value(
EXPERIMENT_TAG,
HParamsPluginData(
version = PLUGIN_DATA_VERSION,
experiment = Experiment(
hparam_infos = HParamInfo.(config.hparams),
metric_infos = MetricInfo.(config.metrics),
time_created_secs = config.time_created_secs
)
)
)
end

function log_hparams(logger::TBLogger,
hparams_dict::Dict{HParam, Any},
group_name::AbstractString,
trial_id::AbstractString,
start_time_secs::Union{Float64, Nothing};
step=nothing)
summ = SummaryCollection(
hparams_summary(hparams_dict,
group_name,
trial_id,
start_time_secs)
)
write_event(logger.file, make_event(logger, summ, step=step))
end

function hparams_summary(hparams_dict::Dict{HParam, Any},
group_name::AbstractString,
trial_id::AbstractString,
start_time_secs=Union{Float64, Nothing})
if start_time_secs === nothing
start_time_secs = time()
end

Summary_Value(
SESSION_START_INFO_TAG,
HParamsPluginData(
version = PLUGIN_DATA_VERSION,
session_start_info = SessionStartInfo(
group_name = group_name,
start_time_secs = start_time_secs,
hparams = Dict(
hparam.name => _ProtoBuf_Value(val)
for (hparam, val) ∈ hparams_dict
)
)
)
)
end

15 changes: 13 additions & 2 deletions src/TensorBoardLogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info,

export TBLogger, reset!, set_step!, increment_step!, set_step_increment!
export log_histogram, log_value, log_vector, log_text, log_image, log_images,
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar,
log_hparams, log_hparams_config
export map_summaries, TBReader

export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC,
Expand All @@ -30,7 +31,7 @@ export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC,
export tb_multiline, tb_margin

# Wrapper types
export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios
export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios, TBHParams

# workaround for FileIO pre 1.6
include("FileIO_workaround.jl")
Expand All @@ -40,6 +41,13 @@ include("protojl/tensorboard/tensorboard.jl")
using .tensorboard: Summary_Value, GraphDef, Summary, Event, SessionLog_SessionStatus, SessionLog
using .tensorboard: TensorShapeProto_Dim, TensorShapeProto, TextPluginData
using .tensorboard: TensorProto, SummaryMetadata, SummaryMetadata_PluginData, _DataType
using .tensorboard.hparams: HParamsPluginData, Experiment, SessionStartInfo, SessionEndInfo, HParamInfo, MetricInfo, HParamInfo, Interval, MetricName, DatasetType
import .tensorboard.hparams
import .tensorboard: SummaryMetadata, Summary
import .tensorboard.hparams: HParamInfo, MetricInfo, Interval

using ProtoBuf
import ProtoBuf.google.protobuf: Value, ListValue

include("PNG.jl")
using .PNGImage
Expand All @@ -61,6 +69,9 @@ include("Loggers/LogEmbeddings.jl")
# Custom Scalar Plugin
include("Loggers/LogCustomScalar.jl")

include("Loggers/LogHParams.jl")


include("logger_dispatch.jl")
include("logger_dispatch_overrides.jl")

Expand Down
19 changes: 19 additions & 0 deletions src/logger_dispatch_overrides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,22 @@ preprocess(name, val::TBVector{T,N}, data) where {T<:Complex,N} =
push!(data, name*"/re"=>TBVector(real.(content(val))), name*"/im"=>TBVector(imag.(content(val))))
summary_impl(name, val::TBVector) = histogram_summary(name, collect(0:length(val.data)),
val.data)

########## Hyperparameters ########################

# FIXME: name unused?
summary_impl(name, val::HParamsConfig) = hparams_config_summary(val.data)
preprocess(name, val::HParamsConfig, data) = push!(data, name=>val)

struct TBHParams <: WrapperLogType
# TODO: The types in the hparam domain and this dict's values are constrained.
# e.g. an hparam with a discrete domain of ["a", "b"] must have string values
# Consider ways to enforce this relationship in the type system.
data::Dict{HParam, Any}
# FIXME: group_name auto generated in the Python implementation (Tensorboard)
group_name::AbstractString
trial_id::AbstractString
start_time_secs::Union{Float64, Nothing}
end
content(x::TBHParams) = x.data
summary_impl(name, val::TBHParams) = hparams_summary(val.data, val.group_name, val.trial_id, val.start_time_secs)
8 changes: 6 additions & 2 deletions src/protojl/tensorboard/tensorboard.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ module tensorboard
include("plugins/custom_scalar/layout_pb.jl")
include("plugins/text/plugin_data_pb.jl")

#include("plugins/hparams/hparams.jl")

# Needs separate module due to conflicting "_DataType" export
module hparams
include("plugins/hparams/api_pb.jl")
include("plugins/hparams/hparams_util_pb.jl")
include("plugins/hparams/plugin_data_pb.jl")
end
end
6 changes: 4 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ test_log_dir = "test_logs/"
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
ENV["GKSwstype"] = "100"

ENV["DATADEPS_ALWAYS_ACCEPT"] = true

@testset "TensorBoardLogger" begin

@testset "TBLogger" begin
Expand Down Expand Up @@ -270,6 +268,10 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true

close.(values(logger.all_files))
end

@testset "HParams Backend" begin
include("test_hparams.jl")
end

@testset "Embedding Logger" begin
logger = TBLogger(test_log_dir*"t", tb_overwrite)
Expand Down
56 changes: 56 additions & 0 deletions test/test_hparams.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using TensorBoardLogger
using Logging
using TensorBoardLogger: preprocess, summary_impl
using TensorBoardLogger: IntervalDomain, DiscreteDomain, HParam, Metric, HParamsConfig
using Test

test_log_dir = "test_logs/"
ENV["DATADEPS_ALWAYS_ACCEPT"] = true

@testset "HParamConfig Logger" begin
logger = TBLogger(test_log_dir*"t", tb_overwrite)
step = 1

interval_domain = IntervalDomain(0.1, 3.0)
hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1")

discrete_domain_strs = ["a", "b", "c"]
discrete_domain = DiscreteDomain(discrete_domain_strs)
hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2")

hparams = [hparam1, hparam2]

metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION)
metrics = [metric]
hparams_config = HParamsConfig(hparams, metrics, 1.2)
ss = TensorBoardLogger.hparams_config_summary(hparams_config)

@test isa(ss, TensorBoardLogger.Summary_Value)
@test ss.tag == TensorBoardLogger.EXPERIMENT_TAG

# TODO: Deserialize and test more properties

log_hparams_config(logger, hparams_config ;step=step)
end

@testset "HParams Logger" begin
logger = TBLogger(test_log_dir*"t", tb_overwrite)
step = 1

interval_domain = IntervalDomain(0.1, 3.0)
hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1")

discrete_domain_strs = ["a", "b", "c"]
discrete_domain = DiscreteDomain(discrete_domain_strs)
hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2")

hparams_dict = Dict(hparam1 => 1.2, hparam2 => "b")

ss = TensorBoardLogger.hparams_summary(hparams_dict, "group_name", "trial_id", nothing)

@test isa(ss, TensorBoardLogger.Summary_Value)
@test ss.tag == TensorBoardLogger.SESSION_START_INFO_TAG

# TODO: Deserialize and test more properties
log_hparams(logger, hparams_dict, "group_name", "trial_id", nothing ;step=step)
end
19 changes: 19 additions & 0 deletions test/test_logger_dispatch_overrides.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using TensorBoardLogger, Test
using TensorBoardLogger: preprocess, content
using TensorBoardLogger: TBHParams
using TestImages
using ImageCore
@testset "TBText" begin
Expand Down Expand Up @@ -61,3 +62,21 @@ end
@test first(data) == ("test/1"=>TBAudio(y, sample_rate))
@test last(data) == ("test/2"=>TBAudio(y, sample_rate))
end

@testset "HParamsConfig" begin
data = Vector{Pair{String,Any}}()
hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1")
metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION)
params_config = HParamsConfig([hparam], [metric], 1.2)
preprocess("test", params_config, data)
@test first(data) == ("test"=>params_config)
end

@testset "TBHParams" begin
data = Vector{Pair{String,Any}}()
hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1")
hparams_dict = Dict(hparam => 1.2)
tbh_params = TBHParams(hparams_dict, "group_name", "trial_id", nothing)
preprocess("test", tbh_params, data)
@test first(data) == ("test"=>tbh_params)
end