-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(WIP) Add support for logging hyperparameters #87
Changes from all commits
56f85ce
91a1ce6
bd03b80
be20a3f
09da9cd
3e64864
e8b16c1
ef403c9
64b442a
b03b33e
81ef93a
3dd84fe
ad3b5ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,187 @@ | ||||||||||
PLUGIN_NAME = "hparams" | ||||||||||
PLUGIN_DATA_VERSION = 0 | ||||||||||
|
||||||||||
EXPERIMENT_TAG = "_hparams_/experiment" | ||||||||||
SESSION_START_INFO_TAG = "_hparams_/session_start_info" | ||||||||||
|
||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a few comments on what are |
||||||||||
struct DiscreteDomain{T} | ||||||||||
values::AbstractVector{T} | ||||||||||
end | ||||||||||
|
||||||||||
DiscreteDomainElem = Union{String, Float64, Bool} | ||||||||||
|
||||||||||
hparams_datatype_sym(::Type{String}) = :DATA_TYPE_STRING | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sometimes |
||||||||||
hparams_datatype_sym(::Type{Float64}) = :DATA_TYPE_FLOAT64 | ||||||||||
hparams_datatype_sym(::Type{Bool}) = :DATA_TYPE_BOOL | ||||||||||
|
||||||||||
function hparams_datatype(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem | ||||||||||
tensorboard.hparams._DataType[hparams_datatype_sym(T)] | ||||||||||
end | ||||||||||
|
||||||||||
ProtoBuf.google.protobuf.Value(x::Number) = Value(number_value=x) | ||||||||||
ProtoBuf.google.protobuf.Value(x::Bool) = Value(bool_value=x) | ||||||||||
ProtoBuf.google.protobuf.Value(x::AbstractString) = Value(string_value=x) | ||||||||||
function ProtoBuf.google.protobuf.Value(x) | ||||||||||
@warn "Cannot create a ProtoBuf.google.protobuf.Value of type $(typeof(x)), defaulting to null." | ||||||||||
Value(null_value=Int32(0)) | ||||||||||
end | ||||||||||
Comment on lines
+22
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is all type-piracy, and it might stop working in the future when something is changed in You should not define specialised methods for types we do not own on functions we do not own. |
||||||||||
|
||||||||||
|
||||||||||
function ProtoBuf.google.protobuf.ListValue(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above, type piracy. Don't import |
||||||||||
ProtoBuf.google.protobuf.ListValue( | ||||||||||
values = ProtoBuf.google.protobuf.Value.(domain.values) | ||||||||||
) | ||||||||||
end | ||||||||||
|
||||||||||
struct IntervalDomain | ||||||||||
min_value::Float64 | ||||||||||
max_value::Float64 | ||||||||||
end | ||||||||||
|
||||||||||
Interval(d::IntervalDomain) = Interval(min_value=d.min_value, max_value=d.max_value) | ||||||||||
|
||||||||||
HParamDomain = Union{IntervalDomain, DiscreteDomain} | ||||||||||
|
||||||||||
struct HParam | ||||||||||
name::AbstractString | ||||||||||
domain::HParamDomain | ||||||||||
display_name::AbstractString | ||||||||||
description::AbstractString | ||||||||||
end | ||||||||||
Comment on lines
+46
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we can only serialise standard tensor board types (Float32 and C-strings) I think that this here should be a type-stable struct with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, it depends: is this an internal type used for logging, or the user-accessible API-type? |
||||||||||
|
||||||||||
|
||||||||||
function HParamInfo(hparam::HParam) | ||||||||||
domain = hparam.domain | ||||||||||
domain_kwargs = if isa(domain, IntervalDomain) | ||||||||||
(;domain_interval = Interval(domain)) | ||||||||||
else | ||||||||||
@assert isa(domain, DiscreteDomain) | ||||||||||
(_type = hparams_datatype(domain), | ||||||||||
domain_discrete = ProtoBuf.google.protobuf.ListValue(domain)) | ||||||||||
end | ||||||||||
HParamInfo(;name = hparam.name, | ||||||||||
description = hparam.description, | ||||||||||
display_name = hparam.display_name, | ||||||||||
domain_kwargs...) | ||||||||||
end | ||||||||||
|
||||||||||
struct Metric | ||||||||||
tag::AbstractString | ||||||||||
group::AbstractString | ||||||||||
display_name::AbstractString | ||||||||||
description::AbstractString | ||||||||||
dataset_type::Symbol | ||||||||||
|
||||||||||
function Metric(tag::AbstractString, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the inner constructor really necessary? |
||||||||||
group::AbstractString, | ||||||||||
display_name::AbstractString, | ||||||||||
description::AbstractString, | ||||||||||
dataset_type::Symbol) | ||||||||||
valid_dataset_types = keys(tensorboard.hparams.DatasetType) | ||||||||||
if dataset_type ∉ valid_dataset_types | ||||||||||
throw(ArgumentError("dataset_type of $(dataset_type) is not one of $(map(string, valid_dataset_types)).")) | ||||||||||
else | ||||||||||
new(tag, group, display_name, description, dataset_type) | ||||||||||
end | ||||||||||
end | ||||||||||
end | ||||||||||
|
||||||||||
function MetricInfo(metric::Metric) | ||||||||||
MetricInfo( | ||||||||||
name=MetricName( | ||||||||||
group=metric.group, | ||||||||||
tag=metric.tag, | ||||||||||
), | ||||||||||
display_name=metric.display_name, | ||||||||||
description=metric.description, | ||||||||||
dataset_type=tensorboard.hparams.DatasetType[metric.dataset_type] | ||||||||||
) | ||||||||||
end | ||||||||||
|
||||||||||
struct HParamsConfig | ||||||||||
hparams::AbstractVector{HParam} | ||||||||||
metrics::AbstractVector{Metric} | ||||||||||
Comment on lines
+103
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
time_created_secs::Float64 | ||||||||||
end | ||||||||||
|
||||||||||
function SummaryMetadata(hparams_plugin_data::HParamsPluginData) | ||||||||||
SummaryMetadata( | ||||||||||
plugin_data = SummaryMetadata_PluginData( | ||||||||||
plugin_name = PLUGIN_NAME, | ||||||||||
content = serialize_proto(hparams_plugin_data) | ||||||||||
) | ||||||||||
) | ||||||||||
end | ||||||||||
|
||||||||||
function Summary_Value(tag, hparams_plugin_data::HParamsPluginData) | ||||||||||
null_tensor = TensorProto(dtype = _DataType.DT_FLOAT, tensor_shape = TensorShapeProto(dim=[])) | ||||||||||
Summary_Value( | ||||||||||
tag = tag, | ||||||||||
metadata = SummaryMetadata(hparams_plugin_data), | ||||||||||
tensor = null_tensor | ||||||||||
) | ||||||||||
end | ||||||||||
|
||||||||||
function log_hparams_config(logger::TBLogger, | ||||||||||
hparams_config::HParamsConfig; | ||||||||||
step=nothing) | ||||||||||
summ = SummaryCollection( | ||||||||||
hparams_config_summary(hparams_config) | ||||||||||
) | ||||||||||
write_event(logger.file, make_event(logger, summ, step=step)) | ||||||||||
end | ||||||||||
|
||||||||||
function hparams_config_summary(config::HParamsConfig) | ||||||||||
Summary_Value( | ||||||||||
EXPERIMENT_TAG, | ||||||||||
HParamsPluginData( | ||||||||||
version = PLUGIN_DATA_VERSION, | ||||||||||
experiment = Experiment( | ||||||||||
hparam_infos = HParamInfo.(config.hparams), | ||||||||||
metric_infos = MetricInfo.(config.metrics), | ||||||||||
time_created_secs = config.time_created_secs | ||||||||||
) | ||||||||||
) | ||||||||||
) | ||||||||||
end | ||||||||||
|
||||||||||
function log_hparams(logger::TBLogger, | ||||||||||
hparams_dict::Dict{HParam, Any}, | ||||||||||
group_name::AbstractString, | ||||||||||
trial_id::AbstractString, | ||||||||||
start_time_secs::Union{Float64, Nothing}; | ||||||||||
step=nothing) | ||||||||||
summ = SummaryCollection( | ||||||||||
hparams_summary(hparams_dict, | ||||||||||
group_name, | ||||||||||
trial_id, | ||||||||||
start_time_secs) | ||||||||||
) | ||||||||||
write_event(logger.file, make_event(logger, summ, step=step)) | ||||||||||
end | ||||||||||
|
||||||||||
function hparams_summary(hparams_dict::Dict{HParam, Any}, | ||||||||||
group_name::AbstractString, | ||||||||||
trial_id::AbstractString, | ||||||||||
start_time_secs=Union{Float64, Nothing}) | ||||||||||
if start_time_secs === nothing | ||||||||||
start_time_secs = time() | ||||||||||
end | ||||||||||
|
||||||||||
Summary_Value( | ||||||||||
SESSION_START_INFO_TAG, | ||||||||||
HParamsPluginData( | ||||||||||
version = PLUGIN_DATA_VERSION, | ||||||||||
session_start_info = SessionStartInfo( | ||||||||||
group_name = group_name, | ||||||||||
start_time_secs = start_time_secs, | ||||||||||
hparams = Dict( | ||||||||||
hparam.name => ProtoBuf.google.protobuf.Value(val) | ||||||||||
for (hparam, val) ∈ hparams_dict | ||||||||||
) | ||||||||||
) | ||||||||||
) | ||||||||||
) | ||||||||||
end | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -207,3 +207,22 @@ preprocess(name, val::TBVector{T,N}, data) where {T<:Complex,N} = | |
push!(data, name*"/re"=>TBVector(real.(content(val))), name*"/im"=>TBVector(imag.(content(val)))) | ||
summary_impl(name, val::TBVector) = histogram_summary(name, collect(0:length(val.data)), | ||
val.data) | ||
|
||
########## Hyperparameters ######################## | ||
|
||
# FIXME: name unused? | ||
summary_impl(name, val::HParamsConfig) = hparams_config_summary(val.data) | ||
preprocess(name, val::HParamsConfig, data) = push!(data, name=>val) | ||
|
||
struct TBHParams <: WrapperLogType | ||
# TODO: The types in the hparam domain and this dict's values are constrained. | ||
# e.g. an hparam with a discrete domain of ["a", "b"] must have string values | ||
# Consider ways to enforce this relationship in the type system. | ||
data::Dict{HParam, Any} | ||
Comment on lines
+218
to
+221
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain this better? I don't understand the todo |
||
# FIXME: group_name auto generated in the Python implementation (Tensorboard) | ||
group_name::AbstractString | ||
Comment on lines
+222
to
+223
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrapper types are essentially the user-api. if this is an internal field autogenerated, then you should remove it (and auto-generate it when necessary). |
||
trial_id::AbstractString | ||
start_time_secs::Union{Float64, Nothing} | ||
end | ||
content(x::TBHParams) = x.data | ||
summary_impl(name, val::TBHParams) = hparams_summary(val.data, val.group_name, val.trial_id, val.start_time_secs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You got those from tensorboardx.py?