Skip to content

Commit 145d82f

Browse files
authored
[client] Replace iOS DNS IsPrivate heuristic with route manager check (#5694)
1 parent a8b9570 commit 145d82f

5 files changed

Lines changed: 35 additions & 3 deletions

File tree

client/internal/dns/mock_server.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
8585
return nil
8686
}
8787

88+
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
89+
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
90+
// Mock implementation - no-op
91+
}
92+
8893
// BeginBatch mock implementation of BeginBatch from Server interface
8994
func (m *MockServer) BeginBatch() {
9095
// Mock implementation - no-op

client/internal/dns/server.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ type Server interface {
5757
ProbeAvailability()
5858
UpdateServerConfig(domains dnsconfig.ServerDomains) error
5959
PopulateManagementDomain(mgmtURL *url.URL) error
60+
SetRouteChecker(func(netip.Addr) bool)
6061
}
6162

6263
type nsGroupsByDomain struct {
@@ -104,6 +105,7 @@ type DefaultServer struct {
104105

105106
statusRecorder *peer.Status
106107
stateManager *statemanager.Manager
108+
routeMatch func(netip.Addr) bool
107109

108110
probeMu sync.Mutex
109111
probeCancel context.CancelFunc
@@ -229,6 +231,14 @@ func newDefaultServer(
229231
return defaultServer
230232
}
231233

234+
// SetRouteChecker sets the function used by upstream resolvers to determine
235+
// whether an IP is routed through the tunnel.
236+
func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) {
237+
s.mux.Lock()
238+
defer s.mux.Unlock()
239+
s.routeMatch = f
240+
}
241+
232242
// RegisterHandler registers a handler for the given domains with the given priority.
233243
// Any previously registered handler for the same domain and priority will be replaced.
234244
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -743,6 +753,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
743753
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
744754
return
745755
}
756+
handler.routeMatch = s.routeMatch
746757

747758
for _, ns := range originalNameservers {
748759
if ns == config.ServerIP {
@@ -852,6 +863,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
852863
if err != nil {
853864
return nil, fmt.Errorf("create upstream resolver: %v", err)
854865
}
866+
handler.routeMatch = s.routeMatch
855867

856868
for _, ns := range nsGroup.NameServers {
857869
if ns.NSType != nbdns.UDPNameServerType {
@@ -1036,6 +1048,7 @@ func (s *DefaultServer) addHostRootZone() {
10361048
log.Errorf("unable to create a new upstream resolver, error: %v", err)
10371049
return
10381050
}
1051+
handler.routeMatch = s.routeMatch
10391052

10401053
handler.upstreamServers = maps.Keys(hostDNSServers)
10411054
handler.deactivate = func(error) {}

client/internal/dns/upstream.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ type upstreamResolverBase struct {
7070
deactivate func(error)
7171
reactivate func()
7272
statusRecorder *peer.Status
73+
routeMatch func(netip.Addr) bool
7374
}
7475

7576
type upstreamFailure struct {

client/internal/dns/upstream_ios.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
6565
} else {
6666
upstreamIP = upstreamIP.Unmap()
6767
}
68-
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
69-
log.Debugf("using private client to query upstream: %s", upstream)
68+
needsPrivate := u.lNet.Contains(upstreamIP) ||
69+
(u.routeMatch != nil && u.routeMatch(upstreamIP))
70+
if needsPrivate {
71+
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
7072
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
7173
if err != nil {
72-
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
74+
return nil, 0, fmt.Errorf("create private client: %s", err)
7375
}
7476
}
7577

client/internal/engine.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,17 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
499499

500500
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
501501

502+
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
503+
for _, routes := range e.routeManager.GetClientRoutes() {
504+
for _, r := range routes {
505+
if r.Network.Contains(ip) {
506+
return true
507+
}
508+
}
509+
}
510+
return false
511+
})
512+
502513
if err = e.wgInterfaceCreate(); err != nil {
503514
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
504515
e.close()

0 commit comments

Comments
 (0)