-
Notifications
You must be signed in to change notification settings - Fork 2
/
01_pretrain.py
51 lines (39 loc) · 1.26 KB
/
01_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import sys
from time import time
from label_the_sky.training.trainer import Trainer, set_random_seeds
set_random_seeds()
if len(sys.argv) != 6:
print('usage: python {} <dataset> <backbone> <n_channels> <target> <timestamp>'.format(
sys.argv[0]))
exit(1)
dataset = sys.argv[1]
backbone = sys.argv[2]
n_channels = int(sys.argv[3])
target = sys.argv[4]
timestamp = sys.argv[5]
base_dir = os.environ['HOME']
trainer = Trainer(
backbone=backbone,
n_channels=n_channels,
output_type=target,
base_dir=base_dir,
weights=None,
model_name=f'{timestamp}_{backbone}_{n_channels}_{dataset}'
)
trainer.describe(verbose=True)
print('loading data')
X_train, y_train = trainer.load_data(dataset=dataset, split='train')
X_val, y_val = trainer.load_data(dataset=dataset, split='val')
X_test, y_test = trainer.load_data(dataset=dataset, split='test')
start = time()
print('pretraining model')
trainer.train(X_train, y_train, X_val, y_val, epochs=100)
trainer.dump_history('history')
print('--- minutes taken:', int((time() - start) / 60))
print('evaluating model on validation set')
trainer.pick_best_model()
trainer.evaluate(X_val, y_val)
print('--- minutes taken:', int((time() - start) / 60))
print('printing history')
trainer.print_history()