-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_d.m
88 lines (76 loc) · 3.28 KB
/
train_d.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
addpath layers/
addpath layers_adapters/
addpath mem/
addpath optimization/
addpath pipeline/
addpath utils/
clearvars -global config;
clearvars -global mem;
global config mem;
denoise_configure();
load('val/val_1');
init(0);
perm = randperm(size(test_samples, 4));
test_samples = test_samples(:,:,:,perm);
test_labels = test_labels(:,:,:,perm);
test_samples = config.NEW_MEM(test_samples(:,:,:,1:5));
test_labels = config.NEW_MEM(test_labels(:,:,:,1:5));
count = 0;
cost_avg = 0;
epoc = 0;
points_seen = 0;
display_points = 10;
save_points = 10;
max_grad = 1;
fprintf('%s\n', datestr(now, 'dd-mm-yyyy HH:MM:SS '));
load size1.mat
for pass = 1:size1
for p = 1
load(strcat('F:/bm3d/BM3D/train/patches_', num2str(p), '.mat'));
perm = randperm(10);
samples = samples(:,:,:,perm);
labels = labels(:,:,:,perm);
train_imgs = config.NEW_MEM(samples);
train_labels = config.NEW_MEM(labels);
for i = 1:size(train_labels, 4) / config.batch_size
points_seen = points_seen + config.batch_size;
in = train_imgs(:,:,:,(i-1)*config.batch_size+1:i*config.batch_size);
out = train_labels(:,:,:,(i-1)*config.batch_size+1:i*config.batch_size);
out = out((size(in, 1) - config.output_size(1)) / 2 + 1:(size(in, 1) - config.output_size(1)) / 2 + config.output_size(1), ...
(size(in, 2) - config.output_size(2)) / 2 + 1:(size(in, 2) - config.output_size(2)) / 2 + config.output_size(2), :, :);
% operate the training pipeline
op_train_pipe(in, out);
% update the weights
config.UPDATE_WEIGHTS();
if(cost_avg == 0)
cost_avg = config.cost;
else
cost_avg = (cost_avg + config.cost) / 2;
end
% display point
if(mod(points_seen, display_points) == 0)
count = count + 1;
fprintf('%d ', count);
end
% save point
if(mod(points_seen, save_points) == 0)
fprintf('\n%s', datestr(now, 'dd-mm-yyyy HH:MM:SS '));
epoc = epoc + 1;
test_cost = 0;
for t = 1:size(test_samples, 4) / config.batch_size
t_label = test_labels(:,:,:,(t-1)*config.batch_size+1:t*config.batch_size);
t_label = config.NEW_MEM(t_label((size(in, 1) - config.output_size(1)) / 2 + 1:(size(in, 1) - config.output_size(1)) / 2 + config.output_size(1), ...
(size(in, 2) - config.output_size(2)) / 2 + 1:(size(in, 2) - config.output_size(2)) / 2 + config.output_size(2), :));
op_test_pipe(test_samples(:,:,:,(t-1)*config.batch_size+1:t*config.batch_size), t_label);
test_out = gather(mem.output);
test_cost = test_cost + config.cost;
end
test_cost = test_cost / size(test_samples, 4);
fprintf('\nepoc %d, training avg cost: %f, test avg cost: %f\n', epoc, cost_avg, test_cost);
save_weights(strcat('F:\bm3d\BM3D\results\epoc', num2str(epoc), '.mat'));
cost_avg = 0;
end
end
end
end
disp('done')