-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
begin_training.jl
47 lines (41 loc) · 1.09 KB
/
begin_training.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
using ArgParse
using GomalizingFlow
function parse_commandline()
s = ArgParseSettings()
@add_arg_table! s begin
"config"
help = """
specify path/to/a/toml/file
you can find an example 'cfgs/example2d.toml'
"""
required = true
"--device"
help = "override Device ID"
default = nothing
"--result"
help = "path/to/result/dir"
default = "result"
"--pretrained"
help = "load /path/to/trained_model.bson and train with the model"
default = nothing
end
return parse_args(s)
end
function main()
args = parse_commandline()
path = args["config"]
device_id = nothing
if !isnothing(args["device"])
device_id = parse(Int, args["device"])
end
pretrained = nothing
if !isnothing(args["pretrained"])
pretrained = abspath(args["pretrained"])
end
result = abspath(args["result"])
hp = GomalizingFlow.load_hyperparams(path; device_id, pretrained, result)
GomalizingFlow.train(hp)
end
if abspath(PROGRAM_FILE) == @__FILE__
main()
end