Skip to content

Commit

Permalink
Merge pull request #38 from bankofcanada/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
bbejanov authored Apr 5, 2023
2 parents a9a9de5 + d6c1fd2 commit b545079
Show file tree
Hide file tree
Showing 12 changed files with 481 additions and 143 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StateSpaceEcon"
uuid = "e4c825b0-b65c-11ea-0b5a-6176b64e7b7f"
authors = ["Atai Akunov <[email protected]>", "Boyan Bejanov <[email protected]>", "Nicholas Labelle St-Pierre <[email protected]>"]
version = "0.4"
version = "0.4.1"

[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand All @@ -20,9 +20,9 @@ TimeSeriesEcon = "8b6756d2-c55c-11ea-2998-5f67ea17da60"
DiffResults = "1.0"
ForwardDiff = "0.10"
JLD2 = "0.4"
ModelBaseEcon = "0.5"
ModelBaseEcon = "0.5.2"
Suppressor = "0.2"
TimeSeriesEcon = "0.5"
TimeSeriesEcon = "0.5.1"
julia = "1.7"

[extras]
Expand Down
55 changes: 32 additions & 23 deletions src/Plans.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
##################################################################################
# This file is part of StateSpaceEcon.jl
# BSD 3-Clause License
# Copyright (c) 2020-2022, Bank of Canada
# Copyright (c) 2020-2023, Bank of Canada
# All rights reserved.
##################################################################################

Expand Down Expand Up @@ -97,13 +97,17 @@ Base.length(p::Plan) = length(p.range)
Base.IndexStyle(::Plan) = IndexLinear()
Base.similar(p::Plan) = Plan(p.range, p.varshks, similar(p.exogenous))
Base.copy(p::Plan) = Plan(p.range, p.varshks, copy(p.exogenous))
TimeSeriesEcon.rangeof(p::Plan) = p.range

function Base.copyto!(dest::Plan,rng::AbstractUnitRange,scr::Plan)
function Base.copyto!(dest::Plan, rng::AbstractUnitRange, scr::Plan)
dest.varshks == scr.varshks || throw(ArgumentError("Both plans must have the same variables and shocks in the same order."))
idx2 = axes(dest.exogenous, 2) # same for both dest and scr
copyto!(dest.exogenous, _offset(dest, rng), idx2, scr.exogenous, _offset(scr, rng), idx2)
return dest
end
Base.copyto!(dest::Plan,rng::MIT,scr::Plan) = Base.copyto!(dest,rng:rng,scr)

Base.copyto!(dest::Plan, rng::MIT, scr::Plan) = Base.copyto!(dest, rng:rng, scr)
Base.copyto!(dest::Plan, src::Plan) = Base.copyto!(dest, intersect(rangeof(dest), rangeof(src)), src)

_offset(p::Plan{T}, idx::T) where {T<:MIT} = convert(Int, idx - first(p.range) + 1)
_offset(p::Plan{T}, idx::AbstractUnitRange{T}) where {T<:MIT} =
Expand Down Expand Up @@ -132,8 +136,9 @@ end
# A range with a model returns a plan trimmed over that range and extended for initial and final conditions.
Base.getindex(p::Plan{MIT{Unit}}, rng::AbstractUnitRange{Int}, m::Model) = p[UnitRange{MIT{Unit}}(rng), m]
@inline function Base.getindex(p::Plan{T}, rng::AbstractUnitRange{T}, m::Model) where {T<:MIT}
rng = (rng.start-m.maxlag):(rng.stop+m.maxlead)
return p[rng]
# rng = (rng.start-m.maxlag):(rng.stop+m.maxlead)
# return p[rng]
copyto!(Plan(m, rng), rng, p)
end

Base.setindex!(p::Plan, x, i...) = error("Cannot assign directly. Use `exogenize` and `endogenize` to alter plan.")
Expand All @@ -142,10 +147,18 @@ Base.setindex!(p::Plan, x, i...) = error("Cannot assign directly. Use `exogenize
# query the exo-end status of a variable

@inline Base.getindex(p::Plan{T}, vars::Symbol...) where {T} = begin
var_inds = [p.varshks[v] for v in vars]
var_inds = Int[p.varshks[vars]...]
Plan{T}(p.range, NamedTuple{(vars...,)}(eachindex(vars)), p.exogenous[:, var_inds])
end

@inline Base.getindex(p::Plan{T}, rng::AbstractUnitRange{T}, vars::Symbol...) where {T} = begin
rng.start < p.range.start && throw(BoundsError(p, rng.start))
rng.stop > p.range.stop && throw(BoundsError(p, rng.stop))
var_inds = Int[p.varshks[vars]...]
Plan{T}(rng, NamedTuple{(vars...,)}(eachindex(vars)), p.exogenous[_offset(p, rng), var_inds])
end


#######################################
# Pretty printing

Expand Down Expand Up @@ -181,7 +194,7 @@ function Base.show(io::IO, p::Plan)
limit = get(io, :limit, true)
cp = collapsed_range(p)
# find the longest string left of "=>" for padding
maxl = maximum(length("$k") for (k, v) in cp)
maxl = maximum(length string first, cp)
if limit
dcol = ncol - maxl - 6
else
Expand Down Expand Up @@ -332,19 +345,12 @@ function importplan(io::IO)
if m === nothing
error("expected Variables: at the start of line 3, got ", line, ".")
end
nt = let
nt = parse_namedtuple(m.captures[1], 3)
nms = collect(keys(nt))
inds = collect(nt)
if inds != 1:length(inds)
if Set(inds) != Set(1:length(inds))
error("indexes of variables on line 3 are not valid.")
end
tmp = Dict(i => n for (i, n) in zip(inds, nms))
nt = NamedTuple{((tmp[i] for i = 1:length(inds))...,)}(1:length(inds))
end
nt
nt = parse_namedtuple(m.captures[1], 3)
if Set(nt) != Set(1:length(nt))
error("Indexes of variables on line 3 are not valid.")
end
# sort nt by its values (variable indexes)
nt = (; sort!(collect(pairs(nt)), by=last)...)
# parse line 4. Example "(X) = Exogenous, (-) = Endogenous
line = readline(io)
m = match(r"\((.+?)\) = Exogenous, \((.+?)\) = Endogenous:", line)
Expand All @@ -360,10 +366,13 @@ function importplan(io::IO)
end
_name_delim = m.captures[1]
_range_delim = m.captures[3]
ranges = [parse_range(strip(str), true, 5) for str in split(m.captures[2], _range_delim; keepempty=false)]
# parse the rest of it
p = Plan{MIT{freq}}(rng, nt, falses(length(rng), length(nt)))
ranges_inds = [[_offset(p, r)...] for r in ranges]
ranges_inds = Vector{Int}[]
for str in split(m.captures[2], _range_delim; keepempty=false)
r = parse_range(strip(str), true, 5)
push!(ranges_inds, [_offset(p, r);])
end
# parse the rest of it
pat = Regex("\\s+(\\w+)\\s*$(_name_delim)\\s*(.*)")
for i = 1:length(nt)
line = readline(io)
Expand Down Expand Up @@ -416,7 +425,7 @@ end

function parse_namedtuple(str, line=nothing)
e = Meta.parse(str)
ans = Meta.isexpr(e, :tuple) && all(Meta.isexpr(ee, :(=)) for ee in e.args) ? eval(e) : nothing
ans = Meta.isexpr(e, :tuple) && all(Base.Fix2(Meta.isexpr, :(=)), e.args) ? eval(e) : nothing
if !(ans isa NamedTuple{NAMES,NTuple{N,Int}} where {NAMES,N})
error("expected NamedTuple, got ", str, line === nothing ? "." : " on line $line.")
end
Expand Down
5 changes: 5 additions & 0 deletions src/StackedTimeSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ import ..steadystatearray
import ..SimData
import ..rawdata

import ..SimFailed
import ..isfailed
import ..MaybeSimData

import ModelBaseEcon.hasevaldata
import ModelBaseEcon.getevaldata
import ModelBaseEcon.setevaldata!
Expand All @@ -41,6 +45,7 @@ include("stackedtime/misc.jl")
include("stackedtime/solverdata.jl")
include("stackedtime/simulate.jl")
include("stackedtime/shockdecomp.jl")
include("stackedtime/stoch_simulate.jl")

end # module

Expand Down
36 changes: 29 additions & 7 deletions src/simdata.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
##################################################################################
# This file is part of StateSpaceEcon.jl
# BSD 3-Clause License
# Copyright (c) 2020-2022, Bank of Canada
# Copyright (c) 2020-2023, Bank of Canada
# All rights reserved.
##################################################################################

Expand All @@ -24,19 +24,21 @@ export SimData
# same constructors as should work for SimData
SimData(args...) = MVTSeries(args...)

const _getname = Base.Fix2(getfield, :name)

const _MVCollection = Union{Vector{ModelVariable},NTuple{N,ModelVariable}} where {N}
# we should allow indexing with model variables
Base.getindex(sd::SimData, vars::_MVCollection) = getindex(sd, map(v -> v.name, vars))
Base.getindex(sd::SimData, vars::_MVCollection) = getindex(sd, map(_getname, vars))
Base.getindex(sd::SimData, vars::ModelVariable) = getindex(sd, vars.name)
Base.setindex!(sd::SimData, val, vars::_MVCollection) = setindex!(sd, val, map(v -> v.name, vars))
Base.setindex!(sd::SimData, val, vars::_MVCollection) = setindex!(sd, val, map(_getname, vars))
Base.setindex!(sd::SimData, val, vars::ModelVariable) = setindex!(sd, val, vars.name)

Base.getindex(sd::SimData, rows, vars::_MVCollection) = getindex(sd, rows, map(v->v.name, vars))
Base.getindex(sd::SimData, rows, vars::_MVCollection) = getindex(sd, rows, map(_getname, vars))
Base.getindex(sd::SimData, rows, vars::ModelVariable) = getindex(sd, rows, vars.name)
Base.setindex!(sd::SimData, val, rows, vars::_MVCollection) = setindex!(sd, val, rows, map(v->v.name, vars))
Base.setindex!(sd::SimData, val, rows, vars::_MVCollection) = setindex!(sd, val, rows, map(_getname, vars))
Base.setindex!(sd::SimData, val, rows, vars::ModelVariable) = setindex!(sd, val, rows, vars.name)

Base.view(sd::SimData, rows, vars::_MVCollection) = view(sd, rows, map(v->v.name, vars))
Base.view(sd::SimData, rows, vars::_MVCollection) = view(sd, rows, map(_getname, vars))
Base.view(sd::SimData, rows, vars::ModelVariable) = view(sd, rows, vars.name)

#######################################################
Expand Down Expand Up @@ -116,7 +118,7 @@ function workspace2data(w::Workspace, vars, range::AbstractUnitRange; copy=false
ret = SimData(range, vars, NaN)
for v in vars
wv = w[Symbol(v)]
copyto!(ret[v], intersect(range,rangeof(wv)), wv)
copyto!(ret[v], intersect(range, rangeof(wv)), wv)
end
return ret
end
Expand Down Expand Up @@ -154,3 +156,23 @@ export dict2array, array2dict, dict2data, data2dict
@deprecate dict2data(d::Workspace, args...; kwargs...) workspace2data(d, args...; kwargs...)
@deprecate array2dict(args...; kwargs...) array2workspace(args...; kwargs...)
@deprecate data2dict(args...; kwargs...) data2workspace(args...; kwargs...)


struct SimFailed <: Exception
info
end
Base.showerror(io::IO, ex::SimFailed) =
isnothing(ex.info) ? print(io, "Simulation failed.") :
ex.info isa MIT ? print(io, "Simulation failed in period $(ex.info).") :
ex.info isa AbstractUnitRange{<:MIT} ? print(io, "Simulation over $(ex.info) failed.") :
print(io, "Simulation failed: $(ex.info)")
isfailed(f::SimFailed)::Bool = !isnothing(f.info)
isfailed(f::SimData)::Bool = false
isfailed(f::Workspace)::Bool = false
isfailed(f)::Bool = throw(ArgumentError("Unexpected $(typeof(f)) argument."))
const MaybeSimData = Union{<:SimData,SimFailed}
Base.promote_rule(T::Type{<:SimData}, S::Type{<:SimFailed}) = Union{T, S}::Type
export SimFailed
export isfailed
export MaybeSimData

Loading

0 comments on commit b545079

Please sign in to comment.