-
Notifications
You must be signed in to change notification settings - Fork 0
/
Train.py
95 lines (76 loc) · 2.85 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Train the land use dataset with 3 classes and 4 image bands
"""
# Import PyTorch libraries
import torch
import torch.nn as nn
from torch import device, cuda, optim, autocast, save
# Import algorithms
from Models.deep3 import deeplab
from Models.Simple_CNN import Net
from utils.validation_accuracy import evaluate_accuracy
from Data.dataset_Oakville_v2 import train_set, getLength
num_classes = 4
num_bands = 4
epochs = 2
learning_rate = 1e-8
mem_args = dict(memory_format=torch.channels_last)
out_path = "/mnt/d/LandUseClassification.pth"
def train_model(model, device_hw, epoch_num, lr):
# Set the optimizer and learning rate scheduler
optimizer = optim.SGD(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
# !! ignore_index is very important !! this is how I handle nodata values (in a hacky way)
criterion = nn.CrossEntropyLoss(ignore_index=4)
gradient_scaler = torch.amp.GradScaler()
step_glob = 0
# Start training
for epoch in range(epoch_num):
model.train()
epoch_loss = 0
for batch in train_set:
images = batch["image"]
masks = batch["mask"]
images = images.to(device=device_hw)
masks = masks.to(device=device_hw)
optimizer.zero_grad()
with autocast(device_hw.type):
mask_prediction = model(images)['out']
masks = masks.squeeze(1)
loss = criterion(mask_prediction, masks)
gradient_scaler.scale(loss).backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
gradient_scaler.step(optimizer)
gradient_scaler.update()
step_glob += 1
epoch_loss += loss.item()
print("Train loss: ", loss.item(), "Step: ", step_glob, "epoch: ", epoch)
"""
# Evaluate Model
# 5 * batch size
step_div = (getLength() // (5 * 20))
if step_div > 0:
if step_glob % step_div == 0:
validation_score = evaluate_accuracy(model, validation_set, device_hw)
scheduler.step(validation_score)
print("Validation Score: ", validation_score)
"""
print("Training Complete!")
state_dict = model.state_dict()
save(state_dict, out_path)
print("Model Saved")
if __name__ == '__main__':
print("Using PyTorch version: ", torch.__version__)
device = device('cuda' if cuda.is_available() else 'cpu')
print(f"Training on {device}")
network = deeplab(num_classes, num_bands, device)
try:
train_model(
model=network,
device_hw=device,
epoch_num=epochs,
lr=learning_rate
)
except cuda.OutOfMemoryError:
print("Out of memory!")
cuda.empty_cache()