Skip to content

Commit 5f54de0

Browse files
authored
Merge pull request #10076 from rizlik/dtls13_ack_improvements
Dtls13: ack management improvements
2 parents 1a3daf0 + 1496614 commit 5f54de0

4 files changed

Lines changed: 190 additions & 31 deletions

File tree

src/dtls13.c

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -734,9 +734,19 @@ int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq)
734734
Dtls13RecordNumber** prevNext = &ssl->dtls13Rtx.seenRecords;
735735
Dtls13RecordNumber* cur = ssl->dtls13Rtx.seenRecords;
736736

737+
if (ssl->dtls13Rtx.seenRecordsCount >= DTLS13_ACK_MAX_RECORDS) {
738+
#ifdef WOLFSSL_RW_THREADED
739+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
740+
#endif
741+
return 0; /* list full, silently drop */
742+
}
743+
737744
for (; cur != NULL; prevNext = &cur->next, cur = cur->next) {
738745
if (w64Equal(cur->epoch, epoch) && w64Equal(cur->seq, seq)) {
739746
/* already in list. no duplicates. */
747+
#ifdef WOLFSSL_RW_THREADED
748+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
749+
#endif
740750
return 0;
741751
}
742752
else if (w64LT(epoch, cur->epoch)
@@ -747,11 +757,16 @@ int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq)
747757
}
748758

749759
rn = Dtls13NewRecordNumber(epoch, seq, ssl->heap);
750-
if (rn == NULL)
760+
if (rn == NULL) {
761+
#ifdef WOLFSSL_RW_THREADED
762+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
763+
#endif
751764
return MEMORY_E;
765+
}
752766

753767
*prevNext = rn;
754768
rn->next = cur;
769+
ssl->dtls13Rtx.seenRecordsCount++;
755770
#ifdef WOLFSSL_RW_THREADED
756771
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
757772
#endif
@@ -781,6 +796,7 @@ static void Dtls13RtxFlushAcks(WOLFSSL* ssl)
781796
}
782797

783798
ssl->dtls13Rtx.seenRecords = NULL;
799+
ssl->dtls13Rtx.seenRecordsCount = 0;
784800
#ifdef WOLFSSL_RW_THREADED
785801
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
786802
#endif
@@ -830,6 +846,11 @@ static void Dtls13RtxRemoveCurAck(WOLFSSL* ssl)
830846
{
831847
Dtls13RecordNumber *rn, **prevNext;
832848

849+
#ifdef WOLFSSL_RW_THREADED
850+
if (wc_LockMutex(&ssl->dtls13Rtx.mutex) != 0)
851+
return;
852+
#endif
853+
833854
prevNext = &ssl->dtls13Rtx.seenRecords;
834855
rn = ssl->dtls13Rtx.seenRecords;
835856

@@ -838,12 +859,21 @@ static void Dtls13RtxRemoveCurAck(WOLFSSL* ssl)
838859
w64Equal(rn->seq, ssl->keys.curSeq)) {
839860
*prevNext = rn->next;
840861
XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG);
862+
if (ssl->dtls13Rtx.seenRecordsCount > 0)
863+
ssl->dtls13Rtx.seenRecordsCount--;
864+
#ifdef WOLFSSL_RW_THREADED
865+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
866+
#endif
841867
return;
842868
}
843869

844870
prevNext = &rn->next;
845871
rn = rn->next;
846872
}
873+
874+
#ifdef WOLFSSL_RW_THREADED
875+
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
876+
#endif
847877
}
848878

849879
static void Dtls13MaybeSaveClientHello(WOLFSSL* ssl)
@@ -2544,39 +2574,26 @@ int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side)
25442574
return NOT_COMPILED_IN;
25452575
}
25462576

2547-
/* 64 bits epoch + 64 bits sequence */
2548-
#define DTLS13_RN_SIZE 16
2549-
2550-
static int Dtls13GetAckListLength(Dtls13RecordNumber* list, word16* length)
2551-
{
2552-
int numberElements;
2553-
2554-
numberElements = 0;
2555-
2556-
/* TODO: check that we don't exceed the maximum length */
2557-
2558-
while (list != NULL) {
2559-
list = list->next;
2560-
numberElements++;
2561-
}
2562-
2563-
*length = (word16)(DTLS13_RN_SIZE * numberElements);
2564-
return 0;
2565-
}
25662577

