Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hparams API #129

Merged
merged 31 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
551581c
Started working on adding an API to use the hparams plugin
JamieMair Jul 11, 2023
e4efbf8
Added vscode settings folder to gitignore
JamieMair Jul 12, 2023
49e1797
First draft attempt at adding hyperparameter support
JamieMair Jul 12, 2023
4cc8112
First pass at implementing hparams with unit tests for initialisation
JamieMair Jul 12, 2023
6a2269b
Pkg sorted the project file
JamieMair Jul 12, 2023
099c947
Almost fixed hyperparameters
JamieMair Jul 12, 2023
234dc12
Cleaned up the hparams code a little
JamieMair Jul 13, 2023
3323b7d
Switched to make_event instead
JamieMair Jul 13, 2023
b8fbbb6
Removed .gitignore from readme
JamieMair Jul 16, 2023
867b105
Tidy up and use normal PB OneOf, switch to allowing Real instead of j…
JamieMair Jul 16, 2023
89a34e9
Simplified imports
JamieMair Jul 16, 2023
cf31314
Moved export and wrote a docstring for the API method
JamieMair Jul 16, 2023
a357f35
Removed typing from dict
JamieMair Jul 16, 2023
c02cd5a
Added a note of what to do next
JamieMair Jul 16, 2023
532fa36
Fixed typo in project toml for `gen`
JamieMair Jul 17, 2023
58ba28c
Convert reals to floats to comply with hparams plugin
JamieMair Jul 17, 2023
7e5b00c
Move initialisations to match how they are written to the file
JamieMair Jul 17, 2023
627db33
Add an overload for writing the dictionary
JamieMair Jul 17, 2023
08c89cd
Implemented a new encoder for the hyperparameter dict - needs cleaning
JamieMair Jul 17, 2023
13daad7
Tidied up the functional code and added explanations
JamieMair Jul 17, 2023
0aae85b
Made the test cases cover more ground
JamieMair Jul 17, 2023
6aedb03
Updated the hparams script
JamieMair Jul 17, 2023
afbd961
Tweak params for hparams example script
JamieMair Jul 17, 2023
c4e50fb
Fixed restrictive type information
JamieMair Jul 17, 2023
dac9c1e
Added some very simple docs to detail usage of API
JamieMair Jul 17, 2023
f6f9ca9
Completed the documentation for the feature
JamieMair Jul 17, 2023
022c3b8
Fixed assumption of single byte size
JamieMair Jul 18, 2023
c5e2a3f
Added a decode implementation as well
JamieMair Jul 18, 2023
0a6f14c
Added a unit test to ensure the hparams content is being serialised c…
JamieMair Jul 18, 2023
1dfa89f
Separated the hparams encoding/decoding tests
JamieMair Jul 18, 2023
e4b5ff9
Merge branch 'JuliaLogging:master' into add-hparams-api
JamieMair Jul 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ test/test_logs
docs/Manifest.toml

gen/proto
gen/protojl
gen/protojl
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ StatsBase = "0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
julia = "1.6"

