Skip to content

Commit 4eed459

Browse files
authored
[client] Fix DNS resolution with userspace WireGuard and kernel firewall (#5873)
1 parent 1353954 commit 4eed459

5 files changed

Lines changed: 146 additions & 45 deletions

File tree

client/firewall/create_linux.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
5656
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
5757
}
5858

59+
// Native firewall handles packet filtering, but the userspace WireGuard bind
60+
// needs a device filter for DNS interception hooks. Install a minimal
61+
// hooks-only filter that passes all traffic through to the kernel firewall.
62+
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
63+
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
64+
}
65+
5966
return fm, nil
6067
}
6168

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package common
2+
3+
import (
4+
"net/netip"
5+
"sync/atomic"
6+
)
7+
8+
// PacketHook stores a registered hook for a specific IP:port.
9+
type PacketHook struct {
10+
IP netip.Addr
11+
Port uint16
12+
Fn func([]byte) bool
13+
}
14+
15+
// HookMatches checks if a packet's destination matches the hook and invokes it.
16+
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
17+
if h == nil {
18+
return false
19+
}
20+
if h.IP == dstIP && h.Port == dport {
21+
return h.Fn(packetData)
22+
}
23+
return false
24+
}
25+
26+
// SetHook atomically stores a hook, handling nil removal.
27+
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
28+
if hook == nil {
29+
ptr.Store(nil)
30+
return
31+
}
32+
ptr.Store(&PacketHook{
33+
IP: ip,
34+
Port: dPort,
35+
Fn: hook,
36+
})
37+
}

client/firewall/uspfilter/filter.go

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,8 @@ type Manager struct {
142142
mssClampEnabled bool
143143

144144
// Only one hook per protocol is supported. Outbound direction only.
145-
udpHookOut atomic.Pointer[packetHook]
146-
tcpHookOut atomic.Pointer[packetHook]
147-
}
148-
149-
// packetHook stores a registered hook for a specific IP:port.
150-
type packetHook struct {
151-
ip netip.Addr
152-
port uint16
153-
fn func([]byte) bool
145+
udpHookOut atomic.Pointer[common.PacketHook]
146+
tcpHookOut atomic.Pointer[common.PacketHook]
154147
}
155148

156149
// decoder for packages
@@ -912,21 +905,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
912905
}
913906

914907
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
915-
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
908+
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
916909
}
917910

918911
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
919-
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
920-
}
921-
922-
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
923-
if h == nil {
924-
return false
925-
}
926-
if h.ip == dstIP && h.port == dport {
927-
return h.fn(packetData)
928-
}
929-
return false
912+
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
930913
}
931914

932915
// filterInbound implements filtering logic for incoming packets.
@@ -1337,28 +1320,12 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
13371320

13381321
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
13391322
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
1340-
if hook == nil {
1341-
m.udpHookOut.Store(nil)
1342-
return
1343-
}
1344-
m.udpHookOut.Store(&packetHook{
1345-
ip: ip,
1346-
port: dPort,
1347-
fn: hook,
1348-
})
1323+
common.SetHook(&m.udpHookOut, ip, dPort, hook)
13491324
}
13501325

13511326
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
13521327
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
1353-
if hook == nil {
1354-
m.tcpHookOut.Store(nil)
1355-
return
1356-
}
1357-
m.tcpHookOut.Store(&packetHook{
1358-
ip: ip,
1359-
port: dPort,
1360-
fn: hook,
1361-
})
1328+
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
13621329
}
13631330

13641331
// SetLogLevel sets the log level for the firewall manager

client/firewall/uspfilter/filter_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) {
202202

203203
h := manager.udpHookOut.Load()
204204
require.NotNil(t, h)
205-
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
206-
assert.Equal(t, uint16(8000), h.port)
207-
assert.True(t, h.fn(nil))
205+
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
206+
assert.Equal(t, uint16(8000), h.Port)
207+
assert.True(t, h.Fn(nil))
208208
assert.True(t, called)
209209

210210
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
@@ -226,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) {
226226

227227
h := manager.tcpHookOut.Load()
228228
require.NotNil(t, h)
229-
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
230-
assert.Equal(t, uint16(53), h.port)
231-
assert.True(t, h.fn(nil))
229+
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
230+
assert.Equal(t, uint16(53), h.Port)
231+
assert.True(t, h.Fn(nil))
232232
assert.True(t, called)
233233

234234
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package uspfilter
2+
3+
import (
4+
"encoding/binary"
5+
"net/netip"
6+
"sync/atomic"
7+
8+
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
9+
"github.com/netbirdio/netbird/client/iface/device"
10+
)
11+
12+
const (
13+
ipv4HeaderMinLen = 20
14+
ipv4ProtoOffset = 9
15+
ipv4FlagsOffset = 6
16+
ipv4DstOffset = 16
17+
ipProtoUDP = 17
18+
ipProtoTCP = 6
19+
ipv4FragOffMask = 0x1fff
20+
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
21+
dstPortOffset = 2
22+
)
23+
24+
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
25+
// It is installed on the WireGuard interface when the userspace bind is active
26+
// but a full firewall filter (Manager) is not needed because a native kernel
27+
// firewall (nftables/iptables) handles packet filtering.
28+
type HooksFilter struct {
29+
udpHook atomic.Pointer[common.PacketHook]
30+
tcpHook atomic.Pointer[common.PacketHook]
31+
}
32+
33+
var _ device.PacketFilter = (*HooksFilter)(nil)
34+
35+
// FilterOutbound checks outbound packets for DNS hook matches.
36+
// Only IPv4 packets matching the registered hook IP:port are intercepted.
37+
// IPv6 and non-IP packets pass through unconditionally.
38+
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
39+
if len(packetData) < ipv4HeaderMinLen {
40+
return false
41+
}
42+
43+
// Only process IPv4 packets, let everything else pass through.
44+
if packetData[0]>>4 != 4 {
45+
return false
46+
}
47+
48+
ihl := int(packetData[0]&0x0f) * 4
49+
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
50+
return false
51+
}
52+
53+
// Skip non-first fragments: they don't carry L4 headers.
54+
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
55+
if flagsAndOffset&ipv4FragOffMask != 0 {
56+
return false
57+
}
58+
59+
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
60+
if !ok {
61+
return false
62+
}
63+
64+
proto := packetData[ipv4ProtoOffset]
65+
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
66+
67+
switch proto {
68+
case ipProtoUDP:
69+
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
70+
case ipProtoTCP:
71+
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
72+
default:
73+
return false
74+
}
75+
}
76+
77+
// FilterInbound allows all inbound packets (native firewall handles filtering).
78+
func (f *HooksFilter) FilterInbound([]byte, int) bool {
79+
return false
80+
}
81+
82+
// SetUDPPacketHook registers the UDP packet hook.
83+
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
84+
common.SetHook(&f.udpHook, ip, dPort, hook)
85+
}
86+
87+
// SetTCPPacketHook registers the TCP packet hook.
88+
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
89+
common.SetHook(&f.tcpHook, ip, dPort, hook)
90+
}

0 commit comments

Comments
 (0)