Skip to content

Commit 8d1ad63

Browse files
authored
[To dev/1.3] enhance cppclient tsblock deserialize validation (#17464) (#17518)
* fix tsblock deserialize * fix ut error on win * Revert "fix ut error on win" This reverts commit 34b8de4.
1 parent c155666 commit 8d1ad63

5 files changed

Lines changed: 134 additions & 2 deletions

File tree

iotdb-client/client-cpp/src/main/ColumnDecoder.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ std::unique_ptr<Column> BinaryArrayColumnDecoder::readColumn(
151151
if (!nullIndicators.empty() && nullIndicators[i]) continue;
152152

153153
int32_t length = buffer.getInt();
154+
if (length < 0) {
155+
throw IoTDBException("BinaryArrayColumnDecoder: negative TEXT length");
156+
}
154157

155158
std::vector<uint8_t> value(length);
156159
for (int32_t j = 0; j < length; j++) {

iotdb-client/client-cpp/src/main/Common.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "Common.h"
2121
#include <boost/date_time/gregorian/gregorian.hpp>
22+
#include <stdexcept>
2223

2324
int32_t parseDateExpressionToInt(const boost::gregorian::date& date) {
2425
if (date.is_not_a_date()) {
@@ -292,6 +293,10 @@ double MyStringBuffer::getDouble() {
292293
}
293294

294295
char MyStringBuffer::getChar() {
296+
if (pos >= str.size()) {
297+
throw IoTDBException("MyStringBuffer::getChar: read past end (pos=" + std::to_string(pos) +
298+
", size=" + std::to_string(str.size()) + ")");
299+
}
295300
return str[pos++];
296301
}
297302

@@ -300,8 +305,16 @@ bool MyStringBuffer::getBool() {
300305
}
301306

302307
std::string MyStringBuffer::getString() {
303-
size_t len = getInt();
304-
size_t tmpPos = pos;
308+
const int lenInt = getInt();
309+
if (lenInt < 0) {
310+
throw IoTDBException("MyStringBuffer::getString: negative length");
311+
}
312+
const size_t len = static_cast<size_t>(lenInt);
313+
if (pos > str.size() || len > str.size() - pos) {
314+
throw IoTDBException("MyStringBuffer::getString: length exceeds buffer (pos=" + std::to_string(pos) +
315+
", len=" + std::to_string(len) + ", size=" + std::to_string(str.size()) + ")");
316+
}
317+
const size_t tmpPos = pos;
305318
pos += len;
306319
return str.substr(tmpPos, len);
307320
}
@@ -350,6 +363,10 @@ void MyStringBuffer::checkBigEndian() {
350363
}
351364

352365
const char* MyStringBuffer::getOrderedByte(size_t len) {
366+
if (pos > str.size() || len > str.size() - pos) {
367+
throw IoTDBException("MyStringBuffer::getOrderedByte: read past end (pos=" + std::to_string(pos) +
368+
", len=" + std::to_string(len) + ", size=" + std::to_string(str.size()) + ")");
369+
}
353370
const char* p = nullptr;
354371
if (isBigEndian) {
355372
p = str.c_str() + pos;
@@ -454,3 +471,28 @@ const std::vector<char>& BitMap::getByteArray() const {
454471
size_t BitMap::getSize() const {
455472
return this->size;
456473
}
474+
475+
TEndPoint UrlUtils::parseTEndPointIpv4AndIpv6Url(const std::string& endPointUrl) {
476+
TEndPoint endPoint;
477+
const size_t colonPos = endPointUrl.find_last_of(':');
478+
if (colonPos == std::string::npos) {
479+
endPoint.__set_ip(endPointUrl);
480+
endPoint.__set_port(0);
481+
return endPoint;
482+
}
483+
std::string ip = endPointUrl.substr(0, colonPos);
484+
const std::string portStr = endPointUrl.substr(colonPos + 1);
485+
try {
486+
const int port = std::stoi(portStr);
487+
endPoint.__set_port(port);
488+
} catch (const std::logic_error&) {
489+
endPoint.__set_ip(endPointUrl);
490+
endPoint.__set_port(0);
491+
return endPoint;
492+
}
493+
if (ip.size() >= 2 && ip.front() == '[' && ip.back() == ']') {
494+
ip = ip.substr(1, ip.size() - 2);
495+
}
496+
endPoint.__set_ip(ip);
497+
return endPoint;
498+
}

iotdb-client/client-cpp/src/main/Common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,5 +480,13 @@ class RpcUtils {
480480
static std::shared_ptr<TSFetchResultsResp> getTSFetchResultsResp(const TSStatus& status);
481481
};
482482

483+
class UrlUtils {
484+
public:
485+
UrlUtils() = delete;
486+
487+
/** Parse host:port; aligns with Java UrlUtils.parseTEndPointIpv4AndIpv6Url plus test edge cases. */
488+
static TEndPoint parseTEndPointIpv4AndIpv6Url(const std::string& endPointUrl);
489+
};
490+
483491

484492
#endif

iotdb-client/client-cpp/src/main/TsBlock.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19+
#include <cstdint>
1920
#include <stdexcept>
2021
#include <algorithm>
2122
#include "TsBlock.h"
@@ -34,6 +35,14 @@ std::shared_ptr<TsBlock> TsBlock::deserialize(const std::string& data) {
3435

3536
// Read value column count
3637
int32_t valueColumnCount = buffer.getInt();
38+
if (valueColumnCount < 0) {
39+
throw IoTDBException("TsBlock::deserialize: negative valueColumnCount");
40+
}
41+
const int64_t minHeaderBytes =
42+
9LL + 2LL * static_cast<int64_t>(valueColumnCount);
43+
if (minHeaderBytes > static_cast<int64_t>(data.size())) {
44+
throw IoTDBException("TsBlock::deserialize: truncated header");
45+
}
3746

3847
// Read value column data types
3948
std::vector<TSDataType::TSDataType> valueColumnDataTypes(valueColumnCount);
@@ -43,6 +52,9 @@ std::shared_ptr<TsBlock> TsBlock::deserialize(const std::string& data) {
4352

4453
// Read position count
4554
int32_t positionCount = buffer.getInt();
55+
if (positionCount < 0) {
56+
throw IoTDBException("TsBlock::deserialize: negative positionCount");
57+
}
4658

4759
// Read column encodings
4860
std::vector<ColumnEncoding> columnEncodings(valueColumnCount + 1);

iotdb-client/client-cpp/src/test/cpp/sessionIT.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
#include "catch.hpp"
2121
#include "Session.h"
22+
#include "TsBlock.h"
23+
#include <sstream>
2224

2325
using namespace std;
2426

@@ -728,3 +730,68 @@ TEST_CASE("Test executeLastDataQuery ", "[testExecuteLastDataQuery]") {
728730
sessionDataSet->setFetchSize(1024);
729731
REQUIRE(sessionDataSet->hasNext() == false);
730732
}
733+
734+
// Helper function for comparing TEndPoint with detailed error message
735+
void assertTEndPointEqual(const TEndPoint& actual,
736+
const std::string& expectedIp,
737+
int expectedPort,
738+
const char* file,
739+
int line) {
740+
if (actual.ip != expectedIp || actual.port != expectedPort) {
741+
std::stringstream ss;
742+
ss << "\nTEndPoint mismatch:\nExpected: " << expectedIp << ":" << expectedPort
743+
<< "\nActual: " << actual.ip << ":" << actual.port;
744+
Catch::SourceLineInfo location(file, line);
745+
Catch::AssertionHandler handler("TEndPoint comparison", location, ss.str(), Catch::ResultDisposition::Normal);
746+
handler.handleMessage(Catch::ResultWas::ExplicitFailure, ss.str());
747+
handler.complete();
748+
}
749+
}
750+
751+
// Macro to simplify test assertions
752+
#define REQUIRE_TENDPOINT(actual, expectedIp, expectedPort) \
753+
assertTEndPointEqual(actual, expectedIp, expectedPort, __FILE__, __LINE__)
754+
755+
TEST_CASE("UrlUtils - parseTEndPointIpv4AndIpv6Url", "[UrlUtils]") {
756+
// Test valid IPv4 addresses
757+
SECTION("Valid IPv4") {
758+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("192.168.1.1:8080"), "192.168.1.1", 8080);
759+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("10.0.0.1:80"), "10.0.0.1", 80);
760+
}
761+
762+
// Test valid IPv6 addresses
763+
SECTION("Valid IPv6") {
764+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("[2001:db8::1]:8080"), "2001:db8::1", 8080);
765+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("[::1]:80"), "::1", 80);
766+
}
767+
768+
// Test hostnames
769+
SECTION("Hostnames") {
770+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("localhost:8080"), "localhost", 8080);
771+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("example.com:443"), "example.com", 443);
772+
}
773+
774+
// Test edge cases
775+
SECTION("Edge cases") {
776+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url(""), "", 0);
777+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("127.0.0.1"), "127.0.0.1", 0);
778+
}
779+
780+
// Test invalid inputs
781+
SECTION("Invalid inputs") {
782+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("192.168.1.1:abc"), "192.168.1.1:abc", 0);
783+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("]invalid[:80"), "]invalid[", 80);
784+
}
785+
786+
// Test port ranges
787+
SECTION("Port ranges") {
788+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("localhost:0"), "localhost", 0);
789+
REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("127.0.0.1:65535"), "127.0.0.1", 65535);
790+
}
791+
}
792+
793+
TEST_CASE("TsBlock deserialize rejects truncated malicious payload", "[TsBlockDeserialize]") {
794+
std::string data(18, '\0');
795+
data[3] = '\x10';
796+
REQUIRE_THROWS_AS(TsBlock::deserialize(data), IoTDBException);
797+
}

0 commit comments

Comments
 (0)