|
4 | 4 | import os
|
5 | 5 | import sys
|
6 | 6 | import math
|
| 7 | +import time |
7 | 8 |
|
8 | 9 | import torch
|
9 | 10 | import torch.optim as optim
|
10 | 11 | import torch.multiprocessing as mp
|
11 | 12 | import torch.nn as nn
|
12 | 13 | import torch.nn.functional as F
|
| 14 | +import tensorboard_logger as tb |
| 15 | + |
| 16 | +import my_optim |
13 | 17 | from envs import create_atari_env
|
14 | 18 | from model import ActorCritic
|
15 | 19 | from train import train
|
16 | 20 | from test import test
|
17 | 21 | from utils import logger
|
18 |
| -import my_optim |
| 22 | +from utils.shared_memory import SharedCounter |
| 23 | + |
19 | 24 |
|
20 | 25 | logger = logger.getLogger('main')
|
21 | 26 |
|
|
41 | 46 | help='environment to train on (default: PongDeterministic-v3)')
|
42 | 47 | parser.add_argument('--no-shared', default=False, metavar='O',
|
43 | 48 | help='use an optimizer without shared momentum.')
|
44 |
| -parser.add_argument('--max-iters', type=int, default=math.inf, |
45 |
| - help='maximum iterations per process.') |
46 |
| - |
| 49 | +parser.add_argument('--max-episode-count', type=int, default=math.inf, |
| 50 | + help='maximum number of episodes to run per process.') |
47 | 51 | parser.add_argument('--debug', action='store_true', default=False,
|
48 | 52 | help='run in a way its easier to debug')
|
| 53 | +parser.add_argument('--short-description', default='no_descr', |
| 54 | + help='Short description of the run params, (used in tensorboard)') |
| 55 | + |
| 56 | +def setup_loggings(args): |
| 57 | + logger.debug('CONFIGURATION: {}'.format(args)) |
| 58 | + |
| 59 | + cur_path = os.path.dirname(os.path.realpath(__file__)) |
| 60 | + args.summ_base_dir = (cur_path+'/runs/{}/{}({})').format(args.env_name, |
| 61 | + time.strftime('%d.%m-%H.%M'), args.short_description) |
| 62 | + logger.info('logging run logs to {}'.format(args.summ_base_dir)) |
| 63 | + tb.configure(args.summ_base_dir) |
49 | 64 |
|
50 | 65 | if __name__ == '__main__':
|
51 | 66 | args = parser.parse_args()
|
52 |
| - |
| 67 | + setup_loggings(args) |
53 | 68 | torch.manual_seed(args.seed)
|
| 69 | + |
54 | 70 | env = create_atari_env(args.env_name)
|
55 | 71 | shared_model = ActorCritic(
|
56 | 72 | env.observation_space.shape[0], env.action_space)
|
|
61 | 77 | else:
|
62 | 78 | optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr)
|
63 | 79 | optimizer.share_memory()
|
64 |
| - |
| 80 | + |
| 81 | + gl_step_cnt = SharedCounter() |
65 | 82 |
|
66 | 83 | if not args.debug:
|
67 | 84 | processes = []
|
68 | 85 |
|
69 |
| - p = mp.Process(target=test, args=(args.num_processes, args, shared_model)) |
| 86 | + p = mp.Process(target=test, args=(args.num_processes, args, |
| 87 | + shared_model, gl_step_cnt)) |
70 | 88 | p.start()
|
71 | 89 | processes.append(p)
|
72 | 90 | for rank in range(0, args.num_processes):
|
73 |
| - p = mp.Process(target=train, args=(rank, args, shared_model, optimizer)) |
| 91 | + p = mp.Process(target=train, args=(rank, args, shared_model, |
| 92 | + gl_step_cnt, optimizer)) |
74 | 93 | p.start()
|
75 | 94 | processes.append(p)
|
76 | 95 | for p in processes:
|
77 | 96 | p.join()
|
78 | 97 | else: ## debug is enabled
|
79 | 98 | # run only one process in a main, easier to debug
|
80 |
| - train(0, args, shared_model, optimizer) |
| 99 | + train(0, args, shared_model, gl_step_cnt, optimizer) |
0 commit comments