Skip to content

Commit c00f147

Browse files
[raycicmd] Parse array adjustments config (#463)
Add arrayAdjustment struct and parseArrayAdjustments() to parse the adjustments field nested under array. Each adjustment has a required `with` map (dimension→value) and optional `skip` (bool). The `adjustments` key in the array map is reserved and excluded from dimension parsing. Duplicate `with` values are rejected. Topic: parse-adjustments Signed-off-by: andrew <andrew@anyscale.com> Signed-off-by: andrew <andrew@anyscale.com>
1 parent d450ad6 commit c00f147

2 files changed

Lines changed: 225 additions & 0 deletions

File tree

raycicmd/array_config.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,19 @@ type arrayElement struct {
1818
}
1919

2020
// parseArrayConfig parses the array field from a step.
21+
// The "adjustments" key, if present, is reserved and not treated as a dimension.
2122
//
2223
// array:
2324
// python: ["3.10", "3.11"]
2425
// cuda: ["12.1.1", "12.8.1"]
26+
// adjustments:
27+
// - with:
28+
// python: "3.10"
29+
// cuda: "12.1.1"
30+
// skip: true
31+
// - with:
32+
// python: "3.12"
33+
// cuda: "12.8.1"
2534
func parseArrayConfig(v any) (*arrayConfig, error) {
2635
m, ok := v.(map[string]any)
2736
if !ok {
@@ -35,6 +44,9 @@ func parseArrayConfig(v any) (*arrayConfig, error) {
3544
dims: make(map[string][]string, len(m)),
3645
}
3746
for dim, vals := range m {
47+
if dim == "adjustments" {
48+
continue // handled separately
49+
}
3850
valsSlice, ok := vals.([]any)
3951
if !ok {
4052
return nil, fmt.Errorf("array.%s must be an array, got %T", dim, vals)
@@ -48,6 +60,9 @@ func parseArrayConfig(v any) (*arrayConfig, error) {
4860
}
4961
cfg.dims[dim] = values
5062
}
63+
if len(cfg.dims) == 0 {
64+
return nil, fmt.Errorf("array must have at least one dimension")
65+
}
5166
return cfg, nil
5267
}
5368

@@ -63,6 +78,20 @@ func toStringSlice(arr []any) ([]string, error) {
6378
return result, nil
6479
}
6580

81+
func toStringMap(m map[string]any) (map[string]string, error) {
82+
result := make(map[string]string, len(m))
83+
for k, v := range m {
84+
s, ok := v.(string)
85+
if !ok {
86+
return nil, fmt.Errorf(
87+
"value for key %q must be a string, got %T", k, v,
88+
)
89+
}
90+
result[k] = s
91+
}
92+
return result, nil
93+
}
94+
6695
func (cfg *arrayConfig) sortedDimensions() []string {
6796
dims := make([]string, 0, len(cfg.dims))
6897
for dim := range cfg.dims {
@@ -156,3 +185,89 @@ func (elem *arrayElement) substituteString(s string) string {
156185
func hasArrayPlaceholder(s string) bool {
157186
return strings.Contains(s, "{{array.")
158187
}
188+
189+
// arrayAdjustment represents a single adjustment to the array expansion.
190+
type arrayAdjustment struct {
191+
with map[string]string
192+
skip bool
193+
}
194+
195+
// parseArrayAdjustments parses the adjustments field from a step.
196+
func parseArrayAdjustments(v any) ([]*arrayAdjustment, error) {
197+
arr, ok := v.([]any)
198+
if !ok {
199+
return nil, fmt.Errorf(
200+
"adjustments must be an array, got %T", v,
201+
)
202+
}
203+
204+
seen := make(map[string]struct{})
205+
var result []*arrayAdjustment
206+
for i, item := range arr {
207+
m, ok := item.(map[string]any)
208+
if !ok {
209+
return nil, fmt.Errorf(
210+
"adjustments[%d] must be a map, got %T", i, item,
211+
)
212+
}
213+
214+
adj, err := parseAdjustment(m)
215+
if err != nil {
216+
return nil, fmt.Errorf("adjustments[%d]: %w", i, err)
217+
}
218+
219+
key := marshalStringMap(adj.with)
220+
if _, dup := seen[key]; dup {
221+
return nil, fmt.Errorf(
222+
"adjustments[%d]: duplicate \"with\" %v",
223+
i, adj.with,
224+
)
225+
}
226+
seen[key] = struct{}{}
227+
228+
result = append(result, adj)
229+
}
230+
return result, nil
231+
}
232+
233+
// marshalStringMap returns a deterministic string for dedup of string maps.
234+
func marshalStringMap(m map[string]string) string {
235+
keys := make([]string, 0, len(m))
236+
for k := range m {
237+
keys = append(keys, k)
238+
}
239+
sort.Strings(keys)
240+
var parts []string
241+
for _, k := range keys {
242+
parts = append(parts, k+"="+m[k])
243+
}
244+
return strings.Join(parts, ",")
245+
}
246+
247+
func parseAdjustment(m map[string]any) (*arrayAdjustment, error) {
248+
withVal, ok := m["with"]
249+
if !ok {
250+
return nil, fmt.Errorf("missing required \"with\" key")
251+
}
252+
withMap, ok := withVal.(map[string]any)
253+
if !ok {
254+
return nil, fmt.Errorf("\"with\" must be a map, got %T", withVal)
255+
}
256+
if len(withMap) == 0 {
257+
return nil, fmt.Errorf("\"with\" cannot be empty")
258+
}
259+
with, err := toStringMap(withMap)
260+
if err != nil {
261+
return nil, fmt.Errorf("with: %w", err)
262+
}
263+
264+
skip, _ := m["skip"].(bool)
265+
266+
for k := range m {
267+
if k != "with" && k != "skip" {
268+
return nil, fmt.Errorf("unknown key %q", k)
269+
}
270+
}
271+
272+
return &arrayAdjustment{with: with, skip: skip}, nil
273+
}

raycicmd/array_config_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,111 @@ func TestHasArrayPlaceholder(t *testing.T) {
125125
}
126126
}
127127

128+
func TestParseArrayAdjustments(t *testing.T) {
129+
input := []any{
130+
map[string]any{
131+
"with": map[string]any{"os": "windows", "arch": "arm64"},
132+
"skip": true,
133+
},
134+
map[string]any{
135+
"with": map[string]any{"os": "Plan 9", "arch": "arm64"},
136+
},
137+
}
138+
139+
adjs, err := parseArrayAdjustments(input)
140+
if err != nil {
141+
t.Fatalf("parseArrayAdjustments() error = %v", err)
142+
}
143+
if len(adjs) != 2 {
144+
t.Fatalf("len(adjs) = %d, want 2", len(adjs))
145+
}
146+
147+
if !adjs[0].skip {
148+
t.Error("adjs[0].skip = false, want true")
149+
}
150+
if adjs[0].with["os"] != "windows" || adjs[0].with["arch"] != "arm64" {
151+
t.Errorf("adjs[0].with = %v, want {os:windows, arch:arm64}", adjs[0].with)
152+
}
153+
154+
if adjs[1].skip {
155+
t.Error("adjs[1] should be a pure addition, got skip=true")
156+
}
157+
}
158+
159+
func TestParseArrayAdjustmentsErrors(t *testing.T) {
160+
tests := []struct {
161+
name string
162+
input any
163+
wantErr string
164+
}{
165+
{
166+
name: "not an array",
167+
input: "bad",
168+
wantErr: "must be an array",
169+
},
170+
{
171+
name: "element not a map",
172+
input: []any{"bad"},
173+
wantErr: "must be a map",
174+
},
175+
{
176+
name: "missing with",
177+
input: []any{map[string]any{"skip": true}},
178+
wantErr: "missing required \"with\"",
179+
},
180+
{
181+
name: "with not a map",
182+
input: []any{map[string]any{"with": "bad"}},
183+
wantErr: "\"with\" must be a map",
184+
},
185+
{
186+
name: "empty with",
187+
input: []any{map[string]any{"with": map[string]any{}}},
188+
wantErr: "\"with\" cannot be empty",
189+
},
190+
{
191+
name: "with value not a string",
192+
input: []any{map[string]any{
193+
"with": map[string]any{"os": 42},
194+
}},
195+
wantErr: "must be a string",
196+
},
197+
{
198+
name: "duplicate with",
199+
input: []any{
200+
map[string]any{
201+
"with": map[string]any{"os": "linux"},
202+
"skip": true,
203+
},
204+
map[string]any{
205+
"with": map[string]any{"os": "linux"},
206+
},
207+
},
208+
wantErr: "duplicate \"with\"",
209+
},
210+
{
211+
name: "unknown key",
212+
input: []any{map[string]any{
213+
"with": map[string]any{"os": "linux"},
214+
"skp": true,
215+
}},
216+
wantErr: "unknown key",
217+
},
218+
}
219+
220+
for _, tt := range tests {
221+
t.Run(tt.name, func(t *testing.T) {
222+
_, err := parseArrayAdjustments(tt.input)
223+
if err == nil {
224+
t.Fatal("parseArrayAdjustments() expected error, got nil")
225+
}
226+
if !strings.Contains(err.Error(), tt.wantErr) {
227+
t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
228+
}
229+
})
230+
}
231+
}
232+
128233
func TestParseArrayConfigErrors(t *testing.T) {
129234
tests := []struct {
130235
name string
@@ -156,6 +261,11 @@ func TestParseArrayConfigErrors(t *testing.T) {
156261
input: map[string]any{"python": []any{}},
157262
wantErr: "cannot be empty",
158263
},
264+
{
265+
name: "only adjustments key",
266+
input: map[string]any{"adjustments": []any{}},
267+
wantErr: "at least one dimension",
268+
},
159269
}
160270

161271
for _, tt := range tests {

0 commit comments

Comments
 (0)