Skip to content

Commit dedc9ac

Browse files
nat: Now cheaper checksum recalc, integrated with conntrack(WIP)
1 parent 9accfc4 commit dedc9ac

2 files changed

Lines changed: 50 additions & 12 deletions

File tree

api/net/nat/napt.hpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <net/port_util.hpp>
2424
#include <net/inet>
2525
#include <net/tcp/tcp.hpp>
26+
#include <net/tcp/connection_tracker.hpp>
2627

2728
namespace net {
2829
namespace nat {
@@ -37,6 +38,10 @@ class NAPT {
3738

3839
public:
3940

41+
NAPT() {
42+
tcp_tracker.on_close = {this, &NAPT::conn_close};
43+
}
44+
4045
// NAT
4146
IP4::IP_packet_ptr nat(IP4::IP_packet_ptr pkt, const Stack& inet);
4247

@@ -56,25 +61,38 @@ class NAPT {
5661
//printf("NAT entry: %s => %u\n", sock.to_string().c_str(), port);
5762
}
5863

64+
void remove_entry(Socket socket)
65+
{
66+
auto it = std::find_if(tcp_trans.begin(), tcp_trans.end(),
67+
[socket] (auto& ent) {
68+
return ent.second == socket;
69+
});
70+
71+
tcp_ports.unbind(it->first);
72+
tcp_trans.erase(it);
73+
}
74+
75+
void conn_close(const tcp::Connection_tracker::Tuple& tuple)
76+
{
77+
remove_entry(tuple.first);
78+
}
79+
5980
private:
6081
Port_util tcp_ports;
6182
Port_util udp_ports;
6283

6384
Translation_table tcp_trans;
6485
Translation_table udp_trans;
6586

87+
tcp::Connection_tracker tcp_tracker;
88+
6689
// Source NAT
6790
void snat(tcp::Packet& pkt, ip4::Addr src_ip);
6891

6992
// Destination NAT
7093
void dnat(tcp::Packet& pkt);
7194

72-
void recalculate_checksum(tcp::Packet& pkt) noexcept
73-
{
74-
pkt.set_checksum(0);
75-
pkt.set_ip_checksum();
76-
pkt.set_checksum(TCP::checksum(pkt));
77-
}
95+
void recalculate_checksum(tcp::Packet& pkt, Socket osock, Socket nsock);
7896

7997
}; // < class NAPT
8098

src/net/nat/napt.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ IP4::IP_packet_ptr NAPT::dnat(IP4::IP_packet_ptr pkt, const Stack& inet)
6666

6767
void NAPT::snat(tcp::Packet& pkt, ip4::Addr src_ip)
6868
{
69+
tcp_tracker.incoming(pkt);
6970
// Get the Socket
7071
Socket socket = pkt.source();
7172

@@ -96,32 +97,51 @@ void NAPT::snat(tcp::Packet& pkt, ip4::Addr src_ip)
9697
// At last, replace the source address
9798
pkt.set_ip_src(src_ip);
9899

99-
printf("SNAT %s => %s\n", socket.to_string().c_str(), pkt.source().to_string().c_str());
100+
debug2("SNAT %s => %s\n", socket.to_string().c_str(), pkt.source().to_string().c_str());
100101

101102
// Recalculate checksum
102-
recalculate_checksum(pkt);
103+
recalculate_checksum(pkt, socket, pkt.source());
103104
}
104105

105106
void NAPT::dnat(tcp::Packet& pkt)
106107
{
107-
auto dst_port = pkt.dst_port();
108+
auto orgsock = pkt.destination();
108109

109110
// Is there an entry?
110-
auto it = tcp_trans.find(dst_port);
111+
auto it = tcp_trans.find(orgsock.port());
111112

112113
// If there already is an entry
113114
if(it != tcp_trans.end())
114115
{
115116
// Get the Socket
116117
auto socket = it->second;
117-
printf("DNAT %s => %s\n", pkt.destination().to_string().c_str(), socket.to_string().c_str());
118+
debug2("DNAT %s => %s\n", orgsock.to_string().c_str(), socket.to_string().c_str());
118119
// Replace the destination port with the original one
119120
pkt.set_destination(socket);
120121

121122
// Recalculate checksum
122-
recalculate_checksum(pkt);
123+
recalculate_checksum(pkt, orgsock, socket);
124+
125+
tcp_tracker.outgoing(pkt);
123126
}
124127
}
125128

129+
void NAPT::recalculate_checksum(tcp::Packet& pkt, Socket osock, Socket nsock)
130+
{
131+
auto old_addr = osock.address();
132+
auto new_addr = nsock.address();
133+
auto old_port = htons(osock.port());
134+
auto new_port = htons(nsock.port());
135+
136+
auto ip_sum = pkt.ip_checksum();
137+
checksum_adjust(&ip_sum, &old_addr, &new_addr);
138+
pkt.set_ip_checksum(ip_sum);
139+
140+
auto tcp_sum = pkt.tcp_checksum();
141+
checksum_adjust(&tcp_sum, &old_addr, &new_addr);
142+
checksum_adjust<uint16_t>(&tcp_sum, &old_port, &new_port);
143+
pkt.set_checksum(tcp_sum);
144+
}
145+
126146
}
127147
}

0 commit comments

Comments
 (0)