diff --git a/cypher/models/cypher/format/format.go b/cypher/models/cypher/format/format.go index 495cf806..e8b97adc 100644 --- a/cypher/models/cypher/format/format.go +++ b/cypher/models/cypher/format/format.go @@ -329,7 +329,7 @@ func (s Emitter) formatMapLiteral(output io.Writer, mapLiteral cypher.MapLiteral } first := true - for key, subExpression := range mapLiteral { + if err := mapLiteral.ForEachItem(func(key string, value cypher.Expression) error { if !first { if _, err := io.WriteString(output, ", "); err != nil { return err @@ -346,9 +346,13 @@ func (s Emitter) formatMapLiteral(output io.Writer, mapLiteral cypher.MapLiteral return err } - if err := s.WriteExpression(output, subExpression); err != nil { + if err := s.WriteExpression(output, value); err != nil { return err } + + return nil + }); err != nil { + return err } if _, err := io.WriteString(output, "}"); err != nil { diff --git a/cypher/models/cypher/format/format_test.go b/cypher/models/cypher/format/format_test.go index 327f65d4..5ebd1b5b 100644 --- a/cypher/models/cypher/format/format_test.go +++ b/cypher/models/cypher/format/format_test.go @@ -2,6 +2,7 @@ package format_test import ( "bytes" + "errors" "testing" "github.com/specterops/dawgs/cypher/models/cypher" @@ -27,6 +28,93 @@ func TestCypherEmitter_StripLiterals(t *testing.T) { require.Equal(t, "match (n {value: $STRIPPED}) where n.other = $STRIPPED and n.number = $STRIPPED return n.name, n", buffer.String()) } +func TestCypherEmitter_FormatsMapLiteralInKeyOrder(t *testing.T) { + var ( + buffer = &bytes.Buffer{} + emitter = format.NewCypherEmitter(false) + ) + + err := emitter.WriteExpression(buffer, cypher.MapLiteral{ + "b": cypher.NewLiteral(2, false), + "a": cypher.NewLiteral(1, false), + }) + + require.NoError(t, err) + require.Equal(t, "{a: 1, b: 2}", buffer.String()) +} + +func TestCypherEmitter_MapLiteralPropagatesExpressionError(t *testing.T) { + var ( + buffer = &bytes.Buffer{} + emitter = format.NewCypherEmitter(false) + ) + + err := emitter.WriteExpression(buffer, cypher.MapLiteral{ + "bad": struct{}{}, + }) + + require.ErrorContains(t, err, "unexpected expression type") +} + +func TestCypherEmitter_MapLiteralPropagatesWriterError(t *testing.T) { + expectedErr := errors.New("write failed") + testCases := map[string]struct { + allowedWrites int + mapLiteral cypher.MapLiteral + }{ + "opening delimiter": { + allowedWrites: 0, + mapLiteral: cypher.MapLiteral{ + "b": cypher.NewLiteral(2, false), + }, + }, + "key": { + allowedWrites: 1, + mapLiteral: cypher.MapLiteral{ + "b": cypher.NewLiteral(2, false), + }, + }, + "colon": { + allowedWrites: 2, + mapLiteral: cypher.MapLiteral{ + "b": cypher.NewLiteral(2, false), + }, + }, + "value": { + allowedWrites: 3, + mapLiteral: cypher.MapLiteral{ + "b": cypher.NewLiteral(2, false), + }, + }, + "item separator": { + allowedWrites: 4, + mapLiteral: cypher.MapLiteral{ + "a": cypher.NewLiteral(1, false), + "b": cypher.NewLiteral(2, false), + }, + }, + "closing delimiter": { + allowedWrites: 4, + mapLiteral: cypher.MapLiteral{ + "b": cypher.NewLiteral(2, false), + }, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + writer := &errorAfterNWrites{ + remaining: testCase.allowedWrites, + err: expectedErr, + } + + err := format.NewCypherEmitter(false).WriteExpression(writer, testCase.mapLiteral) + + require.ErrorIs(t, err, expectedErr) + }) + } +} + func TestCypherEmitter_HappyPath(t *testing.T) { test.LoadFixture(t, test.MutationTestCases).Run(t) test.LoadFixture(t, test.PositiveTestCases).Run(t) @@ -83,6 +171,20 @@ func TestNewStringLiteral_Escaping(t *testing.T) { } } +type errorAfterNWrites struct { + remaining int + err error +} + +func (s *errorAfterNWrites) Write(p []byte) (int, error) { + if s.remaining == 0 { + return 0, s.err + } + + s.remaining-- + return len(p), nil +} + func TestNewStringLiteral_InQuery(t *testing.T) { // Test that escaped string literals work correctly in actual Cypher queries testCases := []struct { diff --git a/cypher/models/cypher/model.go b/cypher/models/cypher/model.go index b3a59b77..173919b8 100644 --- a/cypher/models/cypher/model.go +++ b/cypher/models/cypher/model.go @@ -921,30 +921,49 @@ func (s MapLiteral) copy() MapLiteral { return mapCopy } +func (s MapLiteral) sortedKeys() []string { + keys := make([]string, 0, len(s)) + + for key := range s { + keys = append(keys, key) + } + + sort.Strings(keys) + return keys +} + func (s MapLiteral) Items() []*MapItem { items := make([]*MapItem, 0, len(s)) - for key, value := range s { + _ = s.ForEachItem(func(key string, value Expression) error { items = append(items, &MapItem{ Key: key, Value: value, }) - } + return nil + }) return items } +func (s MapLiteral) ForEachItem(delegate func(key string, value Expression) error) error { + for _, key := range s.sortedKeys() { + if err := delegate(key, s[key]); err != nil { + return err + } + } + + return nil +} + func (s MapLiteral) Keys() []any { - keys := make([]any, 0, len(s)) + sortedKeys := s.sortedKeys() + keys := make([]any, len(sortedKeys)) - for key := range s { - keys = append(keys, key) + for idx, key := range sortedKeys { + keys[idx] = key } - sort.Slice(keys, func(i, j int) bool { - return strings.Compare(keys[i].(string), keys[j].(string)) > 0 - }) - return keys } diff --git a/cypher/models/cypher/model_test.go b/cypher/models/cypher/model_test.go new file mode 100644 index 00000000..60acaab7 --- /dev/null +++ b/cypher/models/cypher/model_test.go @@ -0,0 +1,74 @@ +package cypher_test + +import ( + "errors" + "testing" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/stretchr/testify/require" +) + +func TestMapLiteralItemsReturnsSortedItems(t *testing.T) { + aValue := cypher.NewVariableWithSymbol("a_value") + bValue := cypher.NewVariableWithSymbol("b_value") + + items := cypher.MapLiteral{ + "b": bValue, + "a": aValue, + }.Items() + + require.Len(t, items, 2) + require.Equal(t, "a", items[0].Key) + require.Same(t, aValue, items[0].Value) + require.Equal(t, "b", items[1].Key) + require.Same(t, bValue, items[1].Value) +} + +func TestMapLiteralForEachItemReturnsDelegateError(t *testing.T) { + expectedErr := errors.New("stop iteration") + var visitedKeys []string + + err := cypher.MapLiteral{ + "c": cypher.NewVariableWithSymbol("c_value"), + "b": cypher.NewVariableWithSymbol("b_value"), + "a": cypher.NewVariableWithSymbol("a_value"), + }.ForEachItem(func(key string, _ cypher.Expression) error { + visitedKeys = append(visitedKeys, key) + if key == "b" { + return expectedErr + } + + return nil + }) + + require.ErrorIs(t, err, expectedErr) + require.Equal(t, []string{"a", "b"}, visitedKeys) +} + +func TestMapLiteralKeysReturnsSortedKeys(t *testing.T) { + testCases := map[string]struct { + mapLiteral cypher.MapLiteral + expected []any + }{ + "nil": { + expected: []any{}, + }, + "empty": { + mapLiteral: cypher.MapLiteral{}, + expected: []any{}, + }, + "sorted": { + mapLiteral: cypher.MapLiteral{ + "b": cypher.NewVariableWithSymbol("b_value"), + "a": cypher.NewVariableWithSymbol("a_value"), + }, + expected: []any{"a", "b"}, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + require.Equal(t, testCase.expected, testCase.mapLiteral.Keys()) + }) + } +} diff --git a/cypher/models/pgsql/translate/predicate_test.go b/cypher/models/pgsql/translate/predicate_test.go index d9182a10..aac3226c 100644 --- a/cypher/models/pgsql/translate/predicate_test.go +++ b/cypher/models/pgsql/translate/predicate_test.go @@ -79,6 +79,12 @@ func translatePredicateQuery(t *testing.T, cypherQuery string, parameters map[st return formatted } +func TestExclusiveDisjunctionTranslates(t *testing.T) { + formatted := translatePredicateQuery(t, `MATCH (n:NodeKind1) WHERE true XOR false RETURN n`, nil) + + require.Contains(t, formatted, "true != false") +} + func TestDynamicStringPredicatesUseHelperFunctions(t *testing.T) { for _, testCase := range []struct { name string diff --git a/cypher/models/pgsql/translate/references_test.go b/cypher/models/pgsql/translate/references_test.go index b9a71583..894109a7 100644 --- a/cypher/models/pgsql/translate/references_test.go +++ b/cypher/models/pgsql/translate/references_test.go @@ -48,6 +48,13 @@ func TestCollectReferencedIdentifiersIncludesPatternPredicateReferences(t *testi require.True(t, referencedIdentifiers.Contains("r")) } +func TestCollectReferencedIdentifiersIncludesExclusiveDisjunctionOperands(t *testing.T) { + referencedIdentifiers := referencedIdentifiersForQuery(t, "match (n), (m) where n.enabled = true xor m.enabled = true return n") + + require.True(t, referencedIdentifiers.Contains("n")) + require.True(t, referencedIdentifiers.Contains("m")) +} + func TestCollectReferencedIdentifiersIncludesRepeatedMatchPatternDeclarations(t *testing.T) { referencedIdentifiers := referencedIdentifiersForQuery(t, "match (a)-->(b) match (b)-->(c) return c") diff --git a/cypher/models/pgsql/translate/translator.go b/cypher/models/pgsql/translate/translator.go index a5fe459c..b300fcdf 100644 --- a/cypher/models/pgsql/translate/translator.go +++ b/cypher/models/pgsql/translate/translator.go @@ -298,6 +298,11 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { s.treeTranslator.VisitOperator(pgsql.OperatorOr) } + case *cypher.ExclusiveDisjunction: + for idx := 0; idx < typedExpression.Len()-1; idx++ { + s.treeTranslator.VisitOperator(pgsql.OperatorNotEquals) + } + case *cypher.Conjunction: for idx := 0; idx < typedExpression.Len()-1; idx++ { s.treeTranslator.VisitOperator(pgsql.OperatorAnd) @@ -559,6 +564,13 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } } + case *cypher.ExclusiveDisjunction: + for idx := 0; idx < typedExpression.Len()-1; idx++ { + if err := s.treeTranslator.CompleteBinaryExpression(s.scope, pgsql.OperatorNotEquals); err != nil { + s.SetError(err) + } + } + case *cypher.Conjunction: for idx := 0; idx < typedExpression.Len()-1; idx++ { if err := s.treeTranslator.CompleteBinaryExpression(s.scope, pgsql.OperatorAnd); err != nil { diff --git a/cypher/models/walk/walk.go b/cypher/models/walk/walk.go index 2d6f95d6..45079f61 100644 --- a/cypher/models/walk/walk.go +++ b/cypher/models/walk/walk.go @@ -3,6 +3,7 @@ package walk import ( "errors" "fmt" + "reflect" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/pgsql" @@ -29,7 +30,7 @@ type Visitor[N any] interface { type cancelableVisitorHandler struct { currentSyntaxNodeConsumed bool done bool - errs []error + err error } func (s *cancelableVisitorHandler) Done() bool { @@ -42,7 +43,12 @@ func (s *cancelableVisitorHandler) SetDone() { func (s *cancelableVisitorHandler) SetError(err error) { if err != nil { - s.errs = append(s.errs, err) + if s.err == nil { + s.err = err + } else { + s.err = errors.Join(s.err, err) + } + s.done = true } } @@ -52,7 +58,7 @@ func (s *cancelableVisitorHandler) SetErrorf(format string, args ...any) { } func (s *cancelableVisitorHandler) Error() error { - return errors.Join(s.errs...) + return s.err } func (s *cancelableVisitorHandler) Consume() { @@ -107,8 +113,13 @@ type simpleVisitor[N any] struct { } func NewSimpleVisitor[N any](visitorFunc SimpleVisitorFunc[N]) Visitor[N] { + return NewSimpleVisitorWithOrder(OrderPrefix, visitorFunc) +} + +func NewSimpleVisitorWithOrder[N any](order Order, visitorFunc SimpleVisitorFunc[N]) Visitor[N] { return &simpleVisitor[N]{ Visitor: NewVisitor[N](), + order: order, visitorFunc: visitorFunc, } } @@ -160,6 +171,21 @@ func (s *Cursor[N]) NextBranch() N { return nextBranch } +func isNilNode[N any](node N) bool { + rawNode := any(node) + if rawNode == nil { + return true + } + + value := reflect.ValueOf(rawNode) + switch value.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Pointer: + return value.IsNil() + default: + return false + } +} + func Generic[E any](node E, visitor Visitor[E], cursorConstructor func(node E) (*Cursor[E], error)) error { var stack []*Cursor[E] @@ -181,15 +207,56 @@ func Generic[E any](node E, visitor Visitor[E], cursorConstructor func(node E) ( if err := visitor.Error(); err != nil { return err } + + if visitor.Done() { + return nil + } } - if nextNode.HasNext() && !visitor.WasConsumed() { + if !nextNode.HasNext() { + visitor.Exit(nextNode.Node) + + if err := visitor.Error(); err != nil { + return err + } + + // Clear any consume flag set by Enter or Exit before visiting the next sibling. + visitor.WasConsumed() + stack = stack[0 : len(stack)-1] + } else if visitor.WasConsumed() { + visitor.Exit(nextNode.Node) + + if err := visitor.Error(); err != nil { + return err + } + + // Clear any consume flag set by Exit before visiting the next sibling. + visitor.WasConsumed() + stack = stack[0 : len(stack)-1] + } else { if !isFirstVisit { visitor.Visit(nextNode.Node) if err := visitor.Error(); err != nil { return err } + + if visitor.Done() { + return nil + } + + if visitor.WasConsumed() { + visitor.Exit(nextNode.Node) + + if err := visitor.Error(); err != nil { + return err + } + + // Clear any consume flag set by Exit before visiting the next sibling. + visitor.WasConsumed() + stack = stack[0 : len(stack)-1] + continue + } } if cursor, err := cursorConstructor(nextNode.NextBranch()); err != nil { @@ -197,14 +264,6 @@ func Generic[E any](node E, visitor Visitor[E], cursorConstructor func(node E) ( } else { stack = append(stack, cursor) } - } else { - visitor.Exit(nextNode.Node) - - if err := visitor.Error(); err != nil { - return err - } - - stack = stack[0 : len(stack)-1] } } @@ -218,3 +277,7 @@ func PgSQL(node pgsql.SyntaxNode, visitor Visitor[pgsql.SyntaxNode]) error { func Cypher(node cypher.SyntaxNode, visitor Visitor[cypher.SyntaxNode]) error { return Generic(node, visitor, newCypherWalkCursor) } + +func CypherStructural(node cypher.SyntaxNode, visitor Visitor[cypher.SyntaxNode]) error { + return Generic(node, visitor, newCypherStructuralWalkCursor) +} diff --git a/cypher/models/walk/walk_benchmark_test.go b/cypher/models/walk/walk_benchmark_test.go new file mode 100644 index 00000000..eb871dbe --- /dev/null +++ b/cypher/models/walk/walk_benchmark_test.go @@ -0,0 +1,73 @@ +package walk_test + +import ( + "fmt" + "testing" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/walk" +) + +func BenchmarkCypherWalkLargeProjection(b *testing.B) { + projection := &cypher.Projection{} + for idx := 0; idx < 512; idx++ { + projection.Items = append(projection.Items, &cypher.ProjectionItem{ + Expression: cypher.NewVariableWithSymbol(fmt.Sprintf("n%d", idx)), + }) + } + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(cypher.SyntaxNode, walk.VisitorHandler) {}) + + b.ReportAllocs() + b.ResetTimer() + for idx := 0; idx < b.N; idx++ { + if err := walk.Cypher(projection, visitor); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkCypherWalkLargeMapLiteral(b *testing.B) { + mapLiteral := cypher.NewMapLiteral() + for idx := 0; idx < 512; idx++ { + mapLiteral[fmt.Sprintf("k%03d", idx)] = cypher.NewVariableWithSymbol(fmt.Sprintf("v%d", idx)) + } + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(cypher.SyntaxNode, walk.VisitorHandler) {}) + + b.ReportAllocs() + b.ResetTimer() + for idx := 0; idx < b.N; idx++ { + if err := walk.Cypher(mapLiteral, visitor); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkCypherStructuralWalkLongPattern(b *testing.B) { + patternPart := cypher.NewPatternPart() + patternPart.Variable = cypher.NewVariableWithSymbol("path") + patternPart.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("n0"), + }) + for idx := 0; idx < 128; idx++ { + patternPart.AddPatternElements( + &cypher.RelationshipPattern{ + Variable: cypher.NewVariableWithSymbol(fmt.Sprintf("r%d", idx)), + }, + &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(fmt.Sprintf("n%d", idx+1)), + }, + ) + } + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(cypher.SyntaxNode, walk.VisitorHandler) {}) + + b.ReportAllocs() + b.ResetTimer() + for idx := 0; idx < b.N; idx++ { + if err := walk.CypherStructural(patternPart, visitor); err != nil { + b.Fatal(err) + } + } +} diff --git a/cypher/models/walk/walk_cypher.go b/cypher/models/walk/walk_cypher.go index 7b6fc723..ac81d234 100644 --- a/cypher/models/walk/walk_cypher.go +++ b/cypher/models/walk/walk_cypher.go @@ -8,320 +8,428 @@ import ( "github.com/specterops/dawgs/graph" ) -func cypherSyntaxNodeSliceTypeConvert[F any, FS []F](fs FS) ([]cypher.SyntaxNode, error) { - return ConvertSliceType[cypher.SyntaxNode](fs) +func newCypherWalkCursorWithBranches[F any, FS []F](node cypher.SyntaxNode, branches FS) *Cursor[cypher.SyntaxNode] { + cursor := &Cursor[cypher.SyntaxNode]{ + Node: node, + Branches: make([]cypher.SyntaxNode, 0, len(branches)), + } + + addCypherBranches(cursor, branches) + return cursor } -func newCypherWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], error) { - switch typedNode := node.(type) { - // Types with no AST branches - case *cypher.RangeQuantifier, cypher.Operator, *cypher.Limit, *cypher.Skip, graph.Kinds, *cypher.Parameter: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - }, nil +func newCypherWalkCursorWithBranchPrefix[F any, FS []F](node cypher.SyntaxNode, prefix cypher.SyntaxNode, branches FS) *Cursor[cypher.SyntaxNode] { + cursor := &Cursor[cypher.SyntaxNode]{ + Node: node, + Branches: make([]cypher.SyntaxNode, 0, len(branches)+1), + } - case *cypher.KindMatcher: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Reference}, - }, nil + if !isNilNode(prefix) { + cursor.AddBranches(prefix) + } + addCypherBranches(cursor, branches) + return cursor +} - case *cypher.PropertyLookup: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Atom}, - }, nil +func newCypherWalkCursorWithMapItems(node cypher.SyntaxNode, mapLiteral cypher.MapLiteral) *Cursor[cypher.SyntaxNode] { + cursor := &Cursor[cypher.SyntaxNode]{ + Node: node, + Branches: make([]cypher.SyntaxNode, 0, len(mapLiteral)), + } - case *cypher.MapItem: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Value}, - }, nil + _ = mapLiteral.ForEachItem(func(key string, value cypher.Expression) error { + cursor.AddBranches(&cypher.MapItem{ + Key: key, + Value: value, + }) + return nil + }) - case *cypher.Properties: - if typedNode.Parameter != nil { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Parameter}, - }, nil - } else if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Map.Items()); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: branches, - }, nil - } + return cursor +} - case *cypher.Literal: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - }, nil +func addCypherBranches[F any, FS []F](cursor *Cursor[cypher.SyntaxNode], branches FS) { + for _, branch := range branches { + cursor.AddBranches(cypher.SyntaxNode(branch)) + } +} - case cypher.MapLiteral: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - }, nil +func newCypherStructuralWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], error) { + if isNilNode(node) { + return nil, fmt.Errorf("unable to negotiate cypher model type %T into a translation cursor", node) + } - case *cypher.ListLiteral: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Expressions()); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: typedNode, - Branches: branches, - }, nil + if cursor, handled := newCypherStructuralValueWalkCursor(node); handled { + return cursor, nil + } + if cursor, handled := newCypherStructuralPatternWalkCursor(node); handled { + return cursor, nil + } + + return newCypherWalkCursor(node) +} + +func newCypherStructuralValueWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { + case *cypher.Limit: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Value != nil { + nextCursor.AddBranches(typedNode.Value) } + return nextCursor, true - case *cypher.Create: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Pattern); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: branches, - }, nil + case *cypher.Skip: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Value != nil { + nextCursor.AddBranches(typedNode.Value) } + return nextCursor, true - case *cypher.Unwind: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Expression, typedNode.Variable}, - }, nil + case *cypher.KindMatcher: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Reference != nil { + nextCursor.AddBranches(typedNode.Reference) + } + if typedNode.Kinds != nil { + nextCursor.AddBranches(typedNode.Kinds) + } + return nextCursor, true - case *cypher.RemoveItem: + case *cypher.Properties: nextCursor := &Cursor[cypher.SyntaxNode]{ Node: node, } + if typedNode.Parameter != nil { + nextCursor.AddBranches(typedNode.Parameter) + } + if typedNode.Map != nil { + nextCursor.AddBranches(typedNode.Map) + } + return nextCursor, true + + case cypher.MapLiteral: + return newCypherWalkCursorWithMapItems(node, typedNode), true + case *cypher.RemoveItem: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.KindMatcher != nil { + nextCursor.AddBranches(typedNode.KindMatcher) + } if typedNode.Property != nil { nextCursor.AddBranches(typedNode.Property) } + return nextCursor, true - return nextCursor, nil - - case *cypher.Remove: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Items); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: typedNode, - Branches: branches, - }, nil + case *cypher.IDInCollection: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Variable != nil { + nextCursor.AddBranches(typedNode.Variable) + } + if typedNode.Expression != nil { + nextCursor.AddBranches(typedNode.Expression) } + return nextCursor, true - case *cypher.Delete: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Expressions); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: typedNode, - Branches: branches, - }, nil + case *cypher.ProjectionItem: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Expression != nil { + nextCursor.AddBranches(typedNode.Expression) } + if typedNode.Alias != nil { + nextCursor.AddBranches(typedNode.Alias) + } + return nextCursor, true - case *cypher.SetItem: - return &Cursor[cypher.SyntaxNode]{ + case *cypher.PartialComparison: + nextCursor := &Cursor[cypher.SyntaxNode]{ Node: node, - Branches: []cypher.SyntaxNode{typedNode.Left, typedNode.Right}, - }, nil - - case *cypher.Set: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Items); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: typedNode, - Branches: branches, - }, nil + Branches: []cypher.SyntaxNode{typedNode.Operator}, + } + if typedNode.Right != nil { + nextCursor.AddBranches(typedNode.Right) } + return nextCursor, true - case *cypher.UpdatingClause: - return &Cursor[cypher.SyntaxNode]{ + case *cypher.PartialArithmeticExpression: + nextCursor := &Cursor[cypher.SyntaxNode]{ Node: node, - Branches: []cypher.SyntaxNode{typedNode.Clause}, - }, nil - - case *cypher.PatternPredicate: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.PatternElements); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: typedNode, - Branches: branches, - }, nil + Branches: []cypher.SyntaxNode{typedNode.Operator}, } - - case *cypher.Order: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Items); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: typedNode, - Branches: branches, - }, nil + if typedNode.Right != nil { + nextCursor.AddBranches(typedNode.Right) } + return nextCursor, true - case *cypher.SortItem: - return &Cursor[cypher.SyntaxNode]{ + case *cypher.UnaryAddOrSubtractExpression: + nextCursor := &Cursor[cypher.SyntaxNode]{ Node: node, - Branches: []cypher.SyntaxNode{typedNode.Expression}, - }, nil + Branches: []cypher.SyntaxNode{typedNode.Operator}, + } + if typedNode.Right != nil { + nextCursor.AddBranches(typedNode.Right) + } + return nextCursor, true - case *cypher.MultiPartQuery: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Parts); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: typedNode, - Branches: append(branches, typedNode.SinglePartQuery), - }, nil + default: + return nil, false + } +} + +func newCypherStructuralPatternWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { + case *cypher.PatternPart: + nextCursor := &Cursor[cypher.SyntaxNode]{ + Node: node, + Branches: make([]cypher.SyntaxNode, 0, len(typedNode.PatternElements)+1), } + if typedNode.Variable != nil { + nextCursor.AddBranches(typedNode.Variable) + } + addCypherBranches(nextCursor, typedNode.PatternElements) + return nextCursor, true - case *cypher.MultiPartQueryPart: + case *cypher.RelationshipPattern: nextCursor := &Cursor[cypher.SyntaxNode]{ Node: node, } - - if len(typedNode.ReadingClauses) > 0 { - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.ReadingClauses); err != nil { - return nil, err - } else { - nextCursor.AddBranches(branches...) - } + if typedNode.Variable != nil { + nextCursor.AddBranches(typedNode.Variable) } - - if len(typedNode.UpdatingClauses) > 0 { - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.UpdatingClauses); err != nil { - return nil, err - } else { - nextCursor.AddBranches(branches...) - } + if typedNode.Kinds != nil { + nextCursor.AddBranches(typedNode.Kinds) } + if typedNode.Range != nil { + nextCursor.AddBranches(typedNode.Range) + } + if typedNode.Properties != nil { + nextCursor.AddBranches(typedNode.Properties) + } + return nextCursor, true - if typedNode.With != nil { - nextCursor.AddBranches(typedNode.With) + case *cypher.NodePattern: + nextCursor := &Cursor[cypher.SyntaxNode]{ + Node: node, + } + if typedNode.Variable != nil { + nextCursor.AddBranches(typedNode.Variable) + } + if typedNode.Kinds != nil { + nextCursor.AddBranches(typedNode.Kinds) + } + if typedNode.Properties != nil { + nextCursor.AddBranches(typedNode.Properties) } + return nextCursor, true - return nextCursor, nil + default: + return nil, false + } +} - case *cypher.With: - nextCursor := &Cursor[cypher.SyntaxNode]{ +func newCypherWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], error) { + if isNilNode(node) { + return nil, fmt.Errorf("unable to negotiate cypher model type %T into a translation cursor", node) + } + + if cursor, handled := newCypherLeafWalkCursor(node); handled { + return cursor, nil + } + if cursor, handled := newCypherValueWalkCursor(node); handled { + return cursor, nil + } + if cursor, handled := newCypherPredicateWalkCursor(node); handled { + return cursor, nil + } + if cursor, handled := newCypherOperatorWalkCursor(node); handled { + return cursor, nil + } + if cursor, handled := newCypherProjectionWalkCursor(node); handled { + return cursor, nil + } + if cursor, handled := newCypherQueryWalkCursor(node); handled { + return cursor, nil + } + if cursor, handled := newCypherUpdatingWalkCursor(node); handled { + return cursor, nil + } + if cursor, handled := newCypherPatternWalkCursor(node); handled { + return cursor, nil + } + + return nil, fmt.Errorf("unable to negotiate cypher model type %T into a translation cursor", node) +} + +func newCypherLeafWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch node.(type) { + case *cypher.RangeQuantifier, *cypher.PatternRange, cypher.Operator, *cypher.Limit, *cypher.Skip, + graph.Kinds, cypher.MapLiteral, *cypher.Parameter, *cypher.Literal, *cypher.Variable: + return &Cursor[cypher.SyntaxNode]{ Node: node, + }, true + + default: + return nil, false + } +} + +func newCypherValueWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { + case *cypher.KindMatcher: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Reference != nil { + nextCursor.AddBranches(typedNode.Reference) } + return nextCursor, true - if typedNode.Projection != nil { - nextCursor.AddBranches(typedNode.Projection) + case *cypher.PropertyLookup: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Atom != nil { + nextCursor.AddBranches(typedNode.Atom) } + return nextCursor, true - if typedNode.Where != nil { - nextCursor.AddBranches(typedNode.Where) + case *cypher.MapItem: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Value != nil { + nextCursor.AddBranches(typedNode.Value) + } + return nextCursor, true + + case *cypher.Properties: + if typedNode.Parameter != nil { + return &Cursor[cypher.SyntaxNode]{ + Node: node, + Branches: []cypher.SyntaxNode{typedNode.Parameter}, + }, true + } else { + return newCypherWalkCursorWithMapItems(node, typedNode.Map), true } - return nextCursor, nil + case *cypher.ListLiteral: + return newCypherWalkCursorWithBranches(typedNode, typedNode.Expressions()), true + + case *cypher.FunctionInvocation: + return newCypherWalkCursorWithBranches(typedNode, typedNode.Arguments), true + + case *cypher.Parenthetical: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Expression != nil { + nextCursor.AddBranches(typedNode.Expression) + } + return nextCursor, true + + default: + return nil, false + } +} +func newCypherPredicateWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { case *cypher.Quantifier: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Filter}, - }, nil + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Filter != nil { + nextCursor.AddBranches(typedNode.Filter) + } + return nextCursor, true case *cypher.FilterExpression: - nextCursor := &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Specifier}, - } + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Specifier != nil { + nextCursor.AddBranches(typedNode.Specifier) + } if typedNode.Where != nil { nextCursor.AddBranches(typedNode.Where) } - return nextCursor, nil + return nextCursor, true case *cypher.IDInCollection: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Expression}, - }, nil + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Expression != nil { + nextCursor.AddBranches(typedNode.Expression) + } + return nextCursor, true - case *cypher.FunctionInvocation: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Arguments); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: typedNode, - Branches: branches, - }, nil + case *cypher.Where: + return newCypherWalkCursorWithBranches(node, typedNode.Expressions), true + + case *cypher.Negation: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Expression != nil { + nextCursor.AddBranches(typedNode.Expression) } + return nextCursor, true - case *cypher.Parenthetical: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Expression}, - }, nil + case *cypher.Conjunction: + return newCypherWalkCursorWithBranches(node, typedNode.Expressions), true - case *cypher.RegularQuery: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.SingleQuery}, - }, nil + case *cypher.Disjunction: + return newCypherWalkCursorWithBranches(node, typedNode.Expressions), true - case *cypher.SingleQuery: - if typedNode.SinglePartQuery != nil { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.SinglePartQuery}, - }, nil - } + case *cypher.ExclusiveDisjunction: + return newCypherWalkCursorWithBranches(node, typedNode.Expressions), true - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.MultiPartQuery}, - }, nil + default: + return nil, false + } +} - case *cypher.SinglePartQuery: +func newCypherOperatorWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { + case *cypher.ArithmeticExpression: + return newCypherWalkCursorWithBranchPrefix(node, typedNode.Left, typedNode.Partials), true + + case *cypher.PartialArithmeticExpression: nextCursor := &Cursor[cypher.SyntaxNode]{ - Node: node, + Node: node, + Branches: []cypher.SyntaxNode{typedNode.Operator}, } - - if len(typedNode.ReadingClauses) > 0 { - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.ReadingClauses); err != nil { - return nil, err - } else { - nextCursor.AddBranches(branches...) - } + if typedNode.Right != nil { + nextCursor.AddBranches(typedNode.Right) } + return nextCursor, true - if len(typedNode.UpdatingClauses) > 0 { - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.UpdatingClauses); err != nil { - return nil, err - } else { - nextCursor.AddBranches(branches...) - } + case *cypher.PartialComparison: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Right != nil { + nextCursor.AddBranches(typedNode.Right) } + return nextCursor, true + + case *cypher.Comparison: + return newCypherWalkCursorWithBranchPrefix(node, typedNode.Left, typedNode.Partials), true - if typedNode.Return != nil { - nextCursor.AddBranches(typedNode.Return) + case *cypher.UnaryAddOrSubtractExpression: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Right != nil { + nextCursor.AddBranches(typedNode.Right) } + return nextCursor, true - return nextCursor, nil + default: + return nil, false + } +} - case *cypher.Return: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Projection}, - }, nil +func newCypherProjectionWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { + case *cypher.Order: + return newCypherWalkCursorWithBranches(typedNode, typedNode.Items), true + + case *cypher.SortItem: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Expression != nil { + nextCursor.AddBranches(typedNode.Expression) + } + return nextCursor, true case *cypher.Projection: nextCursor := &Cursor[cypher.SyntaxNode]{ Node: node, } - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Items); err != nil { - return nil, err - } else { - nextCursor.AddBranches(branches...) - } + addCypherBranches(nextCursor, typedNode.Items) if typedNode.Order != nil { nextCursor.AddBranches(typedNode.Order) @@ -335,13 +443,141 @@ func newCypherWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], er nextCursor.AddBranches(typedNode.Limit) } - return nextCursor, nil + return nextCursor, true case *cypher.ProjectionItem: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Expression}, - }, nil + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Expression != nil { + nextCursor.AddBranches(typedNode.Expression) + } + return nextCursor, true + + default: + return nil, false + } +} + +func newCypherQueryWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + if cursor, handled := newCypherStatementWalkCursor(node); handled { + return cursor, true + } + if cursor, handled := newCypherClauseWalkCursor(node); handled { + return cursor, true + } + + return nil, false +} + +func newCypherStatementWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { + case *cypher.MultiPartQuery: + nextCursor := newCypherWalkCursorWithBranches(typedNode, typedNode.Parts) + if typedNode.SinglePartQuery != nil { + nextCursor.AddBranches(typedNode.SinglePartQuery) + } + return nextCursor, true + + case *cypher.MultiPartQueryPart: + return newCypherMultiPartQueryPartWalkCursor(typedNode), true + + case *cypher.RegularQuery: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.SingleQuery != nil { + nextCursor.AddBranches(typedNode.SingleQuery) + } + return nextCursor, true + + case *cypher.SingleQuery: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.SinglePartQuery != nil { + nextCursor.AddBranches(typedNode.SinglePartQuery) + } + if typedNode.MultiPartQuery != nil { + nextCursor.AddBranches(typedNode.MultiPartQuery) + } + return nextCursor, true + + case *cypher.SinglePartQuery: + return newCypherSinglePartQueryWalkCursor(typedNode), true + + default: + return nil, false + } +} + +func newCypherMultiPartQueryPartWalkCursor(node *cypher.MultiPartQueryPart) *Cursor[cypher.SyntaxNode] { + nextCursor := &Cursor[cypher.SyntaxNode]{ + Node: node, + } + + if len(node.ReadingClauses) > 0 { + addCypherBranches(nextCursor, node.ReadingClauses) + } + + if len(node.UpdatingClauses) > 0 { + addCypherBranches(nextCursor, node.UpdatingClauses) + } + + if node.With != nil { + nextCursor.AddBranches(node.With) + } + + return nextCursor +} + +func newCypherSinglePartQueryWalkCursor(node *cypher.SinglePartQuery) *Cursor[cypher.SyntaxNode] { + nextCursor := &Cursor[cypher.SyntaxNode]{ + Node: node, + } + + if len(node.ReadingClauses) > 0 { + addCypherBranches(nextCursor, node.ReadingClauses) + } + + if len(node.UpdatingClauses) > 0 { + addCypherBranches(nextCursor, node.UpdatingClauses) + } + + if node.Return != nil { + nextCursor.AddBranches(node.Return) + } + + return nextCursor +} + +func newCypherClauseWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { + case *cypher.Unwind: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Expression != nil { + nextCursor.AddBranches(typedNode.Expression) + } + if typedNode.Variable != nil { + nextCursor.AddBranches(typedNode.Variable) + } + return nextCursor, true + + case *cypher.With: + nextCursor := &Cursor[cypher.SyntaxNode]{ + Node: node, + } + + if typedNode.Projection != nil { + nextCursor.AddBranches(typedNode.Projection) + } + + if typedNode.Where != nil { + nextCursor.AddBranches(typedNode.Where) + } + + return nextCursor, true + + case *cypher.Return: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Projection != nil { + nextCursor.AddBranches(typedNode.Projection) + } + return nextCursor, true case *cypher.ReadingClause: nextCursor := &Cursor[cypher.SyntaxNode]{ @@ -356,159 +592,121 @@ func newCypherWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], er nextCursor.AddBranches(typedNode.Unwind) } - return nextCursor, nil + return nextCursor, true case *cypher.Match: nextCursor := &Cursor[cypher.SyntaxNode]{ Node: node, } - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Pattern); err != nil { - return nil, err - } else { - nextCursor.AddBranches(branches...) - } + addCypherBranches(nextCursor, typedNode.Pattern) if typedNode.Where != nil { nextCursor.AddBranches(typedNode.Where) } - return nextCursor, nil + return nextCursor, true - case *cypher.PatternPart: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.PatternElements); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: branches, - }, nil - } + default: + return nil, false + } +} - case *cypher.PatternElement: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Element}, - }, nil +func newCypherUpdatingWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { + case *cypher.Create: + return newCypherWalkCursorWithBranches(node, typedNode.Pattern), true - case *cypher.RelationshipPattern: + case *cypher.RemoveItem: nextCursor := &Cursor[cypher.SyntaxNode]{ Node: node, } - if typedNode.Properties != nil { - nextCursor.AddBranches(typedNode.Properties) + if typedNode.Property != nil { + nextCursor.AddBranches(typedNode.Property) } - return nextCursor, nil + return nextCursor, true - case *cypher.NodePattern: - nextCursor := &Cursor[cypher.SyntaxNode]{ - Node: node, - } + case *cypher.Remove: + return newCypherWalkCursorWithBranches(typedNode, typedNode.Items), true - if typedNode.Properties != nil { - nextCursor.AddBranches(typedNode.Properties) + case *cypher.Delete: + return newCypherWalkCursorWithBranches(typedNode, typedNode.Expressions), true + + case *cypher.SetItem: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Left != nil { + nextCursor.AddBranches(typedNode.Left) + } + if typedNode.Right != nil { + nextCursor.AddBranches(typedNode.Right) } + return nextCursor, true - return nextCursor, nil + case *cypher.Set: + return newCypherWalkCursorWithBranches(typedNode, typedNode.Items), true - case *cypher.Where: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Expressions); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: branches, - }, nil + case *cypher.UpdatingClause: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Clause != nil { + nextCursor.AddBranches(typedNode.Clause) } + return nextCursor, true - case *cypher.Variable: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - }, nil + case *cypher.Merge: + return newCypherWalkCursorWithBranchPrefix(node, typedNode.PatternPart, typedNode.MergeActions), true - case *cypher.ArithmeticExpression: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Partials); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: append([]cypher.SyntaxNode{typedNode.Left}, branches...), - }, nil + case *cypher.MergeAction: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Set != nil { + nextCursor.AddBranches(typedNode.Set) } + return nextCursor, true - case *cypher.PartialArithmeticExpression: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Operator, typedNode.Right}, - }, nil + default: + return nil, false + } +} - case *cypher.PartialComparison: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Right}, - }, nil +func newCypherPatternWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], bool) { + switch typedNode := node.(type) { + case *cypher.PatternPredicate: + return newCypherWalkCursorWithBranches(typedNode, typedNode.PatternElements), true - case *cypher.Negation: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Expression}, - }, nil + case *cypher.PatternPart: + return newCypherWalkCursorWithBranches(node, typedNode.PatternElements), true - case *cypher.Conjunction: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Expressions); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: branches, - }, nil + case *cypher.PatternElement: + nextCursor := &Cursor[cypher.SyntaxNode]{Node: node} + if typedNode.Element != nil { + nextCursor.AddBranches(typedNode.Element) } + return nextCursor, true - case *cypher.Disjunction: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Expressions); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: branches, - }, nil + case *cypher.RelationshipPattern: + nextCursor := &Cursor[cypher.SyntaxNode]{ + Node: node, } - case *cypher.Comparison: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.Partials); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: append([]cypher.SyntaxNode{typedNode.Left}, branches...), - }, nil + if typedNode.Properties != nil { + nextCursor.AddBranches(typedNode.Properties) } - case *cypher.Merge: - if branches, err := cypherSyntaxNodeSliceTypeConvert(typedNode.MergeActions); err != nil { - return nil, err - } else { - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: append([]cypher.SyntaxNode{typedNode.PatternPart}, branches...), - }, nil + return nextCursor, true + + case *cypher.NodePattern: + nextCursor := &Cursor[cypher.SyntaxNode]{ + Node: node, } - case *cypher.MergeAction: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Set}, - }, nil + if typedNode.Properties != nil { + nextCursor.AddBranches(typedNode.Properties) + } - case *cypher.UnaryAddOrSubtractExpression: - return &Cursor[cypher.SyntaxNode]{ - Node: node, - Branches: []cypher.SyntaxNode{typedNode.Right}, - }, nil + return nextCursor, true default: - return nil, fmt.Errorf("unable to negotiate cypher model type %T into a translation cursor", node) + return nil, false } } diff --git a/cypher/models/walk/walk_pgsql.go b/cypher/models/walk/walk_pgsql.go index 37e0a8c1..f3ac3945 100644 --- a/cypher/models/walk/walk_pgsql.go +++ b/cypher/models/walk/walk_pgsql.go @@ -35,6 +35,10 @@ func newSQLCaseWalkCursor(node pgsql.SyntaxNode, caseExpr pgsql.Case) (*Cursor[p } func newSQLWalkCursor(node pgsql.SyntaxNode) (*Cursor[pgsql.SyntaxNode], error) { + if isNilNode(node) { + return nil, fmt.Errorf("unable to negotiate sql type %T into a translation cursor", node) + } + switch typedNode := node.(type) { case pgsql.Query: nextCursor := &Cursor[pgsql.SyntaxNode]{ diff --git a/cypher/models/walk/walk_test.go b/cypher/models/walk/walk_test.go index 4ccd6b2f..87fc9065 100644 --- a/cypher/models/walk/walk_test.go +++ b/cypher/models/walk/walk_test.go @@ -1,11 +1,14 @@ package walk_test import ( + "errors" + "fmt" "testing" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/pgsql" "github.com/specterops/dawgs/cypher/models/walk" + "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/cypher/frontend" "github.com/specterops/dawgs/cypher/test" @@ -32,6 +35,1621 @@ func TestWalk(t *testing.T) { } } +func TestCypherWalkConsumeLeafDoesNotSkipSibling(t *testing.T) { + expression := cypher.NewDisjunction( + cypher.NewVariableWithSymbol("first"), + cypher.NewVariableWithSymbol("second"), + ) + + var visited []string + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, handler walk.VisitorHandler) { + variable, isVariable := node.(*cypher.Variable) + if !isVariable { + return + } + + visited = append(visited, variable.Symbol) + if variable.Symbol == "first" { + handler.Consume() + } + }) + + require.NoError(t, walk.Cypher(expression, visitor)) + require.Equal(t, []string{"first", "second"}, visited) +} + +func TestCypherWalkVisitsExclusiveDisjunction(t *testing.T) { + expression := cypher.NewExclusiveDisjunction( + cypher.NewVariableWithSymbol("left"), + cypher.NewVariableWithSymbol("right"), + ) + + var visited []string + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + if variable, isVariable := node.(*cypher.Variable); isVariable { + visited = append(visited, variable.Symbol) + } + }) + + require.NoError(t, walk.Cypher(expression, visitor)) + require.Equal(t, []string{"left", "right"}, visited) +} + +func TestCypherWalkTreatsBareMapLiteralAsLeaf(t *testing.T) { + mapLiteral := cypher.MapLiteral{ + "b": cypher.NewVariableWithSymbol("b_value"), + "a": cypher.NewVariableWithSymbol("a_value"), + } + + var ( + visitedMapLiterals int + visitedMapItems int + visitedValues []string + ) + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + switch typedNode := node.(type) { + case cypher.MapLiteral: + visitedMapLiterals++ + + case *cypher.MapItem: + visitedMapItems++ + + case *cypher.Variable: + visitedValues = append(visitedValues, typedNode.Symbol) + } + }) + + require.NoError(t, walk.Cypher(mapLiteral, visitor)) + require.Equal(t, 1, visitedMapLiterals) + require.Zero(t, visitedMapItems) + require.Empty(t, visitedValues) +} + +func TestCypherWalkVisitsPropertiesMapValuesInKeyOrder(t *testing.T) { + properties := &cypher.Properties{ + Map: cypher.MapLiteral{ + "b": cypher.NewVariableWithSymbol("b_value"), + "a": cypher.NewVariableWithSymbol("a_value"), + }, + } + + var ( + visitedKeys []string + visitedValues []string + ) + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + switch typedNode := node.(type) { + case *cypher.MapItem: + visitedKeys = append(visitedKeys, typedNode.Key) + + case *cypher.Variable: + visitedValues = append(visitedValues, typedNode.Symbol) + } + }) + + require.NoError(t, walk.Cypher(properties, visitor)) + require.Equal(t, []string{"a", "b"}, visitedKeys) + require.Equal(t, []string{"a_value", "b_value"}, visitedValues) +} + +func TestCypherWalkSkipsNilBranches(t *testing.T) { + testCases := map[string]cypher.SyntaxNode{ + "regular query": &cypher.RegularQuery{}, + "single query": &cypher.SingleQuery{}, + "multipart query": &cypher.MultiPartQuery{}, + "return": &cypher.Return{}, + "set item": &cypher.SetItem{}, + "merge action": &cypher.MergeAction{}, + "updating clause": &cypher.UpdatingClause{}, + "projection item": &cypher.ProjectionItem{}, + "pattern element": &cypher.PatternElement{}, + "partial comparison": &cypher.PartialComparison{}, + "partial arithmetic": &cypher.PartialArithmeticExpression{}, + "unary add/subtract": &cypher.UnaryAddOrSubtractExpression{}, + "relationship pattern": &cypher.RelationshipPattern{}, + "node pattern": &cypher.NodePattern{}, + } + + for name, node := range testCases { + t.Run(name, func(t *testing.T) { + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(cypher.SyntaxNode, walk.VisitorHandler) {}) + require.NoError(t, walk.Cypher(node, visitor)) + }) + } +} + +func TestWalkRejectsNilRootsButVisitsTypedNilCollections(t *testing.T) { + t.Run("cypher nil interface root", func(t *testing.T) { + var ( + root cypher.SyntaxNode + visited bool + ) + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(cypher.SyntaxNode, walk.VisitorHandler) { + visited = true + }) + + err := walk.Cypher(root, visitor) + require.ErrorContains(t, err, "unable to negotiate cypher model type ") + require.False(t, visited) + }) + + t.Run("cypher nil pointer root", func(t *testing.T) { + var ( + root *cypher.Variable + visited bool + ) + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(cypher.SyntaxNode, walk.VisitorHandler) { + visited = true + }) + + err := walk.Cypher(root, visitor) + require.ErrorContains(t, err, "unable to negotiate cypher model type *cypher.Variable") + require.False(t, visited) + }) + + t.Run("pgsql nil interface root", func(t *testing.T) { + var ( + root pgsql.SyntaxNode + visited bool + ) + + visitor := walk.NewSimpleVisitor[pgsql.SyntaxNode](func(pgsql.SyntaxNode, walk.VisitorHandler) { + visited = true + }) + + err := walk.PgSQL(root, visitor) + require.ErrorContains(t, err, "unable to negotiate sql type ") + require.False(t, visited) + }) + + t.Run("cypher nil interface branch", func(t *testing.T) { + expression := cypher.NewDisjunction(nil) + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(cypher.SyntaxNode, walk.VisitorHandler) {}) + + err := walk.Cypher(expression, visitor) + require.ErrorContains(t, err, "unable to negotiate cypher model type ") + }) + + t.Run("cypher nil pointer branch", func(t *testing.T) { + var variable *cypher.Variable + expression := cypher.NewDisjunction(variable) + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(cypher.SyntaxNode, walk.VisitorHandler) {}) + + err := walk.Cypher(expression, visitor) + require.ErrorContains(t, err, "unable to negotiate cypher model type *cypher.Variable") + }) + + t.Run("cypher nil map literal root", func(t *testing.T) { + var ( + root cypher.MapLiteral + visited bool + ) + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + _, visited = node.(cypher.MapLiteral) + }) + + require.NoError(t, walk.Cypher(root, visitor)) + require.True(t, visited) + }) + + t.Run("pgsql nil slice node root", func(t *testing.T) { + var ( + root pgsql.CompoundIdentifier + visited bool + ) + + visitor := walk.NewSimpleVisitor[pgsql.SyntaxNode](func(node pgsql.SyntaxNode, _ walk.VisitorHandler) { + _, visited = node.(pgsql.CompoundIdentifier) + }) + + require.NoError(t, walk.PgSQL(root, visitor)) + require.True(t, visited) + }) +} + +func TestSimpleVisitorOrders(t *testing.T) { + expression := cypher.NewDisjunction( + cypher.NewVariableWithSymbol("left"), + cypher.NewVariableWithSymbol("right"), + ) + + testCases := []struct { + name string + order walk.Order + expected []string + }{ + { + name: "prefix", + order: walk.OrderPrefix, + expected: []string{"disjunction", "left", "right"}, + }, + { + name: "infix", + order: walk.OrderInfix, + expected: []string{"disjunction"}, + }, + { + name: "postfix", + order: walk.OrderPostfix, + expected: []string{"left", "right", "disjunction"}, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + var visited []string + visitor := walk.NewSimpleVisitorWithOrder[cypher.SyntaxNode](testCase.order, func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + switch typedNode := node.(type) { + case *cypher.Disjunction: + visited = append(visited, "disjunction") + + case *cypher.Variable: + visited = append(visited, typedNode.Symbol) + } + }) + + require.NoError(t, walk.Cypher(expression, visitor)) + require.Equal(t, testCase.expected, visited) + }) + } +} + +func TestSimpleVisitorConsumeByOrder(t *testing.T) { + t.Run("prefix root consume skips children", func(t *testing.T) { + expression := cypher.NewDisjunction( + cypher.NewVariableWithSymbol("left"), + cypher.NewVariableWithSymbol("right"), + ) + + var visited []string + visitor := walk.NewSimpleVisitorWithOrder[cypher.SyntaxNode](walk.OrderPrefix, func(node cypher.SyntaxNode, handler walk.VisitorHandler) { + switch typedNode := node.(type) { + case *cypher.Disjunction: + visited = append(visited, "disjunction") + handler.Consume() + + case *cypher.Variable: + visited = append(visited, typedNode.Symbol) + } + }) + + require.NoError(t, walk.Cypher(expression, visitor)) + require.Equal(t, []string{"disjunction"}, visited) + }) + + t.Run("infix consume skips remaining siblings", func(t *testing.T) { + expression := cypher.NewDisjunction( + cypher.NewDisjunction( + cypher.NewVariableWithSymbol("left_a"), + cypher.NewVariableWithSymbol("left_b"), + ), + cypher.NewDisjunction( + cypher.NewVariableWithSymbol("right_a"), + cypher.NewVariableWithSymbol("right_b"), + ), + ) + + var visited []string + visitor := walk.NewSimpleVisitorWithOrder[cypher.SyntaxNode](walk.OrderInfix, func(node cypher.SyntaxNode, handler walk.VisitorHandler) { + disjunction, isDisjunction := node.(*cypher.Disjunction) + if !isDisjunction { + return + } + + switch disjunction.Expressions[0].(type) { + case *cypher.Variable: + visited = append(visited, "inner") + + case *cypher.Disjunction: + visited = append(visited, "root") + handler.Consume() + } + }) + + require.NoError(t, walk.Cypher(expression, visitor)) + require.Equal(t, []string{"inner", "root"}, visited) + }) + + t.Run("postfix leaf consume does not skip siblings", func(t *testing.T) { + expression := cypher.NewDisjunction( + cypher.NewVariableWithSymbol("left"), + cypher.NewVariableWithSymbol("right"), + ) + + var visited []string + visitor := walk.NewSimpleVisitorWithOrder[cypher.SyntaxNode](walk.OrderPostfix, func(node cypher.SyntaxNode, handler walk.VisitorHandler) { + switch typedNode := node.(type) { + case *cypher.Variable: + visited = append(visited, typedNode.Symbol) + if typedNode.Symbol == "left" { + handler.Consume() + } + + case *cypher.Disjunction: + visited = append(visited, "disjunction") + } + }) + + require.NoError(t, walk.Cypher(expression, visitor)) + require.Equal(t, []string{"left", "right", "disjunction"}, visited) + }) +} + +func TestGenericSetDoneStopsWithoutUnwindingExit(t *testing.T) { + t.Run("enter", func(t *testing.T) { + root := &genericWalkTestNode{ + name: "root", + children: []*genericWalkTestNode{ + {name: "child"}, + }, + } + visitor := newRecordingGenericWalkVisitor() + + visitor.onEnter = func(node *genericWalkTestNode) { + if node.name == "root" { + visitor.SetDone() + } + } + + require.NoError(t, walk.Generic(root, visitor, newGenericWalkTestCursor)) + require.Equal(t, []string{"enter:root"}, visitor.events) + }) + + t.Run("visit", func(t *testing.T) { + root := &genericWalkTestNode{ + name: "root", + children: []*genericWalkTestNode{ + {name: "left"}, + {name: "right"}, + }, + } + visitor := newRecordingGenericWalkVisitor() + + visitor.onVisit = func(node *genericWalkTestNode) { + if node.name == "root" { + visitor.SetDone() + } + } + + require.NoError(t, walk.Generic(root, visitor, newGenericWalkTestCursor)) + require.Equal(t, []string{"enter:root", "enter:left", "exit:left", "visit:root"}, visitor.events) + }) +} + +func TestGenericSetErrorStopsAndReturnsJoinedError(t *testing.T) { + root := &genericWalkTestNode{ + name: "root", + children: []*genericWalkTestNode{ + {name: "child"}, + }, + } + visitor := newRecordingGenericWalkVisitor() + + visitor.onEnter = func(node *genericWalkTestNode) { + if node.name == "root" { + visitor.SetError(errors.New("first failure")) + visitor.SetErrorf("second %s", "failure") + } + } + + err := walk.Generic(root, visitor, newGenericWalkTestCursor) + require.ErrorContains(t, err, "first failure") + require.ErrorContains(t, err, "second failure") + require.True(t, visitor.Done()) + require.Equal(t, []string{"enter:root"}, visitor.events) +} + +func TestGenericReturnsCursorConstructorErrors(t *testing.T) { + t.Run("root", func(t *testing.T) { + expectedErr := errors.New("root cursor failure") + visitor := newRecordingGenericWalkVisitor() + + err := walk.Generic(&genericWalkTestNode{name: "root"}, visitor, func(*genericWalkTestNode) (*walk.Cursor[*genericWalkTestNode], error) { + return nil, expectedErr + }) + + require.ErrorIs(t, err, expectedErr) + require.Empty(t, visitor.events) + }) + + t.Run("nil root", func(t *testing.T) { + expectedErr := errors.New("nil root cursor failure") + var root *genericWalkTestNode + visitor := newRecordingGenericWalkVisitor() + called := false + + err := walk.Generic(root, visitor, func(node *genericWalkTestNode) (*walk.Cursor[*genericWalkTestNode], error) { + called = true + require.Nil(t, node) + return nil, expectedErr + }) + + require.ErrorIs(t, err, expectedErr) + require.True(t, called) + require.Empty(t, visitor.events) + }) + + t.Run("child", func(t *testing.T) { + expectedErr := errors.New("child cursor failure") + root := &genericWalkTestNode{ + name: "root", + children: []*genericWalkTestNode{ + {name: "bad"}, + }, + } + visitor := newRecordingGenericWalkVisitor() + + err := walk.Generic(root, visitor, func(node *genericWalkTestNode) (*walk.Cursor[*genericWalkTestNode], error) { + if node.name == "bad" { + return nil, expectedErr + } + + return newGenericWalkTestCursor(node) + }) + + require.ErrorIs(t, err, expectedErr) + require.Equal(t, []string{"enter:root"}, visitor.events) + }) + + t.Run("nil child", func(t *testing.T) { + expectedErr := errors.New("nil child cursor failure") + root := &genericWalkTestNode{ + name: "root", + children: []*genericWalkTestNode{nil}, + } + visitor := newRecordingGenericWalkVisitor() + + err := walk.Generic(root, visitor, func(node *genericWalkTestNode) (*walk.Cursor[*genericWalkTestNode], error) { + if node == nil { + return nil, expectedErr + } + + return newGenericWalkTestCursor(node) + }) + + require.ErrorIs(t, err, expectedErr) + require.Equal(t, []string{"enter:root"}, visitor.events) + }) +} + +func TestGenericReturnsVisitorErrorsFromVisitAndExit(t *testing.T) { + t.Run("visit", func(t *testing.T) { + expectedErr := errors.New("visit failure") + root := &genericWalkTestNode{ + name: "root", + children: []*genericWalkTestNode{ + {name: "left"}, + {name: "right"}, + }, + } + visitor := newRecordingGenericWalkVisitor() + + visitor.onVisit = func(node *genericWalkTestNode) { + if node.name == "root" { + visitor.SetError(expectedErr) + } + } + + err := walk.Generic(root, visitor, newGenericWalkTestCursor) + + require.ErrorIs(t, err, expectedErr) + require.Equal(t, []string{"enter:root", "enter:left", "exit:left", "visit:root"}, visitor.events) + }) + + t.Run("exit", func(t *testing.T) { + expectedErr := errors.New("exit failure") + root := &genericWalkTestNode{ + name: "root", + children: []*genericWalkTestNode{ + {name: "left"}, + }, + } + visitor := newRecordingGenericWalkVisitor() + + visitor.onExit = func(node *genericWalkTestNode) { + if node.name == "left" { + visitor.SetError(expectedErr) + } + } + + err := walk.Generic(root, visitor, newGenericWalkTestCursor) + + require.ErrorIs(t, err, expectedErr) + require.Equal(t, []string{"enter:root", "enter:left", "exit:left"}, visitor.events) + }) + + t.Run("consumed node exit", func(t *testing.T) { + expectedErr := errors.New("consumed exit failure") + root := &genericWalkTestNode{ + name: "root", + children: []*genericWalkTestNode{ + {name: "left"}, + {name: "right"}, + }, + } + visitor := newRecordingGenericWalkVisitor() + + visitor.onVisit = func(node *genericWalkTestNode) { + if node.name == "root" { + visitor.Consume() + } + } + visitor.onExit = func(node *genericWalkTestNode) { + if node.name == "root" { + visitor.SetError(expectedErr) + } + } + + err := walk.Generic(root, visitor, newGenericWalkTestCursor) + + require.ErrorIs(t, err, expectedErr) + require.Equal(t, []string{"enter:root", "enter:left", "exit:left", "visit:root", "exit:root"}, visitor.events) + }) +} + +func TestCypherWalkSemanticSkipsDeclarationOnlyFields(t *testing.T) { + testCases := map[string]struct { + node cypher.SyntaxNode + visited []string + notVisited []string + visitedRanges int + }{ + "projection alias": { + node: &cypher.ProjectionItem{ + Expression: cypher.NewVariableWithSymbol("value"), + Alias: cypher.NewVariableWithSymbol("alias"), + }, + visited: []string{"value"}, + notVisited: []string{"alias"}, + }, + "id in collection variable": { + node: &cypher.IDInCollection{ + Variable: cypher.NewVariableWithSymbol("item"), + Expression: cypher.NewVariableWithSymbol("items"), + }, + visited: []string{"items"}, + notVisited: []string{"item"}, + }, + "pattern part variable": { + node: &cypher.PatternPart{ + Variable: cypher.NewVariableWithSymbol("path"), + }, + notVisited: []string{"path"}, + }, + "node pattern variable": { + node: &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("node"), + }, + notVisited: []string{"node"}, + }, + "relationship pattern metadata": { + node: &cypher.RelationshipPattern{ + Variable: cypher.NewVariableWithSymbol("rel"), + Range: &cypher.PatternRange{}, + }, + notVisited: []string{"rel"}, + visitedRanges: 0, + }, + "remove kind matcher": { + node: &cypher.RemoveItem{ + KindMatcher: &cypher.KindMatcher{ + Reference: cypher.NewVariableWithSymbol("node"), + Kinds: graph.Kinds{graph.StringKind("NodeKind")}, + }, + }, + notVisited: []string{"node"}, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + var ( + visitedVariables []string + visitedRanges int + ) + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + switch typedNode := node.(type) { + case *cypher.Variable: + visitedVariables = append(visitedVariables, typedNode.Symbol) + + case *cypher.PatternRange: + visitedRanges++ + } + }) + + require.NoError(t, walk.Cypher(testCase.node, visitor)) + for _, symbol := range testCase.visited { + require.Contains(t, visitedVariables, symbol) + } + for _, symbol := range testCase.notVisited { + require.NotContains(t, visitedVariables, symbol) + } + if testCase.visitedRanges == 0 { + require.Zero(t, visitedRanges) + } + }) + } +} + +func TestCypherWalkVisitsSemanticChildrenByNodeType(t *testing.T) { + testCases := map[string]struct { + node cypher.SyntaxNode + visited []string + notVisited []string + }{ + "kind matcher visits reference only": { + node: &cypher.KindMatcher{ + Reference: cypher.NewVariableWithSymbol("node"), + Kinds: graph.Kinds{graph.StringKind("NodeKind")}, + }, + visited: []string{"variable:node"}, + notVisited: []string{"kind:NodeKind"}, + }, + "property lookup visits atom": { + node: cypher.NewPropertyLookup("node", "name"), + visited: []string{"variable:node"}, + notVisited: []string{"variable:name"}, + }, + "map item visits value": { + node: &cypher.MapItem{ + Key: "name", + Value: cypher.NewVariableWithSymbol("value"), + }, + visited: []string{"variable:value"}, + }, + "properties parameter visits parameter only": { + node: &cypher.Properties{ + Parameter: cypher.NewParameter("props", map[string]any{}), + Map: cypher.MapLiteral{ + "name": cypher.NewVariableWithSymbol("name"), + }, + }, + visited: []string{"parameter:props"}, + notVisited: []string{"mapitem:name", "variable:name"}, + }, + "properties map visits map items": { + node: &cypher.Properties{ + Map: cypher.MapLiteral{ + "name": cypher.NewVariableWithSymbol("name"), + }, + }, + visited: []string{"mapitem:name", "variable:name"}, + }, + "list literal visits expressions": { + node: &cypher.ListLiteral{ + cypher.NewVariableWithSymbol("left"), + cypher.NewVariableWithSymbol("right"), + }, + visited: []string{"variable:left", "variable:right"}, + }, + "create visits pattern expressions": { + node: &cypher.Create{ + Pattern: []*cypher.PatternPart{{ + PatternElements: []*cypher.PatternElement{{ + Element: &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("node"), + Properties: &cypher.Properties{ + Map: cypher.MapLiteral{ + "name": cypher.NewVariableWithSymbol("name"), + }, + }, + }, + }}, + }}, + }, + visited: []string{"mapitem:name", "variable:name"}, + notVisited: []string{"variable:node"}, + }, + "unwind visits source and binding variable": { + node: &cypher.Unwind{ + Expression: cypher.NewVariableWithSymbol("items"), + Variable: cypher.NewVariableWithSymbol("item"), + }, + visited: []string{"variable:items", "variable:item"}, + }, + "remove item visits property only": { + node: &cypher.RemoveItem{ + KindMatcher: &cypher.KindMatcher{ + Reference: cypher.NewVariableWithSymbol("node"), + }, + Property: cypher.NewPropertyLookup("target", "name"), + }, + visited: []string{"variable:target"}, + notVisited: []string{"variable:node"}, + }, + "set item visits both sides": { + node: &cypher.SetItem{ + Left: cypher.NewPropertyLookup("node", "name"), + Right: cypher.NewVariableWithSymbol("value"), + }, + visited: []string{"variable:node", "variable:value"}, + }, + "quantifier visits filter expression semantics": { + node: &cypher.Quantifier{ + Filter: &cypher.FilterExpression{ + Specifier: &cypher.IDInCollection{ + Variable: cypher.NewVariableWithSymbol("item"), + Expression: cypher.NewVariableWithSymbol("items"), + }, + Where: &cypher.Where{}, + }, + }, + visited: []string{"variable:items"}, + notVisited: []string{"variable:item"}, + }, + "function invocation visits arguments": { + node: cypher.NewSimpleFunctionInvocation( + "coalesce", + cypher.NewVariableWithSymbol("left"), + cypher.NewVariableWithSymbol("right"), + ), + visited: []string{"variable:left", "variable:right"}, + }, + "projection visits items order skip and limit nodes": { + node: &cypher.Projection{ + Items: []cypher.Expression{ + &cypher.ProjectionItem{ + Expression: cypher.NewVariableWithSymbol("value"), + Alias: cypher.NewVariableWithSymbol("alias"), + }, + }, + Order: &cypher.Order{ + Items: []*cypher.SortItem{{ + Expression: cypher.NewVariableWithSymbol("ordered"), + }}, + }, + Skip: &cypher.Skip{Value: cypher.NewLiteral(10, false)}, + Limit: &cypher.Limit{Value: cypher.NewLiteral(20, false)}, + }, + visited: []string{"variable:value", "variable:ordered"}, + notVisited: []string{"variable:alias", "literal:10", "literal:20"}, + }, + "arithmetic expression visits operators and operands": { + node: &cypher.ArithmeticExpression{ + Left: cypher.NewVariableWithSymbol("left"), + Partials: []*cypher.PartialArithmeticExpression{{ + Operator: cypher.OperatorAdd, + Right: cypher.NewVariableWithSymbol("right"), + }}, + }, + visited: []string{"variable:left", "operator:+", "variable:right"}, + }, + "comparison visits right operands without operators": { + node: &cypher.Comparison{ + Left: cypher.NewVariableWithSymbol("left"), + Partials: []*cypher.PartialComparison{{ + Operator: cypher.OperatorEquals, + Right: cypher.NewVariableWithSymbol("right"), + }}, + }, + visited: []string{"variable:left", "variable:right"}, + notVisited: []string{"operator:="}, + }, + "merge visits pattern and actions": { + node: &cypher.Merge{ + PatternPart: &cypher.PatternPart{ + PatternElements: []*cypher.PatternElement{{ + Element: &cypher.NodePattern{ + Properties: &cypher.Properties{ + Map: cypher.MapLiteral{ + "id": cypher.NewVariableWithSymbol("id"), + }, + }, + }, + }}, + }, + MergeActions: []*cypher.MergeAction{{ + Set: &cypher.Set{ + Items: []*cypher.SetItem{{ + Left: cypher.NewPropertyLookup("node", "name"), + Right: cypher.NewVariableWithSymbol("name"), + }}, + }, + }}, + }, + visited: []string{"mapitem:id", "variable:id", "variable:node", "variable:name"}, + }, + "unary add or subtract visits right operand only": { + node: &cypher.UnaryAddOrSubtractExpression{ + Operator: cypher.OperatorSubtract, + Right: cypher.NewVariableWithSymbol("value"), + }, + visited: []string{"variable:value"}, + notVisited: []string{"operator:-"}, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + visited := collectCypherWalkLabels(t, testCase.node, walk.Cypher) + + for _, expectedLabel := range testCase.visited { + require.Contains(t, visited, expectedLabel) + } + for _, unexpectedLabel := range testCase.notVisited { + require.NotContains(t, visited, unexpectedLabel) + } + }) + } +} + +func TestCypherWalkSemanticTraversalSequences(t *testing.T) { + testCases := map[string]struct { + node cypher.SyntaxNode + expected []string + }{ + "projection walks items then order and skips pagination values": { + node: &cypher.Projection{ + Items: []cypher.Expression{ + &cypher.ProjectionItem{ + Expression: cypher.NewVariableWithSymbol("value"), + Alias: cypher.NewVariableWithSymbol("alias"), + }, + }, + Order: &cypher.Order{ + Items: []*cypher.SortItem{{ + Expression: cypher.NewVariableWithSymbol("ordered"), + }}, + }, + Skip: &cypher.Skip{Value: cypher.NewLiteral(10, false)}, + Limit: &cypher.Limit{Value: cypher.NewLiteral(20, false)}, + }, + expected: []string{"variable:value", "variable:ordered"}, + }, + "comparison walks left then right without operator": { + node: &cypher.Comparison{ + Left: cypher.NewVariableWithSymbol("left"), + Partials: []*cypher.PartialComparison{{ + Operator: cypher.OperatorEquals, + Right: cypher.NewVariableWithSymbol("right"), + }}, + }, + expected: []string{"variable:left", "variable:right"}, + }, + "arithmetic walks left operator then right": { + node: &cypher.ArithmeticExpression{ + Left: cypher.NewVariableWithSymbol("left"), + Partials: []*cypher.PartialArithmeticExpression{{ + Operator: cypher.OperatorAdd, + Right: cypher.NewVariableWithSymbol("right"), + }}, + }, + expected: []string{"variable:left", "operator:+", "variable:right"}, + }, + "merge walks pattern before actions": { + node: &cypher.Merge{ + PatternPart: &cypher.PatternPart{ + PatternElements: []*cypher.PatternElement{{ + Element: &cypher.NodePattern{ + Properties: &cypher.Properties{ + Map: cypher.MapLiteral{ + "id": cypher.NewVariableWithSymbol("id"), + }, + }, + }, + }}, + }, + MergeActions: []*cypher.MergeAction{{ + Set: &cypher.Set{ + Items: []*cypher.SetItem{{ + Left: cypher.NewPropertyLookup("node", "name"), + Right: cypher.NewVariableWithSymbol("name"), + }}, + }, + }}, + }, + expected: []string{"mapitem:id", "variable:id", "variable:node", "variable:name"}, + }, + "quantifier walks collection expression then where expression": { + node: &cypher.Quantifier{ + Filter: &cypher.FilterExpression{ + Specifier: &cypher.IDInCollection{ + Variable: cypher.NewVariableWithSymbol("item"), + Expression: cypher.NewVariableWithSymbol("items"), + }, + Where: newCypherWhere(cypher.NewVariableWithSymbol("predicate")), + }, + }, + expected: []string{"variable:items", "variable:predicate"}, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + require.Equal(t, testCase.expected, collectCypherWalkLabels(t, testCase.node, walk.Cypher)) + }) + } +} + +func TestCypherWalkQueryAndClauseTraversalSequences(t *testing.T) { + query := &cypher.MultiPartQuery{ + Parts: []*cypher.MultiPartQueryPart{{ + ReadingClauses: []*cypher.ReadingClause{{ + Match: &cypher.Match{ + Pattern: []*cypher.PatternPart{{ + PatternElements: []*cypher.PatternElement{{ + Element: &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("read_node"), + Properties: &cypher.Properties{ + Map: cypher.MapLiteral{ + "read": cypher.NewVariableWithSymbol("read_value"), + }, + }, + }, + }}, + }}, + Where: newCypherWhere(cypher.NewVariableWithSymbol("match_where")), + }, + }}, + UpdatingClauses: []*cypher.UpdatingClause{ + cypher.NewUpdatingClause(&cypher.Set{ + Items: []*cypher.SetItem{{ + Left: cypher.NewPropertyLookup("update_node", "name"), + Right: cypher.NewVariableWithSymbol("update_value"), + }}, + }), + }, + With: &cypher.With{ + Projection: &cypher.Projection{ + Items: []cypher.Expression{ + &cypher.ProjectionItem{ + Expression: cypher.NewVariableWithSymbol("with_projection"), + }, + }, + }, + Where: newCypherWhere(cypher.NewVariableWithSymbol("with_where")), + }, + }}, + SinglePartQuery: &cypher.SinglePartQuery{ + ReadingClauses: []*cypher.ReadingClause{{ + Unwind: &cypher.Unwind{ + Expression: cypher.NewVariableWithSymbol("final_items"), + Variable: cypher.NewVariableWithSymbol("final_item"), + }, + }}, + UpdatingClauses: []cypher.Expression{ + cypher.NewUpdatingClause(&cypher.Remove{ + Items: []*cypher.RemoveItem{ + cypher.RemoveProperty(cypher.NewPropertyLookup("remove_node", "name")), + }, + }), + }, + Return: &cypher.Return{ + Projection: &cypher.Projection{ + Items: []cypher.Expression{ + &cypher.ProjectionItem{ + Expression: cypher.NewVariableWithSymbol("returned"), + }, + }, + }, + }, + }, + } + + require.Equal(t, []string{ + "mapitem:read", + "variable:read_value", + "variable:match_where", + "variable:update_node", + "variable:update_value", + "variable:with_projection", + "variable:with_where", + "variable:final_items", + "variable:final_item", + "variable:remove_node", + "variable:returned", + }, collectCypherWalkLabels(t, query, walk.Cypher)) +} + +func TestCypherStructuralWalkVisitsDeclarationAndMetadataFields(t *testing.T) { + testCases := map[string]struct { + node cypher.SyntaxNode + variables []string + kinds []string + mapKeys []string + literals []any + operators []cypher.Operator + numRanges int + numMapNode int + }{ + "remove kind matcher": { + node: &cypher.RemoveItem{ + KindMatcher: &cypher.KindMatcher{ + Reference: cypher.NewVariableWithSymbol("node"), + Kinds: graph.Kinds{graph.StringKind("NodeKind")}, + }, + }, + variables: []string{"node"}, + kinds: []string{"NodeKind"}, + }, + "node pattern": { + node: &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("node"), + Kinds: graph.Kinds{graph.StringKind("User")}, + Properties: &cypher.Properties{ + Map: cypher.MapLiteral{ + "name": cypher.NewVariableWithSymbol("name"), + }, + }, + }, + variables: []string{"node", "name"}, + kinds: []string{"User"}, + mapKeys: []string{"name"}, + numMapNode: 1, + }, + "relationship pattern": { + node: &cypher.RelationshipPattern{ + Variable: cypher.NewVariableWithSymbol("rel"), + Kinds: graph.Kinds{graph.StringKind("MemberOf")}, + Range: cypher.NewPatternRange(nil, nil), + Properties: &cypher.Properties{ + Map: cypher.MapLiteral{ + "weight": cypher.NewVariableWithSymbol("weight"), + }, + }, + }, + variables: []string{"rel", "weight"}, + kinds: []string{"MemberOf"}, + mapKeys: []string{"weight"}, + numRanges: 1, + numMapNode: 1, + }, + "projection alias": { + node: &cypher.ProjectionItem{ + Expression: cypher.NewVariableWithSymbol("value"), + Alias: cypher.NewVariableWithSymbol("alias"), + }, + variables: []string{"value", "alias"}, + }, + "id in collection": { + node: &cypher.IDInCollection{ + Variable: cypher.NewVariableWithSymbol("item"), + Expression: cypher.NewVariableWithSymbol("items"), + }, + variables: []string{"item", "items"}, + }, + "pattern part variable": { + node: &cypher.PatternPart{ + Variable: cypher.NewVariableWithSymbol("path"), + PatternElements: []*cypher.PatternElement{{ + Element: &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("node"), + }, + }}, + }, + variables: []string{"path", "node"}, + }, + "skip limit and operators": { + node: &cypher.Projection{ + Skip: &cypher.Skip{ + Value: cypher.NewLiteral(5, false), + }, + Limit: &cypher.Limit{ + Value: cypher.NewLiteral(10, false), + }, + Items: []cypher.Expression{ + &cypher.ProjectionItem{ + Expression: &cypher.Comparison{ + Left: cypher.NewVariableWithSymbol("n"), + Partials: []*cypher.PartialComparison{{ + Operator: cypher.OperatorEquals, + Right: cypher.NewLiteral(1, false), + }}, + }, + }, + }, + }, + variables: []string{"n"}, + literals: []any{1, 5, 10}, + operators: []cypher.Operator{ + cypher.OperatorEquals, + }, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + var ( + visitedVariables []string + visitedKinds []string + visitedMapKeys []string + visitedLiterals []any + visitedOperators []cypher.Operator + visitedRanges int + visitedMapNodes int + ) + + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + switch typedNode := node.(type) { + case *cypher.Variable: + visitedVariables = append(visitedVariables, typedNode.Symbol) + + case graph.Kinds: + for _, kind := range typedNode { + visitedKinds = append(visitedKinds, kind.String()) + } + + case *cypher.MapItem: + visitedMapKeys = append(visitedMapKeys, typedNode.Key) + + case cypher.MapLiteral: + visitedMapNodes++ + + case *cypher.Literal: + visitedLiterals = append(visitedLiterals, typedNode.Value) + + case cypher.Operator: + visitedOperators = append(visitedOperators, typedNode) + + case *cypher.PatternRange: + visitedRanges++ + } + }) + + require.NoError(t, walk.CypherStructural(testCase.node, visitor)) + for _, symbol := range testCase.variables { + require.Contains(t, visitedVariables, symbol) + } + for _, kind := range testCase.kinds { + require.Contains(t, visitedKinds, kind) + } + for _, key := range testCase.mapKeys { + require.Contains(t, visitedMapKeys, key) + } + for _, literal := range testCase.literals { + require.Contains(t, visitedLiterals, literal) + } + for _, operator := range testCase.operators { + require.Contains(t, visitedOperators, operator) + } + require.Equal(t, testCase.numRanges, visitedRanges) + require.Equal(t, testCase.numMapNode, visitedMapNodes) + }) + } +} + +func TestCypherStructuralWalkVisitsModeledChildFields(t *testing.T) { + testCases := map[string]struct { + node cypher.SyntaxNode + visited []string + }{ + "limit value": { + node: &cypher.Limit{Value: cypher.NewLiteral(10, false)}, + visited: []string{"literal:10"}, + }, + "skip value": { + node: &cypher.Skip{Value: cypher.NewLiteral(20, false)}, + visited: []string{"literal:20"}, + }, + "kind matcher reference and kinds": { + node: &cypher.KindMatcher{ + Reference: cypher.NewVariableWithSymbol("node"), + Kinds: graph.Kinds{graph.StringKind("NodeKind")}, + }, + visited: []string{"variable:node", "kinds", "kind:NodeKind"}, + }, + "properties parameter and map": { + node: &cypher.Properties{ + Parameter: cypher.NewParameter("props", map[string]any{}), + Map: cypher.MapLiteral{ + "name": cypher.NewVariableWithSymbol("name"), + }, + }, + visited: []string{"parameter:props", "mapitem:name", "variable:name"}, + }, + "bare map literal": { + node: cypher.MapLiteral{ + "name": cypher.NewVariableWithSymbol("name"), + }, + visited: []string{"mapitem:name", "variable:name"}, + }, + "remove item kind matcher and property": { + node: &cypher.RemoveItem{ + KindMatcher: &cypher.KindMatcher{ + Reference: cypher.NewVariableWithSymbol("node"), + Kinds: graph.Kinds{graph.StringKind("NodeKind")}, + }, + Property: cypher.NewPropertyLookup("target", "name"), + }, + visited: []string{"variable:node", "kind:NodeKind", "variable:target"}, + }, + "id in collection variable and expression": { + node: &cypher.IDInCollection{ + Variable: cypher.NewVariableWithSymbol("item"), + Expression: cypher.NewVariableWithSymbol("items"), + }, + visited: []string{"variable:item", "variable:items"}, + }, + "projection item expression and alias": { + node: &cypher.ProjectionItem{ + Expression: cypher.NewVariableWithSymbol("value"), + Alias: cypher.NewVariableWithSymbol("alias"), + }, + visited: []string{"variable:value", "variable:alias"}, + }, + "pattern part variable and elements": { + node: &cypher.PatternPart{ + Variable: cypher.NewVariableWithSymbol("path"), + PatternElements: []*cypher.PatternElement{{ + Element: &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("node"), + }, + }}, + }, + visited: []string{"variable:path", "variable:node"}, + }, + "relationship pattern metadata": { + node: &cypher.RelationshipPattern{ + Variable: cypher.NewVariableWithSymbol("rel"), + Kinds: graph.Kinds{graph.StringKind("MemberOf")}, + Range: cypher.NewPatternRange(nil, nil), + Properties: &cypher.Properties{ + Map: cypher.MapLiteral{ + "weight": cypher.NewVariableWithSymbol("weight"), + }, + }, + }, + visited: []string{"variable:rel", "kind:MemberOf", "range", "mapitem:weight", "variable:weight"}, + }, + "node pattern metadata": { + node: &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("node"), + Kinds: graph.Kinds{graph.StringKind("User")}, + Properties: &cypher.Properties{ + Map: cypher.MapLiteral{ + "name": cypher.NewVariableWithSymbol("name"), + }, + }, + }, + visited: []string{"variable:node", "kind:User", "mapitem:name", "variable:name"}, + }, + "node pattern empty kind list": { + node: &cypher.NodePattern{ + Kinds: graph.Kinds{}, + }, + visited: []string{"kinds"}, + }, + "partial comparison operator and right": { + node: &cypher.PartialComparison{ + Operator: cypher.OperatorEquals, + Right: cypher.NewVariableWithSymbol("right"), + }, + visited: []string{"operator:=", "variable:right"}, + }, + "partial arithmetic operator and right": { + node: &cypher.PartialArithmeticExpression{ + Operator: cypher.OperatorAdd, + Right: cypher.NewVariableWithSymbol("right"), + }, + visited: []string{"operator:+", "variable:right"}, + }, + "unary add or subtract operator and right": { + node: &cypher.UnaryAddOrSubtractExpression{ + Operator: cypher.OperatorSubtract, + Right: cypher.NewVariableWithSymbol("right"), + }, + visited: []string{"operator:-", "variable:right"}, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + visited := collectCypherWalkLabels(t, testCase.node, walk.CypherStructural) + + for _, expectedLabel := range testCase.visited { + require.Contains(t, visited, expectedLabel) + } + }) + } +} + +func TestCypherStructuralWalkVisitsFullASTTypeSequence(t *testing.T) { + query := &cypher.RegularQuery{ + SingleQuery: &cypher.SingleQuery{ + SinglePartQuery: &cypher.SinglePartQuery{ + ReadingClauses: []*cypher.ReadingClause{{ + Match: &cypher.Match{ + Pattern: []*cypher.PatternPart{{ + Variable: cypher.NewVariableWithSymbol("path"), + PatternElements: []*cypher.PatternElement{ + { + Element: &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("n"), + Kinds: graph.Kinds{graph.StringKind("User")}, + Properties: &cypher.Properties{ + Map: cypher.MapLiteral{ + "name": cypher.NewVariableWithSymbol("name"), + }, + }, + }, + }, + { + Element: &cypher.RelationshipPattern{ + Variable: cypher.NewVariableWithSymbol("r"), + Kinds: graph.Kinds{graph.StringKind("MemberOf")}, + Range: cypher.NewPatternRange(nil, nil), + Properties: &cypher.Properties{ + Parameter: cypher.NewParameter("relProps", map[string]any{}), + }, + }, + }, + { + Element: &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("m"), + }, + }, + }, + }}, + Where: newCypherWhere(&cypher.Comparison{ + Left: cypher.NewVariableWithSymbol("n"), + Partials: []*cypher.PartialComparison{{ + Operator: cypher.OperatorEquals, + Right: cypher.NewLiteral(true, false), + }}, + }), + }, + }}, + UpdatingClauses: []cypher.Expression{ + cypher.NewUpdatingClause(&cypher.Set{ + Items: []*cypher.SetItem{{ + Left: cypher.NewPropertyLookup("n", "seen"), + Right: cypher.NewVariableWithSymbol("seen"), + }}, + }), + }, + Return: &cypher.Return{ + Projection: &cypher.Projection{ + Items: []cypher.Expression{ + &cypher.ProjectionItem{ + Expression: cypher.NewVariableWithSymbol("n"), + Alias: cypher.NewVariableWithSymbol("alias"), + }, + }, + Order: &cypher.Order{ + Items: []*cypher.SortItem{{ + Expression: cypher.NewVariableWithSymbol("alias"), + }}, + }, + Skip: &cypher.Skip{Value: cypher.NewLiteral(1, false)}, + Limit: &cypher.Limit{Value: cypher.NewLiteral(2, false)}, + }, + }, + }, + }, + } + + require.Equal(t, []string{ + "*cypher.RegularQuery", + "*cypher.SingleQuery", + "*cypher.SinglePartQuery", + "*cypher.ReadingClause", + "*cypher.Match", + "*cypher.PatternPart", + "*cypher.Variable", + "*cypher.PatternElement", + "*cypher.NodePattern", + "*cypher.Variable", + "graph.Kinds", + "*cypher.Properties", + "cypher.MapLiteral", + "*cypher.MapItem", + "*cypher.Variable", + "*cypher.PatternElement", + "*cypher.RelationshipPattern", + "*cypher.Variable", + "graph.Kinds", + "*cypher.PatternRange", + "*cypher.Properties", + "*cypher.Parameter", + "*cypher.PatternElement", + "*cypher.NodePattern", + "*cypher.Variable", + "*cypher.Where", + "*cypher.Comparison", + "*cypher.Variable", + "*cypher.PartialComparison", + "cypher.Operator", + "*cypher.Literal", + "*cypher.UpdatingClause", + "*cypher.Set", + "*cypher.SetItem", + "*cypher.PropertyLookup", + "*cypher.Variable", + "*cypher.Variable", + "*cypher.Return", + "*cypher.Projection", + "*cypher.ProjectionItem", + "*cypher.Variable", + "*cypher.Variable", + "*cypher.Order", + "*cypher.SortItem", + "*cypher.Variable", + "*cypher.Skip", + "*cypher.Literal", + "*cypher.Limit", + "*cypher.Literal", + }, collectCypherWalkTypes(t, query, walk.CypherStructural)) +} + +func collectCypherWalkLabels(t *testing.T, node cypher.SyntaxNode, walkFunc func(cypher.SyntaxNode, walk.Visitor[cypher.SyntaxNode]) error) []string { + t.Helper() + + var visited []string + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + switch typedNode := node.(type) { + case *cypher.Variable: + visited = append(visited, "variable:"+typedNode.Symbol) + + case *cypher.Parameter: + visited = append(visited, "parameter:"+typedNode.Symbol) + + case *cypher.MapItem: + visited = append(visited, "mapitem:"+typedNode.Key) + + case *cypher.Literal: + visited = append(visited, fmt.Sprintf("literal:%v", typedNode.Value)) + + case cypher.Operator: + visited = append(visited, "operator:"+typedNode.String()) + + case graph.Kinds: + visited = append(visited, "kinds") + for _, kind := range typedNode { + visited = append(visited, "kind:"+kind.String()) + } + + case *cypher.PatternRange: + visited = append(visited, "range") + } + }) + + require.NoError(t, walkFunc(node, visitor)) + return visited +} + +func collectCypherWalkTypes(t *testing.T, node cypher.SyntaxNode, walkFunc func(cypher.SyntaxNode, walk.Visitor[cypher.SyntaxNode]) error) []string { + t.Helper() + + var visited []string + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + visited = append(visited, fmt.Sprintf("%T", node)) + }) + + require.NoError(t, walkFunc(node, visitor)) + return visited +} + +func newCypherWhere(expressions ...cypher.Expression) *cypher.Where { + where := cypher.NewWhere() + where.AddSlice(expressions) + return where +} + +type genericWalkTestNode struct { + name string + children []*genericWalkTestNode +} + +type recordingGenericWalkVisitor struct { + walk.Visitor[*genericWalkTestNode] + + events []string + onEnter func(*genericWalkTestNode) + onVisit func(*genericWalkTestNode) + onExit func(*genericWalkTestNode) +} + +func newRecordingGenericWalkVisitor() *recordingGenericWalkVisitor { + return &recordingGenericWalkVisitor{ + Visitor: walk.NewVisitor[*genericWalkTestNode](), + } +} + +func (s *recordingGenericWalkVisitor) Enter(node *genericWalkTestNode) { + s.events = append(s.events, "enter:"+node.name) + if s.onEnter != nil { + s.onEnter(node) + } +} + +func (s *recordingGenericWalkVisitor) Visit(node *genericWalkTestNode) { + s.events = append(s.events, "visit:"+node.name) + if s.onVisit != nil { + s.onVisit(node) + } +} + +func (s *recordingGenericWalkVisitor) Exit(node *genericWalkTestNode) { + s.events = append(s.events, "exit:"+node.name) + if s.onExit != nil { + s.onExit(node) + } +} + +func newGenericWalkTestCursor(node *genericWalkTestNode) (*walk.Cursor[*genericWalkTestNode], error) { + cursor := &walk.Cursor[*genericWalkTestNode]{ + Node: node, + } + cursor.AddBranches(node.children...) + return cursor, nil +} + +func TestCypherWalkSupportsKnownSyntaxNodeTypes(t *testing.T) { + testCases := map[string]cypher.SyntaxNode{ + "arithmetic expression": &cypher.ArithmeticExpression{}, + "comparison": &cypher.Comparison{}, + "conjunction": cypher.NewConjunction(), + "create": &cypher.Create{}, + "delete": &cypher.Delete{}, + "disjunction": cypher.NewDisjunction(), + "exclusive disjunction": cypher.NewExclusiveDisjunction(), + "filter expression": &cypher.FilterExpression{}, + "function invocation": &cypher.FunctionInvocation{}, + "graph kinds": graph.Kinds{graph.StringKind("NodeKind")}, + "id in collection": &cypher.IDInCollection{}, + "kind matcher": &cypher.KindMatcher{}, + "limit": &cypher.Limit{}, + "list literal": cypher.NewListLiteral(), + "literal": cypher.NewLiteral(1, false), + "map item": &cypher.MapItem{}, + "map literal": cypher.MapLiteral{"value": cypher.NewLiteral(1, false)}, + "match": &cypher.Match{}, + "merge": &cypher.Merge{}, + "merge action": &cypher.MergeAction{}, + "multipart query": &cypher.MultiPartQuery{}, + "multipart query part": &cypher.MultiPartQueryPart{}, + "negation": &cypher.Negation{}, + "node pattern": &cypher.NodePattern{}, + "operator": cypher.Operator("="), + "order": &cypher.Order{}, + "parameter": &cypher.Parameter{}, + "parenthetical": &cypher.Parenthetical{}, + "partial arithmetic": &cypher.PartialArithmeticExpression{}, + "partial comparison": &cypher.PartialComparison{}, + "pattern element": &cypher.PatternElement{}, + "pattern part": &cypher.PatternPart{}, + "pattern predicate": &cypher.PatternPredicate{}, + "pattern range": &cypher.PatternRange{}, + "projection": &cypher.Projection{}, + "projection item": &cypher.ProjectionItem{}, + "properties map": &cypher.Properties{Map: cypher.MapLiteral{"value": cypher.NewLiteral(1, false)}}, + "properties parameter": &cypher.Properties{Parameter: cypher.NewParameter("props", map[string]any{})}, + "quantifier": &cypher.Quantifier{}, + "range quantifier": &cypher.RangeQuantifier{}, + "reading clause": &cypher.ReadingClause{}, + "regular query": &cypher.RegularQuery{}, + "relationship pattern": &cypher.RelationshipPattern{}, + "remove": &cypher.Remove{}, + "remove item": &cypher.RemoveItem{}, + "return": &cypher.Return{}, + "set": &cypher.Set{}, + "set item": &cypher.SetItem{}, + "single part query": &cypher.SinglePartQuery{}, + "single query": &cypher.SingleQuery{}, + "skip": &cypher.Skip{}, + "sort item": &cypher.SortItem{}, + "unary add/subtract expression": &cypher.UnaryAddOrSubtractExpression{}, + "unwind": &cypher.Unwind{}, + "updating clause": &cypher.UpdatingClause{}, + "variable": &cypher.Variable{}, + "where": &cypher.Where{}, + "with": &cypher.With{}, + } + + for name, node := range testCases { + t.Run(name, func(t *testing.T) { + visitor := walk.NewSimpleVisitor[cypher.SyntaxNode](func(cypher.SyntaxNode, walk.VisitorHandler) {}) + require.NoError(t, walk.Cypher(node, visitor)) + }) + } +} + func TestPgSQLWalkVisitsJoinTable(t *testing.T) { query := pgsql.Query{ Body: pgsql.Select{ diff --git a/docs/cypher_ast_tooling_validation.md b/docs/cypher_ast_tooling_validation.md new file mode 100644 index 00000000..e07150aa --- /dev/null +++ b/docs/cypher_ast_tooling_validation.md @@ -0,0 +1,145 @@ +# Cypher AST Tooling Validation + +Validation date: 2026-05-27. + +This records the validation pass for the Cypher AST tooling test-hardening work. + +## Review Remediation Preflight + +- Branch: `main`, 21 commits ahead of `upstream/main`. +- Baseline: `upstream/main` resolves locally at `9fe779703362543ce2ef6a46fd93f4c040ac1ac0`. +- Existing untracked files left untouched during preflight: `review.md` and `docs/cypher_support_v4.md`. +- Integration validation was run separately for the supplied Neo4j and PostgreSQL connection strings. + +## Semantic Traversal Audit + +`walk.Cypher` consumers in `cypher/models/pgsql/translate`, `cypher/models/pgsql/optimize`, `query/builder.go`, and +`query/neo4j` were audited for the newly reachable `*cypher.ExclusiveDisjunction` node. The PostgreSQL translator +needed an explicit XOR translation path; it now lowers XOR expression-list joins to PostgreSQL boolean inequality. +Reference and source collectors operate on descendant variables and tolerate the newly visible operand sub-trees. + +The reviewed broadening where bare `cypher.MapLiteral` values expanded into `*cypher.MapItem` children in all semantic +expression contexts was reverted. `walk.Cypher` now preserves the upstream semantic contract: bare map literals are +leaf nodes, while `*cypher.Properties` exposes map items for pattern/create/update property maps. `walk.CypherStructural` +owns bare map literal expansion for AST inspection use cases. + +## Walker Benchmark Comparison + +The prior benchmark pass was captured with: + +```bash +go test -run '^$' -bench 'BenchmarkCypher' -benchmem -count=10 ./cypher/models/walk +``` + +`upstream/main` does not have `walk.CypherStructural`, so the upstream worktree used the branch benchmark file with the +structural benchmark omitted. The comparable semantic walker results were: + +| Benchmark | `upstream/main` | `HEAD` | Delta | +| --- | ---: | ---: | ---: | +| `CypherWalkLargeProjection-20` | 69.72 us/op | 83.54 us/op | +19.83% | + +The earlier branch-only map-literal benchmark delta came from semantic traversal expanding bare map literals. That +broadening was reverted, so the semantic map-literal benchmark is expected to remain leaf-equivalent with upstream. +The projection benchmark remains slower after moving optional-field nil filtering into cursor constructors; allocations +are effectively flat at 74.55 KiB/op on `upstream/main` vs. 74.41 KiB/op on `HEAD`. + +The branch-only structural benchmark measured: + +| Benchmark | `HEAD` | +| --- | ---: | +| `CypherStructuralWalkLongPattern-20` | 70.19 us/op, 49.02 KiB/op, 1288 allocs/op | + +## Walker Coverage Comparison + +Package-local coverage was captured with: + +```bash +go test -covermode=count -coverprofile=.coverage/walk-head.cover ./cypher/models/walk +go test -covermode=count -coverprofile=.coverage/walk-upstream.cover ./cypher/models/walk +``` + +| Revision | Coverage | +| --- | ---: | +| `upstream/main` | 53.2% | +| `HEAD` | 81.3% | + +`HEAD` does not lower `cypher/models/walk` package coverage. + +## PR Description Notes + +Behavior changes to call out: + +- `cypher.MapLiteral.Keys()` now returns keys in ascending lexical order. It previously returned descending order. +- `walk.Cypher` now traverses `*cypher.ExclusiveDisjunction`; translator and collector visitors now see XOR operand + sub-trees. +- `walk.Cypher` keeps bare `cypher.MapLiteral` values as semantic leaves; only `*cypher.Properties` exposes + `*cypher.MapItem` children in semantic traversal. `walk.CypherStructural` traverses bare map literal contents. +- `walk.Generic` continues to surface nil roots and nil branches through cursor-constructor negotiation errors instead + of treating them as successful optional traversals. +- `cancelableVisitorHandler.SetError` now accumulates repeated errors with `errors.Join` in a left-associated chain + instead of storing a flat slice before joining. + +New exported APIs: + +- `walk.CypherStructural` +- `walk.NewSimpleVisitorWithOrder` +- `walk.OrderInfix` +- `walk.OrderPostfix` +- `cypher.MapLiteral.ForEachItem` + +`README.md` has build/test/metric workflow guidance but no walker API summary, so no README API update was needed. + +## Commands + +- `go test ./cypher/models/walk` + - Result: pass. +- `go test ./cypher/models/walk ./cypher/models/cypher ./cypher/models/cypher/format` + - Result: pass. +- `go test -covermode=count -coverpkg=./cypher/models/walk,./cypher/models/cypher,./cypher/models/cypher/format -coverprofile=/tmp/cypher_ast_tooling_validation.cover ./cypher/models/walk ./cypher/models/cypher ./cypher/models/cypher/format` + - Result: pass. + - Package coverage: `walk` 27.1%, `cypher` 18.4%, `format` 28.4%. +- `make test` + - Result: pass. + - Wrote `.coverage/coverage.txt`. +- `make format` + - Initial result: fail in the local environment with `xargs: goimports: Permission denied` because the + wrapper-managed Go bin directory was not on `PATH`. +- `PATH="/home/zinic/codex/config/go/bin:$PATH" make format` + - Result: pass. + - No file changes after formatting. +- `CONNECTION_STRING= make test_neo4j` + - Result: pass. +- `CONNECTION_STRING= make test_pg` + - Result: pass. + +## CRAP Snapshot + +CRAP was calculated from the focused coverage profile for the altered Cypher AST tooling paths. + +| CRAP | Complexity | Coverage | Function | +| ---: | ---: | ---: | --- | +| 18.00 | 18 | 100.0% | `cypher/models/walk/walk.go:189 Generic` | +| 14.00 | 14 | 100.0% | `cypher/models/walk/walk_cypher.go:66 newCypherStructuralValueWalkCursor` | +| 11.00 | 11 | 100.0% | `cypher/models/walk/walk_cypher.go:540 newCypherUpdatingWalkCursor` | +| 11.00 | 11 | 100.0% | `cypher/models/walk/walk_cypher.go:478 newCypherClauseWalkCursor` | +| 10.00 | 10 | 100.0% | `cypher/models/walk/walk_cypher.go:267 newCypherPredicateWalkCursor` | +| 10.00 | 10 | 100.0% | `cypher/models/walk/walk_cypher.go:217 newCypherValueWalkCursor` | +| 9.02 | 9 | 94.1% | `cypher/models/walk/walk_cypher.go:175 newCypherWalkCursor` | +| 9.00 | 9 | 100.0% | `cypher/models/cypher/format/format.go:326 (Emitter).formatMapLiteral` | +| 8.03 | 8 | 92.3% | `cypher/models/walk/walk_cypher.go:591 newCypherPatternWalkCursor` | +| 8.00 | 8 | 100.0% | `cypher/models/walk/walk_cypher.go:347 newCypherProjectionWalkCursor` | +| 7.00 | 7 | 100.0% | `cypher/models/walk/walk_cypher.go:401 newCypherStatementWalkCursor` | +| 6.00 | 6 | 100.0% | `cypher/models/walk/walk_cypher.go:316 newCypherOperatorWalkCursor` | +| 6.00 | 6 | 100.0% | `cypher/models/walk/walk_cypher.go:143 newCypherStructuralPatternWalkCursor` | +| 6.00 | 2 | 0.0% | `cypher/models/cypher/model.go:1001 (*ListLiteral).Keys` | +| 4.00 | 4 | 100.0% | `cypher/models/walk/walk_cypher.go:458 newCypherSinglePartQueryWalkCursor` | +| 4.00 | 4 | 100.0% | `cypher/models/walk/walk_cypher.go:438 newCypherMultiPartQueryPartWalkCursor` | +| 3.00 | 3 | 100.0% | `cypher/models/walk/walk_cypher.go:390 newCypherQueryWalkCursor` | +| 3.00 | 3 | 100.0% | `cypher/models/walk/walk_cypher.go:55 newCypherStructuralWalkCursor` | +| 3.00 | 3 | 100.0% | `cypher/models/cypher/model.go:949 (MapLiteral).ForEachItem` | +| 2.00 | 2 | 100.0% | `cypher/models/walk/walk_cypher.go:204 newCypherLeafWalkCursor` | +| 2.00 | 2 | 100.0% | `cypher/models/cypher/model.go:959 (MapLiteral).Keys` | +| 1.00 | 1 | 100.0% | `cypher/models/cypher/model.go:935 (MapLiteral).Items` | + +`(*ListLiteral).Keys` is included in the snapshot because it matched the measured function-name set, but it was not part +of this change sequence. diff --git a/docs/cypher_support_v4.md b/docs/cypher_support_v4.md new file mode 100644 index 00000000..9c71cfa4 --- /dev/null +++ b/docs/cypher_support_v4.md @@ -0,0 +1,212 @@ +# Cypher Translation Support V4 Plan + +This plan organizes implementation work for the next CySQL translation completeness targets: +list indexing and slicing, CASE expressions, existential subqueries, recursive pattern predicates, +list and pattern comprehensions, and UNION / UNION ALL. + +## Phase 0: Baseline + +Update the support matrix so it reflects current translator behavior. The existing support document is stale in places: +CREATE, UNWIND, quantifiers, labels(), head(), and tail() have translator and test coverage now, despite older notes +listing some of them as unsupported or defective. + +Add negative tests for the exact currently-missing constructs before implementing each feature. This keeps every feature +with a clear before/after signal. + +## 1. List Indexing And Slicing + +Target forms: + +```cypher +nodes(p)[0] +relationships(p)[-1] +list[1] +list[1..3] +list[..3] +list[1..] +``` + +Implementation: + +- Add Cypher AST nodes such as `ListIndexExpression` and `ListSliceExpression`. +- Implement `oC_ListOperatorExpression` handling in `frontend/expression.go`. +- Add copy, format, and walk support. +- Lower to existing PgSQL `ArrayIndex` and `ArraySlice`. +- Define Cypher-to-Postgres index semantics explicitly. Cypher is zero-based; PostgreSQL arrays are one-based. +- Handle negative indexes with `cardinality(array) + index + 1`. +- Normalize open-ended slices and test empty/null behavior against Neo4j. + +Tests: + +- Translation cases for literals, parameters, `nodes(p)`, `relationships(p)`, and `collect(...)`. +- Integration cases for first, last, middle, out-of-range, null list, and empty list. + +## 2. CASE Expressions + +Target forms: + +```cypher +CASE WHEN cond THEN x ELSE y END +CASE expr WHEN value THEN x ELSE y END +``` + +Implementation: + +- Remove the frontend unsupported-rule rejection for `oC_CaseExpression`. +- Add AST nodes: `CaseExpression` and `CaseAlternative`. +- Add visitors for `oC_CaseExpression` and `oC_CaseAlternative`. +- Reuse the existing `pgsql.Case` model and formatter. +- Add type inference for CASE result type, likely by finding a common supertype across `THEN` and `ELSE`. +- Ensure aggregate grouping logic can traverse CASE expressions. + +Tests: + +- Scalar CASE in `RETURN`, `WHERE`, and `WITH`. +- CASE over properties, labels, nulls, parameters, and aggregates. +- CASE inside `ORDER BY` and grouped aggregation. + +## 3. Existential Subqueries + +Target forms: + +```cypher +EXISTS { MATCH ... } +EXISTS { (n)-[:R]->() } +NOT EXISTS { ... } +``` + +Implementation: + +- Remove the frontend unsupported-rule rejection for `oC_ExistentialSubquery`. +- Add an AST node `ExistentialSubquery` with either a `RegularQuery` or `Pattern + Where`. +- For pattern-only forms, translate through existing pattern predicate machinery. +- For query forms, compile to `pgsql.ExistsExpression` with a correlated subquery. +- Reuse the current scope for outer references, but isolate subquery-local bindings. +- Start read-only: support `MATCH`, `WHERE`, and `RETURN` first. Defer updates inside existential subqueries unless + backend semantics require a hard rejection. + +Tests: + +- Correlated and uncorrelated exists. +- Nested `EXISTS`. +- `NOT EXISTS`. +- Optional matches inside exists. +- Equivalence with current pattern predicates where applicable. + +## 4. Recursive Pattern Predicates + +Current blocker: pattern predicates reject traversal steps with expansion. + +Implementation: + +- Reuse variable-length traversal translation, but render it under an `EXISTS` subquery instead of a top-level match + frame. +- Start with single expansion predicates: + +```cypher +WHERE (n)-[:R*1..3]->(m) +``` + +- Then support anonymous endpoint and endpoint predicates: + +```cypher +WHERE (n)-[:R*]->(:Kind {prop: 1}) +``` + +- Ensure pattern predicate frames do not become visible to outer projection or path rendering. +- Add optimizer safety checks so recursive predicate lowering does not accidentally use path materialization meant for + returned paths. + +Tests: + +- Positive and negated recursive predicates. +- Bound source, bound target, both bound, and anonymous target. +- Recursive predicate combined with normal match expansion in the same query. + +## 5. List Comprehensions + +Target forms: + +```cypher +[x IN list WHERE pred | expr] +[x IN list | expr] +[x IN list WHERE pred] +``` + +Implementation: + +- Add an AST node `ListComprehension` using the existing `FilterExpression` shape plus an optional projection + expression. +- Translate with a correlated `SELECT array_agg(...) FROM unnest(...)`. +- Define the comprehension variable in an isolated scope frame. +- Use `coalesce(array_agg(...), array[]::[])` to preserve list-return shape. +- Reuse quantifier `IDInCollection` handling where practical, without over-coupling list comprehensions to boolean + quantifiers. + +Tests: + +- Literal lists, property arrays, parameters, and `collect(...)`. +- Projection omitted vs. projection provided. +- Predicates using outer variables. +- Empty input and null input semantics. + +## 6. Pattern Comprehensions + +Target forms: + +```cypher +[(a)-->(b) | b.name] +[p = (a)-[*]->(b) WHERE b.enabled | p] +``` + +Implementation: + +- Add an AST node `PatternComprehension` with optional path variable, pattern, optional where, and projection + expression. +- Lower to a correlated subquery that runs pattern translation and aggregates the projected expression. +- Start with fixed-length patterns only. +- Add variable-length/path support after recursive pattern predicates are solid. +- Use the same scope isolation rules as existential subqueries. + +Tests: + +- Fixed relationship pattern returning nodes/properties. +- Correlated outer variable source. +- Optional path binding. +- Empty result returns empty list. +- Later: variable-length pattern comprehension. + +## 7. UNION / UNION ALL + +This should land last because it touches query shape and result contracts. + +Implementation: + +- Remove the frontend unsupported-rule rejection for `oC_Union`. +- Extend `RegularQuery` AST to represent multiple `SingleQuery` branches plus `ALL` flags. +- Translate each branch independently with its own scope and parameter namespace. +- Validate projection compatibility: same column count and compatible aliases/types. +- Lower to existing PgSQL `SetOperation` with `UNION` / `UNION ALL`. +- Decide alias source. Cypher generally uses the first branch's return names. +- Block branch-local updates initially unless existing update semantics are clearly safe. + +Tests: + +- `UNION` distinct vs. `UNION ALL`. +- Matching projection aliases and mismatched aliases. +- Parameters in both branches. +- Branch-level vs. final `ORDER BY`, `SKIP`, and `LIMIT`, if the grammar allows the form. + +## Suggested Order + +1. List indexing and slicing. +2. CASE expressions. +3. Existential subqueries, pattern-only first. +4. Recursive pattern predicates. +5. List comprehensions. +6. Pattern comprehensions. +7. UNION / UNION ALL. + +This order builds reusable machinery before the harder features need it: array indexing helps path/list expressions, +CASE exercises scalar AST plumbing, existential subqueries establish correlated subquery scope, and that scope model then +carries comprehensions and recursive pattern predicates. diff --git a/docs/cypher_walker_semantics.md b/docs/cypher_walker_semantics.md new file mode 100644 index 00000000..124e9454 --- /dev/null +++ b/docs/cypher_walker_semantics.md @@ -0,0 +1,16 @@ +# Cypher Walker Semantics + +Cypher has two traversal needs that should stay separate: + +- `walk.Cypher` is the semantic walker used by translation and optimizer code. It walks expression-bearing children that participate in translation order, and intentionally skips declaration-only fields such as projection aliases, pattern variables, kind metadata, and quantifier binding variables where those fields are handled by parent nodes or clause-specific logic. Bare `cypher.MapLiteral` values are semantic leaves; only `*cypher.Properties` exposes map item/value children in semantic traversal. +- `walk.CypherStructural` is the structural walker for AST inspection. It should visit all modeled child nodes, including declarations, aliases, pattern metadata, relationship ranges, and map/list contents. + +When adding a Cypher AST element, update both walker modes deliberately: + +- Add semantic traversal only for fields that should affect translator/optimizer expression stack behavior. +- Add structural traversal for every modeled child field. +- Add tests that assert actual visited children, not only that cursor construction succeeds. + +Nil handling is part of the contract. Nil traversal roots and nil branches should surface cursor negotiation errors, not successful no-op walks. Optional nil pointer children should be skipped by the cursor constructor that owns the optional field, while valid empty syntax nodes such as empty map literals, empty list literals, empty kind lists, and empty identifiers should still be visitable when they are the traversal root. + +Visitor cancellation is immediate. `SetDone`, `SetError`, and `SetErrorf` stop traversal after the current callback returns; the walker does not unwind pending `Exit` callbacks for nodes still on the traversal stack. Visitors that need balanced enter/exit state should use `Consume` for subtree pruning and reserve cancellation/error APIs for terminal traversal.