-
Notifications
You must be signed in to change notification settings - Fork 24
/
test.py
75 lines (66 loc) · 2.24 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, '../')))
if 'OMP_NUM_THREADS' not in os.environ:
os.environ['OMP_NUM_THREADS'] = '8'
import sys
import argparse
import socket
from contextlib import closing
def parse_args():
parser = argparse.ArgumentParser(description='Test and eval a model')
parser.add_argument('config', help='config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--data',
type=str,
nargs='+')
parser.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use')
parser.add_argument('--seed', type=int, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
args = parser.parse_args()
return args
def args_to_str(args):
argv = [args.config, args.checkpoint]
if args.seed is not None:
argv += ['--seed', str(args.seed)]
if args.deterministic:
argv.append('--deterministic')
if args.data is not None:
argv += ['--data'] + args.data
return argv
def main():
args = parse_args()
if args.gpu_ids is not None:
gpu_ids = args.gpu_ids
elif 'CUDA_VISIBLE_DEVICES' in os.environ:
gpu_ids = [int(i) for i in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
else:
gpu_ids = [0]
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in gpu_ids])
if len(gpu_ids) == 1:
import tools.test
sys.argv = [''] + args_to_str(args)
tools.test.main()
else:
from torch.distributed import launch
for port in range(29500, 65536):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
res = sock.connect_ex(('localhost', port))
if res != 0:
break
sys.argv = ['',
'--nproc_per_node={}'.format(len(gpu_ids)),
'--master_port={}'.format(port),
'./tools/test.py'
] + args_to_str(args) + ['--launcher', 'pytorch', '--diff_seed']
launch.main()
if __name__ == '__main__':
main()