|
3 | 3 | #include "puffernet.h" |
4 | 4 |
|
5 | 5 | void demo() { |
6 | | - Weights* weights = load_weights("resources/breakout/breakout_weights.bin", 147844); |
| 6 | + Weights* weights = load_weights("resources/breakout/breakout_weights.bin", 32384); |
7 | 7 | int logit_sizes[1] = {3}; |
8 | | - LinearLSTM* net = make_linearlstm(weights, 1, 118, logit_sizes, 1); |
| 8 | + PufferNet* net = make_puffernet(weights, 1, 118, 64, 2, logit_sizes, 1); |
9 | 9 |
|
10 | 10 | Breakout env = { |
11 | 11 | .frameskip = 1, |
@@ -46,109 +46,19 @@ void demo() { |
46 | 46 | } |
47 | 47 | } else if (frame % 4 == 0) { |
48 | 48 | // Apply frameskip outside the env for smoother rendering |
49 | | - int* actions = (int*)env.actions; |
50 | | - forward_linearlstm(net, env.observations, actions); |
51 | | - env.actions[0] = actions[0]; |
| 49 | + forward_puffernet(net, env.observations, env.actions); |
52 | 50 | } |
53 | 51 |
|
54 | 52 | frame = (frame + 1) % 4; |
55 | 53 | c_step(&env); |
56 | 54 | c_render(&env); |
57 | 55 | } |
58 | | - free_linearlstm(net); |
| 56 | + free_puffernet(net); |
59 | 57 | free(weights); |
60 | 58 | free_allocated(&env); |
61 | 59 | close_client(env.client); |
62 | 60 | } |
63 | 61 |
|
64 | | -void test_performance(int timeout) { |
65 | | - Breakout env = { |
66 | | - .num_agents = 1, |
67 | | - .frameskip = 4, |
68 | | - .width = 576, |
69 | | - .height = 330, |
70 | | - .initial_paddle_width = 62, |
71 | | - .paddle_width = 62, |
72 | | - .paddle_height = 8, |
73 | | - .ball_width = 32, |
74 | | - .ball_height = 32, |
75 | | - .brick_width = 32, |
76 | | - .brick_height = 12, |
77 | | - .brick_rows = 6, |
78 | | - .brick_cols = 18, |
79 | | - .initial_ball_speed = 256, |
80 | | - .max_ball_speed = 448, |
81 | | - .paddle_speed = 620, |
82 | | - .continuous = 0, |
83 | | - }; |
84 | | - allocate(&env); |
85 | | - c_reset(&env); |
86 | | - |
87 | | - int start = time(NULL); |
88 | | - int num_steps = 0; |
89 | | - while (time(NULL) - start < timeout) { |
90 | | - for (int i = 0; i < 1000; i++) { |
91 | | - //env.actions[0] = 1;//rand() % 3; |
92 | | - c_step(&env); |
93 | | - num_steps++; |
94 | | - } |
95 | | - } |
96 | | - |
97 | | - int end = time(NULL); |
98 | | - float sps = num_steps / (end - start); |
99 | | - printf("Test Environment SPS: %f\n", sps); |
100 | | - free_allocated(&env); |
101 | | -} |
102 | | - |
103 | | -void test_performance_multi(int num_envs, int timeout) { |
104 | | - Breakout* envs = (Breakout*)calloc(num_envs, sizeof(Breakout)); |
105 | | - for (int i = 0; i < num_envs; i++) { |
106 | | - envs[i] = (Breakout){ |
107 | | - .num_agents = 1, |
108 | | - .frameskip = 4, |
109 | | - .width = 576, |
110 | | - .height = 330, |
111 | | - .initial_paddle_width = 62, |
112 | | - .paddle_width = 62, |
113 | | - .paddle_height = 8, |
114 | | - .ball_width = 32, |
115 | | - .ball_height = 32, |
116 | | - .brick_width = 32, |
117 | | - .brick_height = 12, |
118 | | - .brick_rows = 6, |
119 | | - .brick_cols = 18, |
120 | | - .initial_ball_speed = 256, |
121 | | - .max_ball_speed = 448, |
122 | | - .paddle_speed = 620, |
123 | | - .continuous = 0, |
124 | | - }; |
125 | | - allocate(&envs[i]); |
126 | | - c_reset(&envs[i]); |
127 | | - } |
128 | | - |
129 | | - int start = time(NULL); |
130 | | - int num_steps = 0; |
131 | | - while (time(NULL) - start < timeout) { |
132 | | - for (int i = 0; i < num_envs; i++) { |
133 | | - envs[i].actions[0] = 1; |
134 | | - c_step(&envs[i]); |
135 | | - num_steps++; |
136 | | - } |
137 | | - } |
138 | | - |
139 | | - int end = time(NULL); |
140 | | - float sps = num_steps / (end - start); |
141 | | - printf("Test Environment SPS: %f\n", sps); |
142 | | - |
143 | | - for (int i = 0; i < num_envs; i++) { |
144 | | - free_allocated(&envs[i]); |
145 | | - } |
146 | | - free(envs); |
147 | | -} |
148 | | - |
149 | | - |
150 | 62 | int main() { |
151 | | - //demo(); |
152 | | - //test_performance(5); |
153 | | - test_performance_multi(65536, 5); |
| 63 | + demo(); |
154 | 64 | } |
0 commit comments