From af42adfc7c8b10004bd7b00484a8fdd1c066bb28 Mon Sep 17 00:00:00 2001 From: Matthew McNeely Date: Wed, 27 May 2026 13:46:17 -0400 Subject: [PATCH] Add WithMaxRecvMsgSize option for gRPC connections Routes through dgo.NewClient when set so callers can raise the default 4 MB gRPC receive limit for large query responses against remote Dgraph. Co-Authored-By: Claude Opus 4.7 (1M context) --- client.go | 114 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 110 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index e37f84b..be9813b 100644 --- a/client.go +++ b/client.go @@ -9,8 +9,10 @@ import ( "context" "errors" "fmt" + "net/url" "os" "reflect" + "strconv" "strings" "sync" @@ -19,6 +21,8 @@ import ( dg "github.com/dolan-in/dgman/v2" "github.com/go-logr/logr" "github.com/go-playground/validator/v10" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" ) // Client provides an interface for ModusGraph operations @@ -119,6 +123,7 @@ type clientOptions struct { poolSize int maxEdgeTraversal int cacheSizeMB int + maxRecvMsgSize int namespace string logger logr.Logger validator StructValidator @@ -173,6 +178,17 @@ func WithCacheSizeMB(size int) ClientOpt { } } +// WithMaxRecvMsgSize sets the maximum gRPC receive message size in bytes for +// remote Dgraph connections. When greater than zero, the client is constructed +// via dgo.NewClient with grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(size)) +// applied so large query responses (e.g. wide subgraphs or vector results) do +// not exceed the default 4 MB gRPC limit. Ignored for embedded (file://) URIs. +func WithMaxRecvMsgSize(size int) ClientOpt { + return func(o *clientOptions) { + o.maxRecvMsgSize = size + } +} + // WithValidator sets a validator instance for struct validation. // The validator will be used to validate structs before insert, upsert, and update operations. // If no validator is provided, validation will be skipped. @@ -259,10 +275,24 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) { switch { case strings.HasPrefix(uri, dgraphURIPrefix): - client.pool = newClientPool(options.poolSize, func() (*dgo.Dgraph, error) { + factory := func() (*dgo.Dgraph, error) { client.logger.V(2).Info("Opening new Dgraph connection", "uri", uri) return dgo.Open(uri) - }, client.logger) + } + if options.maxRecvMsgSize > 0 { + endpoint, dgoOpts, err := parseDgraphURI(uri) + if err != nil { + return nil, err + } + dgoOpts = append(dgoOpts, dgo.WithGrpcOption( + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(options.maxRecvMsgSize)))) + factory = func() (*dgo.Dgraph, error) { + client.logger.V(2).Info("Opening new Dgraph connection", + "uri", uri, "maxRecvMsgSize", options.maxRecvMsgSize) + return dgo.NewClient(endpoint, dgoOpts...) + } + } + client.pool = newClientPool(options.poolSize, factory, client.logger) dg.SetLogger(client.logger) clientMap[key] = client return client, nil @@ -308,6 +338,81 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) { } +// parseDgraphURI mirrors dgo.Open's connection-string parsing so callers can +// route through dgo.NewClient with additional dgo.ClientOption values (e.g. +// custom grpc.DialOption settings). It returns the host:port endpoint and the +// dgo client options derived from the URI's auth, sslmode, and namespace params. +func parseDgraphURI(connStr string) (string, []dgo.ClientOption, error) { + u, err := url.Parse(connStr) + if err != nil { + return "", nil, fmt.Errorf("invalid connection string: %w", err) + } + if u.Scheme != "dgraph" { + return "", nil, errors.New("invalid scheme: must start with dgraph://") + } + if !strings.Contains(u.Host, ":") { + return "", nil, errors.New("invalid connection string: host url must have both host and port") + } + if strings.Split(u.Host, ":")[1] == "" { + return "", nil, errors.New("invalid connection string: missing port after port-separator colon") + } + + params, err := url.ParseQuery(u.RawQuery) + if err != nil { + return "", nil, fmt.Errorf("malformed connection string: %w", err) + } + + apiKey := params.Get("apikey") + bearerToken := params.Get("bearertoken") + if apiKey != "" && bearerToken != "" { + return "", nil, errors.New("invalid connection string: both apikey and bearertoken cannot be provided") + } + + opts := []dgo.ClientOption{} + if apiKey != "" { + opts = append(opts, dgo.WithDgraphAPIKey(apiKey)) + } + if bearerToken != "" { + opts = append(opts, dgo.WithBearerToken(bearerToken)) + } + + sslMode := params.Get("sslmode") + if sslMode == "" { + sslMode = "disable" + } + switch sslMode { + case "disable": + opts = append(opts, dgo.WithGrpcOption( + grpc.WithTransportCredentials(insecure.NewCredentials()))) + case "require": + opts = append(opts, dgo.WithSkipTLSVerify()) + case "verify-ca": + opts = append(opts, dgo.WithSystemCertPool()) + default: + return "", nil, fmt.Errorf( + "invalid SSL mode: %s (must be one of disable, require, verify-ca)", sslMode) + } + + if nsParam := params.Get("namespace"); nsParam != "" { + nsID, err := strconv.ParseUint(nsParam, 10, 64) + if err != nil { + return "", nil, fmt.Errorf("invalid namespace ID: %w", err) + } + opts = append(opts, dgo.WithNamespace(nsID)) + } + + if u.User != nil { + username := u.User.Username() + password, _ := u.User.Password() + if username == "" || password == "" { + return "", nil, errors.New("invalid connection string: both username and password must be provided") + } + opts = append(opts, dgo.WithACLCreds(username, password)) + } + + return u.Host, opts, nil +} + type client struct { uri string engine *Engine @@ -325,8 +430,9 @@ func (c client) key() string { if c.options.embeddingProvider != nil { embeddingKey = fmt.Sprintf("%p", c.options.embeddingProvider) } - return fmt.Sprintf("%s:%t:%d:%d:%d:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize, - c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.namespace, validatorKey, embeddingKey) + return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize, + c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.maxRecvMsgSize, + c.options.namespace, validatorKey, embeddingKey) } // embeddingProvider implements the embeddingClient interface, exposing the