Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 110 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"context"
"errors"
"fmt"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"sync"

Expand All @@ -19,6 +21,8 @@ import (
dg "github.com/dolan-in/dgman/v2"
"github.com/go-logr/logr"
"github.com/go-playground/validator/v10"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)

// Client provides an interface for ModusGraph operations
Expand Down Expand Up @@ -119,6 +123,7 @@ type clientOptions struct {
poolSize int
maxEdgeTraversal int
cacheSizeMB int
maxRecvMsgSize int
namespace string
logger logr.Logger
validator StructValidator
Expand Down Expand Up @@ -173,6 +178,17 @@ func WithCacheSizeMB(size int) ClientOpt {
}
}

// WithMaxRecvMsgSize sets the maximum gRPC receive message size in bytes for
// remote Dgraph connections. When greater than zero, the client is constructed
// via dgo.NewClient with grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(size))
// applied so large query responses (e.g. wide subgraphs or vector results) do
// not exceed the default 4 MB gRPC limit. Ignored for embedded (file://) URIs.
func WithMaxRecvMsgSize(size int) ClientOpt {
return func(o *clientOptions) {
o.maxRecvMsgSize = size
}
}

// WithValidator sets a validator instance for struct validation.
// The validator will be used to validate structs before insert, upsert, and update operations.
// If no validator is provided, validation will be skipped.
Expand Down Expand Up @@ -259,10 +275,24 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) {

switch {
case strings.HasPrefix(uri, dgraphURIPrefix):
client.pool = newClientPool(options.poolSize, func() (*dgo.Dgraph, error) {
factory := func() (*dgo.Dgraph, error) {
client.logger.V(2).Info("Opening new Dgraph connection", "uri", uri)
return dgo.Open(uri)
}, client.logger)
}
if options.maxRecvMsgSize > 0 {
endpoint, dgoOpts, err := parseDgraphURI(uri)
if err != nil {
return nil, err
}
dgoOpts = append(dgoOpts, dgo.WithGrpcOption(
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(options.maxRecvMsgSize))))
factory = func() (*dgo.Dgraph, error) {
client.logger.V(2).Info("Opening new Dgraph connection",
"uri", uri, "maxRecvMsgSize", options.maxRecvMsgSize)
return dgo.NewClient(endpoint, dgoOpts...)
}
}
client.pool = newClientPool(options.poolSize, factory, client.logger)
dg.SetLogger(client.logger)
clientMap[key] = client
return client, nil
Expand Down Expand Up @@ -308,6 +338,81 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) {

}

// parseDgraphURI mirrors dgo.Open's connection-string parsing so callers can
// route through dgo.NewClient with additional dgo.ClientOption values (e.g.
// custom grpc.DialOption settings). It returns the host:port endpoint and the
// dgo client options derived from the URI's auth, sslmode, and namespace params.
func parseDgraphURI(connStr string) (string, []dgo.ClientOption, error) {
u, err := url.Parse(connStr)
if err != nil {
return "", nil, fmt.Errorf("invalid connection string: %w", err)
}
if u.Scheme != "dgraph" {
return "", nil, errors.New("invalid scheme: must start with dgraph://")
}
if !strings.Contains(u.Host, ":") {
return "", nil, errors.New("invalid connection string: host url must have both host and port")
}
if strings.Split(u.Host, ":")[1] == "" {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Port parsing in parseDgraphURI rejects valid IPv6 endpoints when WithMaxRecvMsgSize is enabled.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At client.go, line 356:

<comment>Port parsing in `parseDgraphURI` rejects valid IPv6 endpoints when `WithMaxRecvMsgSize` is enabled.</comment>

<file context>
@@ -308,6 +338,81 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) {
+	if !strings.Contains(u.Host, ":") {
+		return "", nil, errors.New("invalid connection string: host url must have both host and port")
+	}
+	if strings.Split(u.Host, ":")[1] == "" {
+		return "", nil, errors.New("invalid connection string: missing port after port-separator colon")
+	}
</file context>
Suggested change
if strings.Split(u.Host, ":")[1] == "" {
if u.Port() == "" {

return "", nil, errors.New("invalid connection string: missing port after port-separator colon")
}

params, err := url.ParseQuery(u.RawQuery)
if err != nil {
return "", nil, fmt.Errorf("malformed connection string: %w", err)
}

apiKey := params.Get("apikey")
bearerToken := params.Get("bearertoken")
if apiKey != "" && bearerToken != "" {
return "", nil, errors.New("invalid connection string: both apikey and bearertoken cannot be provided")
}

opts := []dgo.ClientOption{}
if apiKey != "" {
opts = append(opts, dgo.WithDgraphAPIKey(apiKey))
}
if bearerToken != "" {
opts = append(opts, dgo.WithBearerToken(bearerToken))
}

sslMode := params.Get("sslmode")
if sslMode == "" {
sslMode = "disable"
}
switch sslMode {
case "disable":
opts = append(opts, dgo.WithGrpcOption(
grpc.WithTransportCredentials(insecure.NewCredentials())))
case "require":
opts = append(opts, dgo.WithSkipTLSVerify())
case "verify-ca":
opts = append(opts, dgo.WithSystemCertPool())
default:
return "", nil, fmt.Errorf(
"invalid SSL mode: %s (must be one of disable, require, verify-ca)", sslMode)
}

if nsParam := params.Get("namespace"); nsParam != "" {
nsID, err := strconv.ParseUint(nsParam, 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid namespace ID: %w", err)
}
opts = append(opts, dgo.WithNamespace(nsID))
}

if u.User != nil {
username := u.User.Username()
password, _ := u.User.Password()
if username == "" || password == "" {
return "", nil, errors.New("invalid connection string: both username and password must be provided")
}
opts = append(opts, dgo.WithACLCreds(username, password))
}

return u.Host, opts, nil
}

type client struct {
uri string
engine *Engine
Expand All @@ -325,8 +430,9 @@ func (c client) key() string {
if c.options.embeddingProvider != nil {
embeddingKey = fmt.Sprintf("%p", c.options.embeddingProvider)
}
return fmt.Sprintf("%s:%t:%d:%d:%d:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize,
c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.namespace, validatorKey, embeddingKey)
return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize,
c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.maxRecvMsgSize,
c.options.namespace, validatorKey, embeddingKey)
}

// embeddingProvider implements the embeddingClient interface, exposing the
Expand Down
Loading