25672578
int Dtls13WriteAckMessage(WOLFSSL* ssl,
2568-
Dtls13RecordNumber* recordNumberList, word32* length)
2579+
Dtls13RecordNumber* recordNumberList, word16 recordsCount, word32* length)
25692580
{
25702581
word16 msgSz, headerLength;
25712582
byte *output, *ackMessage;
25722583
word32 sendSz;
2584+
word32 written;
25732585
int ret;
25742586

25752587
sendSz = 0;
2588+
written = 0;
25762589

25772590
if (ssl->dtls13EncryptEpoch == NULL)
25782591
return BAD_STATE_E;
25792592

2593+
if (recordsCount > DTLS13_ACK_MAX_RECORDS)
2594+
return BUFFER_E;
2595+
msgSz = (word16)(DTLS13_RN_SIZE * recordsCount);
2596+
25802597
if (w64IsZero(ssl->dtls13EncryptEpoch->epochNumber)) {
25812598
/* unprotected ACK */
25822599
headerLength = DTLS_RECORD_HEADER_SZ;
@@ -2586,10 +2603,6 @@ int Dtls13WriteAckMessage(WOLFSSL* ssl,
25862603
sendSz += MAX_MSG_EXTRA;
25872604
}
25882605

2589-
ret = Dtls13GetAckListLength(recordNumberList, &msgSz);
2590-
if (ret != 0)
2591-
return ret;
2592-
25932606
sendSz += headerLength;
25942607

25952608
/* ACK list 2 bytes length field */
@@ -2612,15 +2625,21 @@ int Dtls13WriteAckMessage(WOLFSSL* ssl,
26122625
WOLFSSL_MSG("write ack records");
26132626

26142627
while (recordNumberList != NULL) {
2628+
if (written + DTLS13_RN_SIZE > msgSz)
2629+
return BUFFER_E;
26152630
WOLFSSL_MSG_EX("epoch %d seq %d", recordNumberList->epoch,
26162631
recordNumberList->seq);
26172632
c64toa(&recordNumberList->epoch, ackMessage);
26182633
ackMessage += OPAQUE64_LEN;
26192634
c64toa(&recordNumberList->seq, ackMessage);
26202635
ackMessage += OPAQUE64_LEN;
26212636
recordNumberList = recordNumberList->next;
2637+
written += DTLS13_RN_SIZE;
26222638
}
26232639

2640+
if (written != msgSz)
2641+
return BUFFER_E;
2642+
26242643
*length = msgSz + OPAQUE16_LEN;
26252644

26262645
return 0;
@@ -2731,6 +2750,7 @@ int Dtls13DoScheduledWork(WOLFSSL* ssl)
27312750
tail = &(*tail)->next;
27322751
*tail = ssl->dtls13Rtx.seenRecords;
27332752
ssl->dtls13Rtx.seenRecords = NULL;
2753+
ssl->dtls13Rtx.seenRecordsCount = 0;
27342754
ssl->dupWrite->sendAcks = 1;
27352755
wc_UnLockMutex(&ssl->dupWrite->dupMutex);
27362756
}
@@ -2944,12 +2964,13 @@ int SendDtls13Ack(WOLFSSL* ssl)
29442964
if (ret < 0)
29452965
return ret;
29462966
#endif
2947-
ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, &length);
2967+
ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords,
2968+
ssl->dtls13Rtx.seenRecordsCount, &length);
29482969
#ifdef WOLFSSL_RW_THREADED
29492970
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
29502971
#endif
2951-
if (ret != 0)
2952-
return ret;
2972+
if (ret != 0)
2973+
return ret;
29532974

29542975
output = GetOutputBuffer(ssl);
29552976

tests/api/test_dtls.c

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ int test_dtls13_ack_order(void)
918918
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 2)), 0);
919919
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 2)), 0);
920920
ExpectIntEQ(Dtls13WriteAckMessage(ssl_c, ssl_c->dtls13Rtx.seenRecords,
921-
&length), 0);
921+
ssl_c->dtls13Rtx.seenRecordsCount, &length), 0);
922922

