Skip to content

Commit

Permalink
new consensus structure and torch support
Browse files Browse the repository at this point in the history
  • Loading branch information
TimRoith committed Aug 16, 2024
1 parent b6b1103 commit 2b4cbb0
Show file tree
Hide file tree
Showing 5 changed files with 490 additions and 50 deletions.
1 change: 0 additions & 1 deletion cbx/dynamics/cbo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .pdyn import CBXDynamic


def cbo_update(drift, lamda, dt, sigma, noise):
return -lamda * dt * drift + sigma * noise
#%% CBO
Expand Down
55 changes: 30 additions & 25 deletions cbx/dynamics/pdyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from ..scheduler import scheduler, multiply, effective_sample_size
from ..utils.termination import max_it_term
from ..utils.history import track_x, track_energy, track_update_norm, track_consensus, track_drift, track_drift_mean
from ..utils.particle_init import init_particles
from cbx.utils.objective_handling import _promote_objective

#%%
Expand Down Expand Up @@ -126,7 +125,6 @@ def __init__(
f_dim: str = '1D',
check_f_dims: bool = True,
x: Union[None, np.ndarray] = None,
x_min: float = -1., x_max: float = 1.,
M: int = 1, N: int = 20, d: int = None,
max_it: int = 1000,
term_criteria: List[Callable] = None,
Expand All @@ -136,18 +134,17 @@ def __init__(
norm: Callable = None,
sampler: Callable = None,
post_process: Callable = None,
seed: int = None
) -> None:

self.verbosity = verbosity
self.seed = seed

# set utilities
self.copy = copy if copy is not None else np.copy
self.norm = norm if norm is not None else np.linalg.norm
rng = np.random.default_rng(12345)
self.sampler = sampler if sampler is not None else rng.standard_normal

# set array backend funs
self.set_array_backend_funs(copy, norm, sampler)

# init particles
self.init_x(x, M, N, d, x_min, x_max)
self.init_x(x, M, N, d)

# set and promote objective function
self.init_f(f, f_dim, check_f_dims)
Expand All @@ -166,7 +163,13 @@ def __init__(
# post processing
self.post_process = post_process if post_process is not None else post_process_default

def init_x(self, x, M, N, d, x_min, x_max):
def set_array_backend_funs(self, copy, norm, sampler):
self.copy = copy if copy is not None else np.copy
self.norm = norm if norm is not None else np.linalg.norm
rng = np.random.default_rng(self.seed)
self.sampler = sampler if sampler is not None else rng.standard_normal

def init_x(self, x, M, N, d):
"""
Initialize the particle system with the given parameters.
Expand All @@ -185,22 +188,22 @@ def init_x(self, x, M, N, d, x_min, x_max):
if x is None:
if d is None:
raise RuntimeError('If the inital partical system is not given, the dimension d must be specified!')
x = init_particles(
shape=(M, N, d),
x_min = x_min, x_max = x_max
)
self.x = self.init_particles(shape=(M, N, d))

else: # if x is given correct shape
if len(x.shape) == 1:
if x.ndim == 1:
x = x[None, None, :]
elif len(x.shape) == 2:
x = x[None, :]
elif x.ndim == 2:
x = x[None, ...]
self.x = self.copy(x)

self.M, self.N = self.x.shape[:2]
self.d = self.x.shape[2:]
self.ddims = tuple(i for i in range(2, self.x.ndim))

self.M = x.shape[0]
self.N = x.shape[1]
self.d = x.shape[2:]
self.ddims = tuple(i for i in range(2, x.ndim))
self.x = self.copy(x)

def init_particles(self, shape=None):
return np.random.uniform(-1., 1., size=shape)

def init_f(self, f, f_dim, check_f_dims):
self.f = _promote_objective(f, f_dim)
Expand Down Expand Up @@ -620,16 +623,18 @@ def __init__(self, f,
self.set_noise(noise)

self.init_batch_idx(batch_args)

self.consensus = None #consensus point
self._compute_consensus = compute_consensus if compute_consensus is not None else compute_consensus_default()
self.init_consensus(compute_consensus)

known_tracks = {
'consensus': track_consensus,
'drift_mean': track_drift_mean,
'drift': track_drift,
**ParticleDynamic.known_tracks,}

def init_consensus(self, compute_consensus):
self.consensus = None #consensus point
self._compute_consensus = compute_consensus if compute_consensus is not None else compute_consensus_default()

def init_alpha(self, alpha):
'''
Initialize alpha per batch. If alpha is a float it is broadcasted to an array similar to x with dimensions (x.shape[0], 1).
Expand Down
2 changes: 1 addition & 1 deletion cbx/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def __init__(self, b=0., c=0.):
self.minima = np.array([[self.b, self.b]])

def apply(self, x):
return (1/x.shape[-1]) * np.sum((x - self.b)**2 - 10*np.cos(2*np.pi*(x - self.b)) + 10, axis=-1) + self.c
return (1/x.shape[-1]) * ((x - self.b)**2 - 10*np.cos(2*np.pi*(x - self.b)) + 10).sum(-1) + self.c


class Rastrigin_multimodal(cbx_objective):
Expand Down
30 changes: 29 additions & 1 deletion cbx/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def torch_decorated(*args, **kwargs):

@requires_torch
def norm_torch(x, axis, **kwargs):
return torch.linalg.norm(x, dim=axis, **kwargs)
return torch.linalg.vector_norm(x, dim=axis, **kwargs)

@requires_torch
def compute_consensus_torch(energy, x, alpha):
Expand All @@ -48,6 +48,34 @@ def _normal_torch(size=None):
return torch.randn(size=size).to(device)
return _normal_torch

def set_array_backend_funs_torch(self, copy, norm, sampler):
self.copy = copy if copy is not None else torch.clone
self.norm = norm if norm is not None else norm_torch
self.sampler = sampler if sampler is not None else standard_normal_torch(self.device)

def init_particles(self, shape=None,):
return torch.zeros(size=shape).uniform_(-1., 1.)

def init_consensus(self, compute_consensus):
self.consensus = None #consensus point
self._compute_consensus = compute_consensus if compute_consensus is not None else compute_consensus_torch

def to_torch_dynamic(dyn_cls):
def add_device_init(self, *args, device='cpu', **kwargs):
self.device = device
dyn_cls.__init__(self, *args, **kwargs)

return type(dyn_cls.__name__ + str('_torch'),
(dyn_cls,),
dict(
__init__ = add_device_init,
set_array_backend_funs=set_array_backend_funs_torch,
init_particles=init_particles,
init_consensus=init_consensus
)
)


@requires_torch
def eval_model(x, model, w, pprop):
params = {p: w[pprop[p][-2]:pprop[p][-1]].view(pprop[p][0]) for p in pprop}
Expand Down
452 changes: 430 additions & 22 deletions docs/examples/polarcbo.ipynb

Large diffs are not rendered by default.

0 comments on commit 2b4cbb0

Please sign in to comment.