[extras]
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
Fontconfig = "186bb1d3-e1f7-5a2c-a377-96d770f13627"
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Minio = "4281f0d9-7ae0-406e-9172-b7277c1efa20"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
Cairo="159f3aea-2a34-519c-b102-8c37f9878175"
Fontconfig="186bb1d3-e1f7-5a2c-a377-96d770f13627"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ makedocs(
"Backends" => "custom_behaviour.md",
"Reading back data" => "deserialization.md",
"Extending" => "extending_behaviour.md",
"Explicit Interface" => "explicit_interface.md"
"Explicit Interface" => "explicit_interface.md",
"Hyperparameter logging" => "hyperparameters.md"
],
"Examples" => Any[
"Flux.jl" => "examples/flux.md"
"Optim.jl" => "examples/optim.md"
"Hyperparameter tuning" => "examples/hyperparameter_tuning.md"
]
],
format = Documenter.HTML(
Expand Down
53 changes: 53 additions & 0 deletions docs/src/examples/hyperparameter_tuning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Hyperparameter tuning

We will start this example by setting up a simple random walk experiment, and seeing the effect of the hyperparameter `bias` on the results.

First, import the packages we will need with:
```julia
using TensorBoardLogger, Logging
using Random
```
Next, we will create a function which runs the experiment and logs the results, include the hyperparameters stored in the `config` dictionary.
```julia
function run_experiment(id, config)
logger = TBLogger("random_walk/run$id", tb_append)

# Specify all the metrics we want to track in a list
metric_names = ["scalar/position"]
write_hparams!(logger, config, metric_names)

epochs = config["epochs"]
sigma = config["sigma"]
bias = config["bias"]
with_logger(logger) do
x = 0.0
for i in 1:epochs
x += sigma * randn() + bias
@info "scalar" position = x
end
end
nothing
end
```
Now we can write a script which runs an experiment over a set of parameter values.
```julia
id = 0
for bias in LinRange(-0.1, 0.1, 11)
for epochs in [50, 100]
config = Dict(
"bias"=>bias,
"epochs"=>epochs,
"sigma"=>0.1
)
run_experiment(id, config)
id += 1
end
end
```

Below is an example of the dashboard you get when you open Tensorboard with the command:
```sh
tensorboard --logdir=random_walk
```

![tuning plot](tuning.png)
Binary file added docs/src/examples/tuning.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions docs/src/hyperparameters.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Hyperparameter logging

In additition to logging the experiments, you may wish to also visualise the effect of hyperparameters on some plotted metrics. This can be done by logging the hyperparameters via the `write_hparams!` function, which takes a dictionary mapping hyperparameter names to their values (currently limited to `Real`, `Bool` or `String` types), along with the names of any metrics that you want to view the effects of.

You can see how the HParams dashboard in Tensorboard can be used to tune hyperparameters on the [tensorboard website](https://www.tensorflow.org/tensorboard/hyperparameter_tuning_with_hparams).

## API
```@docs
write_hparams!
```
9 changes: 6 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ We also support logging custom types from a the following third-party libraries:

## Explicit logging

In alternative, you can also log data to TensorBoard through its functional interface,
by calling the relevant method with a tag string and the data. For information
on this interface refer to [Explicit interface](@ref)...
As an alternative, you can also log data to TensorBoard through its functional interface, by calling the relevant method with a tag string and the data. For information on this interface refer to [Explicit interface](@ref).

## Hyperparameter tuning

Many experiments rely on hyperparameters, which can be difficult to tune. Tensorboard allows you to visualise the effect of your hyperparameters on your metrics, giving you an intuition for the correct hyperparameters for your task. For information on this API, see the [Hyperparameter logging](@ref) manual page.

38 changes: 38 additions & 0 deletions examples/HParams.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using TensorBoardLogger #import the TensorBoardLogger package
using Logging #import Logging package
using Random # Exports randn

# Run 10 experiments to see a plot
for j in 1:10
logger = TBLogger("random_walks/run$j", tb_append)

sigma = 0.1
epochs = 200
bias = (rand()*2 - 1) / 10 # create a random bias
use_seed = false
# Add in the a dummy loss metric
with_logger(logger) do
x = 0.0
for i in 1:epochs
x += sigma * randn() + bias
@info "scalar" loss = x
end
end

# Hyperparameter is a dictionary of parameter names to their values. This
# supports numerical types, bools and strings. Non-bool numerical types
# are converted to Float64 to be displayed.
hparams_config = Dict{String, Any}(
"sigma"=>sigma,
"epochs"=>epochs,
"bias"=>bias,
"use_seed"=>use_seed,
"method"=>"MC"
)
# Specify a list of tags that you want to show up in the hyperparameter
# comparison
metrics = ["scalar/loss"]

# Write the hyperparameters and metrics config to the logger.
write_hparams!(logger, hparams_config, metrics)
end
2 changes: 1 addition & 1 deletion gen/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429"

[comapt]
[compat]
ProtoBuf = "0.9.1"
3 changes: 2 additions & 1 deletion src/TensorBoardLogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info,

export TBLogger, reset!, set_step!, increment_step!, set_step_increment!, with_TBLogger_hold_step
export log_histogram, log_value, log_vector, log_text, log_image, log_images,
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar, write_hparams!
export map_summaries, TBReader

export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC,
Expand Down Expand Up @@ -62,6 +62,7 @@ include("ImageFormat.jl")
const TB_PLUGIN_JLARRAY_NAME = "_jl_tbl_array_sz"

include("TBLogger.jl")
include("hparams.jl")
include("utils.jl") # CRC Utils
include("event.jl")
include("Loggers/base.jl")
Expand Down
177 changes: 177 additions & 0 deletions src/hparams.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import .tensorboard_plugin_hparams.hparams: var"#DataType" as HParamDataType, DatasetType as HDatasetType
import .tensorboard_plugin_hparams.google.protobuf: ListValue as HListValue, Value as HValue
import .tensorboard_plugin_hparams.hparams as HP
import ProtoBuf as PB

struct HParamRealDomain
min_value::Float64
max_value::Float64
end
struct HParamSetDomain{T<:Union{String,Bool,Float64}}
values::Vector{T}
end
Base.@kwdef struct HParamConfig
name::String
datatype::DataType
displayname::String = ""
description::String = ""
domain::Union{Nothing,HParamRealDomain,HParamSetDomain} = nothing
end
Base.@kwdef struct MetricConfig
name::String
displayname::String = ""
description::String = ""
end

default_domain(::Val{Bool}) = HParamSetDomain([false, true])
default_domain(::Val{Float64}) = HParamRealDomain(typemin(Float64), typemax(Float64))
JamieMair marked this conversation as resolved.
Show resolved Hide resolved
default_domain(::Val{String}) = nothing

_to_proto_hparam_dtype(::Val{Bool}) = HParamDataType.DATA_TYPE_BOOL
_to_proto_hparam_dtype(::Val{Float64}) = HParamDataType.DATA_TYPE_FLOAT64
_to_proto_hparam_dtype(::Val{String}) = HParamDataType.DATA_TYPE_STRING

function _convert_value(v::T) where {T<:Union{String,Bool,Real}}
if v isa String
return HValue(OneOf(:string_value, v))
elseif v isa Bool
return HValue(OneOf(:bool_value, v))
elseif v isa Real
return HValue(OneOf(:number_value, Float64(v)))
else
error("Unrecognised type!")
end
end

_convert_hparam_domain(::Nothing) = nothing
_convert_hparam_domain(domain::HParamRealDomain) = OneOf(:domain_interval, HP.Interval(domain.min_value, domain.max_value))
_convert_hparam_domain(domain::HParamSetDomain) = OneOf(:domain_discrete, HListValue([_convert_value(v) for v in domain.values]))

function hparam_info(c::HParamConfig)
datatype = c.datatype
domain = c.domain
if isnothing(c.domain)
domain = default_domain(Val(datatype))
else
if isa(domain, HParamRealDomain)
@assert datatype == Float64 "Real domains require Float64"
elseif isa(domain, HParamSetDomain{String})
@assert datatype == String "Domains with strings require a datatype of String"
elseif isa(domain, HParamSetDomain{Bool})
@assert datatype == Bool "Domains with bools require a datatype of Bool"
elseif isa(domain, HParamSetDomain{Float64})
@assert datatype <: Real "Domains with floats require a datatype a Real datatype"
end
end

dtype = _to_proto_hparam_dtype(Val(datatype))
converted_domain = _convert_hparam_domain(domain)
return HP.HParamInfo(c.name, c.displayname, c.description, dtype, converted_domain)
end
function metric_info(c::MetricConfig)
mname = HP.MetricName("", c.name)
return HP.MetricInfo(mname, c.displayname, c.description, HDatasetType.DATASET_UNKNOWN)
end

function encode_bytes(content::HP.HParamsPluginData)
data = PipeBuffer()
encode(ProtoEncoder(data), content)
return take!(data)
end

# Dictionary serialisation in ProtoBuf does not work for this specific map type
# and must be overloaded so that it can be parsed. The format was derived by
# looking at the binary output of a log file created by tensorboardX.
# These protobuf overloads should be removed once https://github.com/JuliaIO/ProtoBuf.jl/pull/234 is merged.
function PB.encode(e::ProtoEncoder, i::Int, x::Dict{String,HValue})
for (k, v) in x
PB.Codecs.encode_tag(e, 1, PB.Codecs.LENGTH_DELIMITED)
total_size = PB.Codecs._encoded_size(k, 1) + PB.Codecs._encoded_size(v, 2)
PB.Codecs.vbyte_encode(e.io, UInt32(total_size)) # Add two for the wire type and length
PB.Codecs.encode(e, 1, k)
PB.Codecs.encode(e, 2, v)
end
return nothing
end

# Similarly, we must overload the size calculation to take into account the new
# format.
function PB.Codecs._encoded_size(d::Dict{String,HValue}, i::Int)
mapreduce(x->begin
total_size = PB.Codecs._encoded_size(x.first, 1) + PB.Codecs._encoded_size(x.second, 2)
return 1 + PB.Codecs._varint_size(total_size) + total_size
end, +, d, init=0)
end

function PB.Codecs.decode!(d::ProtoDecoder, buffer::Dict{String,HValue})
len = PB.Codecs.vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
pair_field_number, pair_wire_type = PB.Codecs.decode_tag(d)
pair_len = PB.Codecs.vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = PB.Codecs.decode_tag(d)
key = PB.Codecs.decode(d, K)
field_number, wire_type = PB.Codecs.decode_tag(d)
val = PB.Codecs.decode(d, Ref{V})
@assert position(d.io) == pair_end_pos
buffer[key] = val
end
@assert position(d.io) == endpos
nothing
end

"""
write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::AbstractArray{String})

Writes the supplied hyperparameters to the logger, along with noting all metrics that should be tracked.

The value of `hparams` can be a `String`, `Bool` or a subtype of `Real`. All `Real` values are converted
to `Float64` when writing the logs.

`metrics` should be a list of tags, which correspond to scalars that have been logged. Tensorboard will
automatically extract the latest metric logged to use for this value.
"""
function write_hparams!(logger::TBLogger, hparams::Dict{String,<:Any}, metrics::AbstractArray{String})
PLUGIN_NAME = "hparams"
PLUGIN_DATA_VERSION = 0

EXPERIMENT_TAG = "_hparams_/experiment"
SESSION_START_INFO_TAG = "_hparams_/session_start_info"
SESSION_END_INFO_TAG = "_hparams_/session_end_info"

# Check for datatypes
for (k, v) in hparams
@assert typeof(v) <: Union{Bool,String,Real} "Hyperparameters must be of types String, Bool or Real"
# Cast non-supported numerical values to Float64
if !(typeof(v) <: Bool) && typeof(v) <: Real
hparams[k] = Float64(v)
end
end

hparam_infos = [hparam_info(HParamConfig(; name=k, datatype=typeof(v))) for (k, v) in hparams]
metric_infos = [metric_info(MetricConfig(; name=metric)) for metric in metrics]

hparams_dict = Dict(k => _convert_value(v) for (k, v) in hparams)

experiment = HP.Experiment("", "", "", time(), hparam_infos, metric_infos)
experiment_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, OneOf(:experiment, experiment))
experiment_md = SummaryMetadata(SummaryMetadata_PluginData(PLUGIN_NAME, encode_bytes(experiment_content)), "", "", DataClass.DATA_CLASS_UNKNOWN)
experiment_summary = Summary([Summary_Value("", EXPERIMENT_TAG, experiment_md, nothing)])

session_start_info = HP.SessionStartInfo(hparams_dict, "", "", "", time())
session_start_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, OneOf(:session_start_info, session_start_info))
session_start_md = SummaryMetadata(SummaryMetadata_PluginData(PLUGIN_NAME, encode_bytes(session_start_content)), "", "", DataClass.DATA_CLASS_UNKNOWN)
session_start_summary = Summary([Summary_Value("", SESSION_START_INFO_TAG, session_start_md, nothing)])

session_end_info = HP.SessionEndInfo(HP.Status.STATUS_SUCCESS, time())
session_end_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, OneOf(:session_end_info, session_end_info))
session_end_md = SummaryMetadata(SummaryMetadata_PluginData(PLUGIN_NAME, encode_bytes(session_end_content)), "", "", DataClass.DATA_CLASS_UNKNOWN)
session_end_summary = Summary([Summary_Value("", SESSION_END_INFO_TAG, session_end_md, nothing)])

for s in (experiment_summary, session_start_summary, session_end_summary)
event = make_event(logger, s)
write_event(logger, event)
end
nothing
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ end
close.(values(logger.all_files))
end

@testset "Hyperparameter logging" begin
include("test_hparams.jl")
end

@testset "Image processing interface" begin
#2-d image
data = Vector{Pair{String,Any}}()
Expand Down
Loading