923923
/* must zero the span reserved for the header to avoid read of uninited
924924
* data.
@@ -939,6 +939,124 @@ int test_dtls13_ack_order(void)
939939
return EXPECT_RESULT();
940940
}
941941

942+
int test_dtls13_ack_overflow(void)
943+
{
944+
EXPECT_DECLS;
945+
#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && defined(WOLFSSL_DTLS13)
946+
WOLFSSL_CTX *ctx_c = NULL, *ctx_s = NULL;
947+
WOLFSSL *ssl_c = NULL, *ssl_s = NULL;
948+
struct test_memio_ctx test_ctx;
949+
unsigned char readBuf[50];
950+
word32 length = 0;
951+
int i;
952+
953+
XMEMSET(&test_ctx, 0, sizeof(test_ctx));
954+
955+
ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s,
956+
wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method), 0);
957+
ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0);
958+
ExpectIntEQ(wolfSSL_read(ssl_c, readBuf, sizeof(readBuf)), -1);
959+
ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ);
960+
ExpectIntEQ(wolfSSL_read(ssl_s, readBuf, sizeof(readBuf)), -1);
961+
ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ);
962+
963+
/* Edge case 1: one below limit - all inserts must succeed */
964+
for (i = 0; i < DTLS13_ACK_MAX_RECORDS - 1; i++) {
965+
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 0),
966+
w64From32(0, (word32)i)), 0);
967+
}
968+
ExpectIntEQ(ssl_c->dtls13Rtx.seenRecordsCount, DTLS13_ACK_MAX_RECORDS - 1);
969+
970+
/* Edge case 2: insert the last allowed record - must succeed */
971+
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 0),
972+
w64From32(0, (word32)(DTLS13_ACK_MAX_RECORDS - 1))), 0);
973+
ExpectIntEQ(ssl_c->dtls13Rtx.seenRecordsCount, DTLS13_ACK_MAX_RECORDS);
974+
975+
/* Writing a full-but-valid list must succeed */
976+
ExpectIntEQ(Dtls13WriteAckMessage(ssl_c, ssl_c->dtls13Rtx.seenRecords,
977+
ssl_c->dtls13Rtx.seenRecordsCount, &length), 0);
978+
979+
/* Edge case 3: one over limit - must be silently dropped */
980+
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 0),
981+
w64From32(0, (word32)DTLS13_ACK_MAX_RECORDS)), 0);
982+
ExpectIntEQ(ssl_c->dtls13Rtx.seenRecordsCount, DTLS13_ACK_MAX_RECORDS);
983+
984+
/* Bypass the insert guard to force the list one element over the limit,
985+
* then verify Dtls13WriteAckMessage errors out instead of overflowing */
986+
ssl_c->dtls13Rtx.seenRecordsCount = 0;
987+
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 1),
988+
w64From32(0, (word32)DTLS13_ACK_MAX_RECORDS)), 0);
989+
ssl_c->dtls13Rtx.seenRecordsCount = (word16)(DTLS13_ACK_MAX_RECORDS + 1);
990+
ExpectIntEQ(Dtls13WriteAckMessage(ssl_c, ssl_c->dtls13Rtx.seenRecords,
991+
ssl_c->dtls13Rtx.seenRecordsCount, &length), BUFFER_E);
992+
993+
wolfSSL_free(ssl_c);
994+
wolfSSL_CTX_free(ctx_c);
995+
wolfSSL_free(ssl_s);
996+
wolfSSL_CTX_free(ctx_s);
997+
#endif
998+
return EXPECT_RESULT();
999+
}
1000+
1001+
int test_dtls13_ack_dup_write_counter(void)
1002+
{
1003+
EXPECT_DECLS;
1004+
#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && defined(WOLFSSL_DTLS13) \
1005+
&& defined(HAVE_WRITE_DUP)
1006+
WOLFSSL_CTX *ctx_c = NULL, *ctx_s = NULL;
1007+
WOLFSSL *ssl_c = NULL, *ssl_s = NULL;
1008+
WOLFSSL *ssl_c2 = NULL;
1009+
struct test_memio_ctx test_ctx;
1010+
unsigned char readBuf[50];
1011+
int i;
1012+
1013+
XMEMSET(&test_ctx, 0, sizeof(test_ctx));
1014+
1015+
ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s,
1016+
wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method), 0);
1017+
ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0);
1018+
/* Drain any post-handshake messages */
1019+
ExpectIntEQ(wolfSSL_read(ssl_c, readBuf, sizeof(readBuf)), -1);
1020+
ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ);
1021+
ExpectIntEQ(wolfSSL_read(ssl_s, readBuf, sizeof(readBuf)), -1);
1022+
ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ);
1023+
1024+
/* Split ssl_c: ssl_c becomes READ_DUP_SIDE, ssl_c2 becomes WRITE_DUP_SIDE */
1025+
ExpectNotNull(ssl_c2 = wolfSSL_write_dup(ssl_c));
1026+
1027+
/* Cycle 1: add records, trigger handoff, verify counter is reset to 0 */
1028+
for (i = 0; i < 5; i++)
1029+
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 3),
1030+
w64From32(0, (word32)i)), 0);
1031+
ExpectIntEQ(ssl_c->dtls13Rtx.seenRecordsCount, 5);
1032+
ssl_c->dtls13Rtx.sendAcks = 1;
1033+
ExpectIntEQ(Dtls13DoScheduledWork(ssl_c), 0);
1034+
/* seenRecords ownership was transferred to dupWrite->sendAckList;
1035+
* seenRecordsCount must be reset to 0, not left at 5. */
1036+
ExpectNull(ssl_c->dtls13Rtx.seenRecords);
1037+
ExpectIntEQ(ssl_c->dtls13Rtx.seenRecordsCount, 0);
1038+
1039+
/* Cycle 2 (different epoch to avoid the dup-filter): verify the counter
1040+
* did not accumulate across the previous transfer. Without the fix,
1041+
* seenRecordsCount would now be 10 after this second batch. */
1042+
for (i = 0; i < 5; i++)
1043+
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 4),
1044+
w64From32(0, (word32)i)), 0);
1045+
ExpectIntEQ(ssl_c->dtls13Rtx.seenRecordsCount, 5);
1046+
ssl_c->dtls13Rtx.sendAcks = 1;
1047+
ExpectIntEQ(Dtls13DoScheduledWork(ssl_c), 0);
1048+
ExpectNull(ssl_c->dtls13Rtx.seenRecords);
1049+
ExpectIntEQ(ssl_c->dtls13Rtx.seenRecordsCount, 0);
1050+
1051+
wolfSSL_free(ssl_c);
1052+
wolfSSL_free(ssl_c2);
1053+
wolfSSL_CTX_free(ctx_c);
1054+
wolfSSL_free(ssl_s);
1055+
wolfSSL_CTX_free(ctx_s);
1056+
#endif
1057+
return EXPECT_RESULT();
1058+
}
1059+
9421060
int test_dtls_version_checking(void)
9431061
{
9441062
EXPECT_DECLS;

tests/api/test_dtls.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ int test_wolfSSL_dtls_cid_parse(void);
3030
int test_wolfSSL_dtls_set_pending_peer(void);
3131
int test_dtls13_epochs(void);
3232
int test_dtls13_ack_order(void);
33+
int test_dtls13_ack_overflow(void);
34+
int test_dtls13_ack_dup_write_counter(void);
3335
int test_dtls_version_checking(void);
3436
int test_dtls_short_ciphertext(void);
3537
int test_dtls12_record_length_mismatch(void);
@@ -60,6 +62,8 @@ int test_dtls13_min_rtx_interval(void);
6062
TEST_DECL_GROUP("dtls", test_wolfSSL_dtls_set_pending_peer), \
6163
TEST_DECL_GROUP("dtls", test_dtls13_epochs), \
6264
TEST_DECL_GROUP("dtls", test_dtls13_ack_order), \
65+
TEST_DECL_GROUP("dtls", test_dtls13_ack_overflow), \
66+
TEST_DECL_GROUP("dtls", test_dtls13_ack_dup_write_counter), \
6367
TEST_DECL_GROUP("dtls", test_dtls_version_checking), \
6468
TEST_DECL_GROUP("dtls", test_dtls_short_ciphertext), \
6569
TEST_DECL_GROUP("dtls", test_dtls12_record_length_mismatch), \

wolfssl/internal.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5857,6 +5857,20 @@ enum {
58575857
DTLS13_EPOCH_TRAFFIC0 = 3
58585858
};
58595859

5860+
/* 64-bit epoch + 64-bit sequence number */
5861+
#define DTLS13_RN_SIZE (OPAQUE64_LEN + OPAQUE64_LEN)
5862+
/* Maximum number of ACK records allowed in an ACK message */
5863+
#ifndef DTLS13_ACK_MAX_RECORDS
5864+
#define DTLS13_ACK_MAX_RECORDS 128
5865+
#endif
5866+
/* WOLFSSL_MAX_16BIT / DTLS13_RN_SIZE (0xffff / (OPAQUE64_LEN + OPAQUE64_LEN))
5867+
* Literals are used because OPAQUE64_LEN is an enum value, invisible to the
5868+
* preprocessor. */
5869+
#if DTLS13_ACK_MAX_RECORDS > 0xffff / 16
5870+
#error "DTLS13_ACK_MAX_RECORDS exceeds the maximum encodable in the word16 length field"
5871+
#endif
5872+
5873+
58605874
typedef struct Dtls13Epoch {
58615875
w64wrapper epochNumber;
58625876

@@ -5925,6 +5939,7 @@ typedef struct Dtls13Rtx {
59255939
Dtls13RtxRecord *rtxRecords;
59265940
Dtls13RtxRecord **rtxRecordTailPtr;
59275941
Dtls13RecordNumber *seenRecords;
5942+
word16 seenRecordsCount;
59285943
#ifdef WOLFSSL_32BIT_MILLI_TIME
59295944
word32 lastRtx;
59305945
#else
@@ -7224,6 +7239,7 @@ WOLFSSL_LOCAL void DtlsSetSeqNumForReply(WOLFSSL* ssl);
72247239
#define Dtls13CheckEpoch wolfSSL_Dtls13CheckEpoch
72257240
#define Dtls13WriteAckMessage wolfSSL_Dtls13WriteAckMessage
72267241
#define Dtls13RtxAddAck wolfSSL_Dtls13RtxAddAck
7242+
#define Dtls13DoScheduledWork wolfSSL_Dtls13DoScheduledWork
72277243
#endif
72287244

72297245
WOLFSSL_TEST_VIS struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl,
@@ -7238,7 +7254,7 @@ WOLFSSL_LOCAL int Dtls13GetSeq(WOLFSSL* ssl, int order, word32* seq,
72387254
byte increment);
72397255
WOLFSSL_LOCAL void Dtls13RtxRemoveRecord(WOLFSSL* ssl, w64wrapper epoch,
72407256
w64wrapper seq);
7241-
WOLFSSL_LOCAL int Dtls13DoScheduledWork(WOLFSSL* ssl);
7257+
WOLFSSL_TEST_VIS int Dtls13DoScheduledWork(WOLFSSL* ssl);
72427258
WOLFSSL_LOCAL int Dtls13DeriveSnKeys(WOLFSSL* ssl, int provision);
72437259
WOLFSSL_LOCAL int Dtls13SetRecordNumberKeys(WOLFSSL* ssl,
72447260
enum encrypt_side side);
@@ -7279,7 +7295,7 @@ WOLFSSL_LOCAL int Dtls13ReconstructEpochNumber(WOLFSSL* ssl, byte epochBits,
72797295
WOLFSSL_LOCAL int Dtls13ReconstructSeqNumber(WOLFSSL* ssl,
72807296
Dtls13UnifiedHdrInfo* hdrInfo, w64wrapper* out);
72817297
WOLFSSL_TEST_VIS int Dtls13WriteAckMessage(WOLFSSL* ssl,
7282-
Dtls13RecordNumber* recordNumberList, word32* length);
7298+
Dtls13RecordNumber* recordNumberList, word16 recordsCount, word32* length);
72837299
WOLFSSL_LOCAL int SendDtls13Ack(WOLFSSL* ssl);
72847300
WOLFSSL_TEST_VIS int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq);
72857301
WOLFSSL_LOCAL int Dtls13RtxProcessingCertificate(WOLFSSL* ssl, byte* input,

0 commit comments

Comments
 (0)