-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_funcs.py
66 lines (55 loc) · 2.85 KB
/
train_funcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import numpy as np
from tqdm import tqdm
import ast
"""
TRAIN FUNCTION DEFINITION:
train(model: StableDiffusionPipeline,
projection_matrices: list[size=L](nn.Module),
og_matrices: list[size=L](nn.Module),
contexts: list[size=N](torch.tensor[size=MAX_LEN,...]),
valuess: list[size=N](list[size=L](torch.tensor[size=MAX_LEN,...])),
old_texts: list[size=N](str),
new_texts: list[size=N](str),
**kwargs)
where L is the number of matrices to edit, and N is the number of sentences to train on (batch size).
PARAMS:
model: the model to use.
projection_matrices: list of projection matrices to edit from the model.
og_matrices: list of original values for the projection matrices. detached from the model.
contexts: list of context vectors (inputs to the matrices) to edit.
valuess: list of results from all matrices for each context vector.
old_texts: list of sentences to be edited.
new_texts: list of target sentences to be aimed at.
**kwargs: additional command line arguments.
TRAIN_FUNC_DICT defined at the bottom of the file.
"""
def baseline_train(model, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts):
return None
def train_closed_form(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts,
new_texts, layers_to_edit=None, lamb=0.1):
layers_to_edit = ast.literal_eval(layers_to_edit) if type(layers_to_edit) == str else layers_to_edit
lamb = ast.literal_eval(lamb) if type(lamb) == str else lamb
for layer_num in tqdm(range(len(projection_matrices))):
if (layers_to_edit is not None) and (layer_num not in layers_to_edit):
continue
with torch.no_grad():
#mat1 = \lambda W + \sum{v k^T}
mat1 = lamb * projection_matrices[layer_num].weight
#mat2 = \lambda I + \sum{k k^T}
mat2 = lamb * torch.eye(projection_matrices[layer_num].weight.shape[1], device = projection_matrices[layer_num].weight.device)
#aggregate sums for mat1, mat2
for context, values in zip(contexts, valuess):
context_vector = context.reshape(context.shape[0], context.shape[1], 1)
context_vector_T = context.reshape(context.shape[0], 1, context.shape[1])
value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1)
for_mat1 = (value_vector @ context_vector_T).sum(dim=0)
for_mat2 = (context_vector @ context_vector_T).sum(dim=0)
mat1 += for_mat1
mat2 += for_mat2
#update projection matrix
projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2))
TRAIN_FUNC_DICT = {
"baseline": baseline_train,
"train_closed_form": train_closed_form,
}