Skip to content

Commit f01c4f1

Browse files
authored
Merge pull request #9454 from SparkiDev/rsa_dec_too_small_output_fix
RSA decrypt: don't write past buffer end on error
2 parents 1dfa4d1 + 23c5678 commit f01c4f1

2 files changed

Lines changed: 27 additions & 6 deletions

File tree

tests/api/test_rsa.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,15 +785,18 @@ int test_wc_RsaPublicEncryptDecrypt(void)
785785
WC_DECLARE_VAR(in, byte, TEST_STRING_SZ, NULL);
786786
WC_DECLARE_VAR(plain, byte, TEST_STRING_SZ, NULL);
787787
WC_DECLARE_VAR(cipher, byte, TEST_RSA_BYTES, NULL);
788+
WC_DECLARE_VAR(shortPlain, byte, TEST_STRING_SZ - 4, NULL);
788789

789790
WC_ALLOC_VAR(in, byte, TEST_STRING_SZ, NULL);
790791
WC_ALLOC_VAR(plain, byte, TEST_STRING_SZ, NULL);
791792
WC_ALLOC_VAR(cipher, byte, TEST_RSA_BYTES, NULL);
793+
WC_ALLOC_VAR(shortPlain, byte, TEST_STRING_SZ - 4, NULL);
792794

793795
#ifdef WC_DECLARE_VAR_IS_HEAP_ALLOC
794796
ExpectNotNull(in);
795797
ExpectNotNull(plain);
796798
ExpectNotNull(cipher);
799+
ExpectNotNull(shortPlain);
797800
#endif
798801
ExpectNotNull(XMEMCPY(in, inStr, inLen));
799802

@@ -820,6 +823,11 @@ int test_wc_RsaPublicEncryptDecrypt(void)
820823
ExpectIntEQ(XMEMCMP(plain, inStr, plainLen), 0);
821824
/* Pass bad args - tested in another testing function.*/
822825

826+
/* Test for when plain length is less than required. */
827+
ExpectIntEQ(wc_RsaPrivateDecrypt(cipher, cipherLenResult, shortPlain,
828+
TEST_STRING_SZ - 4, &key), RSA_BUFFER_E);
829+
830+
WC_FREE_VAR(shortPlain, NULL);
823831
WC_FREE_VAR(in, NULL);
824832
WC_FREE_VAR(plain, NULL);
825833
WC_FREE_VAR(cipher, NULL);

wolfcrypt/src/rsa.c

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3636,15 +3636,28 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
36363636
if (rsa_type == RSA_PRIVATE_DECRYPT) {
36373637
word32 i = 0;
36383638
word32 j;
3639+
byte last = 0;
36393640
int start = (int)((size_t)pad - (size_t)key->data);
36403641

36413642
for (j = 0; j < key->dataLen; j++) {
3642-
signed char c;
3643-
out[i] = key->data[j];
3644-
c = (signed char)ctMaskGTE((int)j, start);
3645-
c &= (signed char)ctMaskLT((int)i, (int)outLen);
3646-
/* 0 - no add, -1 add */
3647-
i += (word32)((byte)(-c));
3643+
signed char incMask;
3644+
signed char maskData;
3645+
3646+
/* When j < start + outLen then out[i] = key->data[j]
3647+
* else out[i] = last
3648+
*/
3649+
maskData = (signed char)ctMaskLT((int)j,
3650+
start + (int)outLen);
3651+
out[i] = (byte)(key->data[j] & maskData ) |
3652+
(byte)(last & (~maskData));
3653+
last = out[i];
3654+
3655+
/* Increment i when j is in range:
3656+
* [start..(start + outLen - 1)]. */
3657+
incMask = (signed char)ctMaskGTE((int)j, start);
3658+
incMask &= (signed char)ctMaskLT((int)j,
3659+
start + (int)outLen - 1);
3660+
i += (word32)((byte)(-incMask));
36483661
}
36493662
}
36503663
else

0 commit comments

Comments
 (0)