forked from h2oai/h2o-llmstudio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_wave.py
163 lines (140 loc) · 5.45 KB
/
train_wave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
# Set this before importing any other modules to be on the safe side
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import logging
import sys
import time
import psutil
def check_for_done(process_queue):
"""Checks for finished process ids
Args:
process_queue: list of process ids
Returns:
(True, process_idx) if there is any finished process
(False, False) if there is not finished processes
"""
for i, pid in enumerate(process_queue):
zombie = False
try:
p = psutil.Process(pid)
zombie = p.status() == "zombie"
except psutil.NoSuchProcess:
pass
if not psutil.pid_exists(pid) or zombie:
return True, i
return False, False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-C", "--config", help="config filename", default=argparse.SUPPRESS
)
parser.add_argument("-Y", "--yaml", help="yaml filename", default=argparse.SUPPRESS)
parser.add_argument(
"-Q",
"--process-queue",
help="process queue to wait for",
default=argparse.SUPPRESS,
)
parser_args, _ = parser.parse_known_args(sys.argv)
process_queue = []
if "process_queue" in parser_args and parser_args.process_queue != "":
process_queue = [int(x) for x in parser_args.process_queue.split(",")]
while True:
if len(process_queue) == 0:
break
done, num = check_for_done(process_queue)
if done:
process_queue.pop(num)
else:
time.sleep(30)
# delayed imports from llm_studio, only after we want to start training
import subprocess
import torch
from llm_studio.src.utils.config_utils import load_config_py, load_config_yaml
from llm_studio.src.utils.exceptions import (
LLMAugmentationsException,
LLMDataException,
LLMMetricException,
LLMModelException,
LLMTrainingException,
)
from llm_studio.src.utils.gpu_utils import is_oom_error
from llm_studio.src.utils.logging_utils import initialize_logging, write_flag
from llm_studio.src.utils.utils import kill_ddp_processes
from train import run
if "config" in parser_args:
cfg = load_config_py(parser_args.config)
elif "yaml" in parser_args:
cfg = load_config_yaml(parser_args.yaml)
flag_path = os.path.join(cfg.output_directory, "flags{}.json")
# Check if DDP
if "WORLD_SIZE" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
if local_rank == 0:
write_flag(flag_path.format(""), "status", "running")
else:
write_flag(flag_path.format(""), "status", "running")
local_rank = 0
initialize_logging(cfg)
try:
run(cfg=cfg)
except Exception as exception:
write_flag(flag_path.format(local_rank), "status", "failed")
if is_oom_error(exception):
logging.error(
"GPU Out-of-Memory (OOM) error occurred. "
"Please, reduce the batch size, or input data size, "
"or model size. Or try gradient checkpointing.",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "OOM error")
logging.info(
"<pre>"
+ subprocess.check_output(["nvidia-smi"]).decode("utf-8")
+ "</pre>"
)
if torch.cuda.is_available():
logging.info(
"<pre>" + torch.cuda.memory_summary().replace("-", "=") + "</pre>"
)
elif isinstance(exception, LLMDataException):
logging.error(
"Data error occurred during H2O LLM Studio run:", exc_info=True
)
write_flag(flag_path.format(local_rank), "info", "Data error")
elif isinstance(exception, LLMTrainingException):
logging.error(
"Training error occurred during H2O LLM Studio run:", exc_info=True
)
write_flag(flag_path.format(local_rank), "info", "Training error")
elif isinstance(exception, LLMMetricException):
logging.error(
"Validation metric failed. Please make sure selected validation "
"metric is suitable for your current problem setup.",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "Metric error")
elif isinstance(exception, LLMAugmentationsException):
logging.error(
"Custom augmentations error occurred during " "H2O LLM Studio run:",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "Augmentations error")
elif isinstance(exception, LLMModelException):
logging.error(
"Model error occurred during H2O LLM Studio run:",
exc_info=True,
)
write_flag(flag_path.format(local_rank), "info", "Model error")
else:
logging.error(
"Exception occurred during H2O LLM Studio run:", exc_info=True
)
write_flag(flag_path.format(local_rank), "info", "See logs")
kill_ddp_processes()