forked from time-series-foundation-models/lag-llama
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pretrain.sh
64 lines (52 loc) · 2.41 KB
/
pretrain.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# PLEASE FOLLOW THE BELOW INSTRUCTIONS FIRST
# 1. Install the requirements. It is recommend to use a new Anaconda environment with Python 3.10.8. Execute the below command (remove the #)
# !pip install -r requirements.txt
# 2. Please download https://drive.google.com/file/d/1JrDWMZyoPsc6d1wAAjgm3PosbGus-jCE/view?usp=sharing and use the below command to download the non-monash datasets (remove the #)
# tar -xvzf nonmonash_datasets.tar.gz -C datasets
# 3. Edit the Weights and Biases arguments on line 59 of this script
mkdir -p experiments
mkdir -p experiments/seeds
mkdir -p experiments/results
EXP_NAME="pretraining_lag_llama"
FILENAME="experiments/seeds/${EXP_NAME}"
CONFIGPATH="configs/lag_llama.json"
echo $EXP_NAME
# NUM_SEEDS used only if it is a new experiment
NUM_SEEDS=1
# Create seeds
if [ -f $FILENAME ]; then
echo "${FILENAME} already exists."
SEEDS=()
while read -r LINE; do
SEEDS+=("$LINE")
done < $FILENAME
echo "Found ${#SEEDS[@]} seeds for training."
else
# Write seeds
echo "${FILENAME} created. Writing seeds."
touch $FILENAME
for (( i = 0; i < $NUM_SEEDS; i++ ))
do
SEED=$((RANDOM + 1))
echo $SEED >> $FILENAME
done
# Read them
SEEDS=()
while read -r LINE; do
SEEDS+=("$LINE")
done < $FILENAME
fi
# Train
for SEED in "${SEEDS[@]}"
do
EXPERIMENT_NAME="${EXP_NAME}_seed_${SEED}"
python run.py \
-e $EXP_NAME -d "datasets" --seed $SEED \
-r "experiments/results" \
--batch_size 512 -m 1000 -n 128 \
--wandb_entity "enter-wandb-entity" --wandb_project "enter-wandb-project" --wandb_tags "enter-wandb-tags-or-remove-this-argument" \
--all_datasets "australian_electricity_demand" "electricity_hourly" "london_smart_meters_without_missing" "solar_10_minutes" "wind_farms_without_missing" "pedestrian_counts" "uber_tlc_hourly" "traffic" "kdd_cup_2018_without_missing" "saugeenday" "sunspot_without_missing" "exchange_rate" "cpu_limit_minute" "cpu_usage_minute" "function_delay_minute" "instances_minute" "memory_limit_minute" "memory_usage_minute" "platform_delay_minute" "requests_minute" "ett_h1" "ett_h2" "ett_m1" "ett_m2" "beijing_pm25" "AirQualityUCI" "beijing_multisite" "weather" \
--test_datasets "weather" "pedestrian_counts" "exchange_rate" "ett_m2" "platform_delay_minute" "requests_minute" "beijing_pm25" \
--num_workers 2 --args_from_dict_path $CONFIGPATH --search_batch_size \
--lr 0.0001
done