-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
84 lines (68 loc) · 2.11 KB
/
eval.py
File metadata and controls
84 lines (68 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import torch
import os
from copy import deepcopy
from tqdm import tqdm
import utils
from video import VideoRecorder
from arguments import parse_args
from env.wrappers import make_pad_env
from agent.agent import make_agent
from utils import get_curl_pos_neg
def evaluate(env, agent, args, video, adapt=False):
episode_rewards = []
for i in tqdm(range(args.ear_num_episodes)):
ep_agent = deepcopy(agent)
video.init(enabled=True)
obs = env.reset()
done = False
episode_reward = 0
losses = []
step = 0
ep_agent.train()
while not done:
with utils.eval_mode(ep_agent):
action = ep_agent.select_action(obs)
next_obs, reward, done, _ = env.step(action)
episode_reward += reward
video.record(env, losses)
obs = next_obs
step += 1
video.save(f'{args.mode}_ear_{i}.mp4' if adapt else f'{args.mode}_eval_{i}.mp4')
episode_rewards.append(episode_reward)
return np.mean(episode_rewards)
def init_env(args):
utils.set_seed_everywhere(args.seed)
return make_pad_env(
domain_name=args.domain_name,
task_name=args.task_name,
seed=args.seed,
episode_length=args.episode_length,
action_repeat=args.action_repeat,
mode=args.mode
)
def main(args):
env = init_env(args)
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
video = VideoRecorder(video_dir if args.save_video else None, height=448, width=448)
assert torch.cuda.is_available(), 'must have cuda enabled'
cropped_obs_shape = (3*args.frame_stack, 84, 84)
agent = make_agent(
obs_shape=cropped_obs_shape,
action_shape=env.action_space.shape,
args=args
)
agent.load(model_dir, args.ear_checkpoint)
print(f'Evaluating {args.work_dir} for {args.ear_num_episodes} episodes (mode: {args.mode})')
eval_reward = evaluate(env, agent, args, video)
print('eval reward:', int(eval_reward))
results_fp = os.path.join(args.work_dir, f'pad_{args.mode}.pt')
torch.save({
'args': args,
'eval_reward': eval_reward
}, results_fp)
print('Saved results to', results_fp)
if __name__ == '__main__':
args = parse_args()
main(args)