Skip to content

Commit

Permalink
AutoMerger: support alpha weight mode (issue #103)
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Jan 17, 2024
1 parent 08a8bf5 commit fa861bc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
6 changes: 5 additions & 1 deletion scripts/model_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2834,10 +2834,14 @@ def before_process(self, p, enabled, model_a, base_model, mm_max_models, mm_fine
if getattr(shared, "modelmixer_overrides", None) is not None:
overrides = getattr(shared, "modelmixer_overrides")
_weights = overrides["weights"]
_alpha = overrides["alpha"]
_uses = overrides["uses"]
args = list(args_)
for j in range(len(_uses)):
if _uses[j] and len(_weights) > j:
if _uses[j] and len(_alpha) > j and _alpha[j] != "":
mm_alpha[j] = _alpha[j]
args[num_models*3+j] = _alpha[j]
elif _uses[j] and len(_weights) > j:
mm_weights[j] = _weights[j]
# update args to set extra_params
args[num_models*7+j] = _weights[j]
Expand Down
36 changes: 35 additions & 1 deletion sd_modelmixer/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def hyper_score(localargs):
tunables = localargs.pass_through["tunables"]
isxl = localargs.pass_through["isxl"]
uses = localargs.pass_through["uses"]
usembws = localargs.pass_through["usembws"]
testweights = localargs.pass_through["weights"].copy()
prompt = localargs.pass_through["prompt"]
payload_path = localargs.pass_through["payload_path"]
Expand All @@ -244,13 +245,18 @@ def hyper_score(localargs):
if shared.state.interrupted:
raise ValueError("Error: Interrupted!")

testalpha = [""] * len(usembws)
# gather tunable variables into override weights
for k in tunables:
name = k.split(".")
modelidx = ord(name[0].split("_")[1]) - 98
if uses[modelidx] is False:
continue

if len(usembws[modelidx]) == 0:
testalpha[modelidx] = localargs[k]
continue

weight = testweights[modelidx]
j = BLOCKS.index(name[1])
weight[j] = localargs[k]
Expand All @@ -259,10 +265,12 @@ def hyper_score(localargs):
_weights = [""] * len(testweights)
for j in range(len(testweights)):
_weights[j] = ','.join([("0" if float(w) == 0.0 else str(w)) for w in testweights[j]])

print(" - test weights: ", _weights)
print(" - test alphas: ", testalpha)

# setup override weights. will be replaced with mm_weights
shared.modelmixer_overrides = {"weights": _weights, "uses": uses}
shared.modelmixer_overrides = {"weights": _weights, "alpha": testalpha, "uses": uses}

if len(payloads) == 0:
images = []
Expand Down Expand Up @@ -332,6 +340,7 @@ def hyper_score(localargs):
uses = initial["uses"] # used models
weights = initial["weights"] # normalized weights
usembws = initial["usembws"] # merged blocks
alpha = initial["alpha"] # alpha values without merged blocks
selected_blocks = initial["selected"]

isxl = shared.sd_model.is_sdxl
Expand Down Expand Up @@ -367,6 +376,20 @@ def hyper_score(localargs):
continue

name = f"model_{chr(i + 98)}"
if len(usembws[k]) == 0:
# no merged block weighs
val = alpha[k]
# setup range, lower + val ~ val + upper < search max. e.g.) -0.3 + val ~ val + 0.3 < 0.5
lower = max(val + search_lower, 0)
upper = min(val + search_upper, search_max)
if steps_or_inc >= 1:
search_space[f"{name}.alpha"] = [*np.round(np.linspace(lower, upper, steps_or_inc), 8)]
elif steps_or_inc < 1:
search_space[f"{name}.alpha"] = [*np.round(np.arange(lower, upper, steps_or_inc), 8)]

k += 1
continue

weight = weights[k]
mbw = normalize_mbw(usembws[k], isxl)
for b in selected_blocks:
Expand Down Expand Up @@ -394,6 +417,7 @@ def hyper_score(localargs):
_uses = current["uses"] # used models
_weights = current["weights"] # normalized weights
_usembws = current["usembws"] # merged blocks
_alpha = current["alpha"] # merged blocks
_selected_blocks = current["selected"]
k = 0 # fix index for not used model. e.g.) A, B, E (C is not selected case)
for i in range(len(_uses)):
Expand All @@ -406,6 +430,14 @@ def hyper_score(localargs):
continue

name = f"model_{chr(i + 98)}"
if len(_usembws[k]) == 0:
# no merged block weighs
val = _alpha[k]
warm[f"{name}.alpha"] = val

k += 1
continue

weight = _weights[k]
mbw = normalize_mbw(_usembws[k], isxl)
for b in _selected_blocks:
Expand Down Expand Up @@ -452,7 +484,9 @@ def hyper_score(localargs):
pass_through = {
"tunables": [*search_space.keys()],
"weights": weights,
"alpha": alpha,
"uses": override_uses,
"usembws": usembws,
"classifier": classifier,
"payload_path": payload_path,
"tally_type": tally_type,
Expand Down

0 comments on commit fa861bc

Please sign in to comment.