Skip to content

Commit c1d16a4

Browse files
authored
Improve state type validation in CombineRequest deserialization. (#17449)
* Improve state type validation in CombineRequest deserialization. Tighten request deserialization by validating state type before instantiation to reduce unexpected type usage risk, and add targeted tests for accepted and rejected state class names. Made-with: Cursor * spotless
1 parent 57fe1c9 commit c1d16a4

2 files changed

Lines changed: 67 additions & 1 deletion

File tree

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/processor/twostage/exchange/payload/CombineRequest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.iotdb.db.pipe.processor.twostage.exchange.payload;
2121

2222
import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.IoTDBSinkRequestVersion;
23+
import org.apache.iotdb.db.pipe.processor.twostage.state.CountState;
2324
import org.apache.iotdb.db.pipe.processor.twostage.state.State;
2425
import org.apache.iotdb.service.rpc.thrift.TPipeTransferReq;
2526

@@ -109,7 +110,7 @@ private CombineRequest translateFromTPipeTransferReq(TPipeTransferReq transferRe
109110
combineId = ReadWriteIOUtils.readString(transferReq.body);
110111

111112
final String stateClassName = ReadWriteIOUtils.readString(transferReq.body);
112-
state = (State) Class.forName(stateClassName).newInstance();
113+
state = instantiateState(stateClassName);
113114
state.deserialize(transferReq.body);
114115

115116
version = transferReq.version;
@@ -118,6 +119,13 @@ private CombineRequest translateFromTPipeTransferReq(TPipeTransferReq transferRe
118119
return this;
119120
}
120121

122+
private State instantiateState(final String stateClassName) throws Exception {
123+
if (CountState.class.getName().equals(stateClassName)) {
124+
return new CountState();
125+
}
126+
throw new IllegalArgumentException("Unexpected state class: " + stateClassName);
127+
}
128+
121129
@Override
122130
public String toString() {
123131
return "CombineRequest{"

iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import org.apache.iotdb.commons.path.PartialPath;
2323
import org.apache.iotdb.commons.pipe.sink.payload.thrift.response.PipeTransferFilePieceResp;
2424
import org.apache.iotdb.commons.schema.SchemaConstant;
25+
import org.apache.iotdb.db.pipe.processor.twostage.exchange.payload.CombineRequest;
26+
import org.apache.iotdb.db.pipe.processor.twostage.state.CountState;
2527
import org.apache.iotdb.db.pipe.sink.payload.evolvable.request.PipeTransferDataNodeHandshakeV1Req;
2628
import org.apache.iotdb.db.pipe.sink.payload.evolvable.request.PipeTransferPlanNodeReq;
2729
import org.apache.iotdb.db.pipe.sink.payload.evolvable.request.PipeTransferSchemaSnapshotPieceReq;
@@ -43,6 +45,7 @@
4345
import org.apache.iotdb.db.queryengine.plan.statement.Statement;
4446
import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertBaseStatement;
4547
import org.apache.iotdb.rpc.RpcUtils;
48+
import org.apache.iotdb.service.rpc.thrift.TPipeTransferReq;
4649

4750
import org.apache.tsfile.common.conf.TSFileConfig;
4851
import org.apache.tsfile.enums.TSDataType;
@@ -69,6 +72,61 @@ public class PipeDataNodeThriftRequestTest {
6972

7073
private static final String TIME_PRECISION = "ms";
7174

75+
@Test
76+
public void testCombineRequest() throws Exception {
77+
final CombineRequest req =
78+
CombineRequest.toTPipeTransferReq("pipe", 1L, 2, "combine", new CountState(123L));
79+
final CombineRequest deserializeReq = CombineRequest.fromTPipeTransferReq(req);
80+
81+
Assert.assertEquals(req.getVersion(), deserializeReq.getVersion());
82+
Assert.assertEquals(req.getType(), deserializeReq.getType());
83+
Assert.assertEquals("pipe", deserializeReq.getPipeName());
84+
Assert.assertEquals(1L, deserializeReq.getCreationTime());
85+
Assert.assertEquals(2, deserializeReq.getRegionId());
86+
Assert.assertEquals("combine", deserializeReq.getCombineId());
87+
Assert.assertTrue(deserializeReq.getState() instanceof CountState);
88+
Assert.assertEquals(123L, ((CountState) deserializeReq.getState()).getCount());
89+
}
90+
91+
@Test
92+
public void testCombineRequestWithUnexpectedStateClassName() throws Exception {
93+
final CombineRequest req =
94+
CombineRequest.toTPipeTransferReq("pipe", 1L, 2, "combine", new CountState(123L));
95+
96+
final ByteBuffer bodyBuffer = req.body.duplicate();
97+
final String pipeName = ReadWriteIOUtils.readString(bodyBuffer);
98+
final long creationTime = ReadWriteIOUtils.readLong(bodyBuffer);
99+
final int regionId = ReadWriteIOUtils.readInt(bodyBuffer);
100+
final String combineId = ReadWriteIOUtils.readString(bodyBuffer);
101+
ReadWriteIOUtils.readString(bodyBuffer);
102+
final long count = ReadWriteIOUtils.readLong(bodyBuffer);
103+
104+
final ByteBuffer tamperedBody;
105+
try (final PublicBAOS byteArrayOutputStream = new PublicBAOS();
106+
final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) {
107+
ReadWriteIOUtils.write(pipeName, outputStream);
108+
ReadWriteIOUtils.write(creationTime, outputStream);
109+
ReadWriteIOUtils.write(regionId, outputStream);
110+
ReadWriteIOUtils.write(combineId, outputStream);
111+
ReadWriteIOUtils.write("java.lang.String", outputStream);
112+
ReadWriteIOUtils.write(count, outputStream);
113+
tamperedBody =
114+
ByteBuffer.wrap(byteArrayOutputStream.getBuf(), 0, byteArrayOutputStream.size());
115+
}
116+
117+
final TPipeTransferReq tamperedReq = new TPipeTransferReq();
118+
tamperedReq.version = req.version;
119+
tamperedReq.type = req.type;
120+
tamperedReq.body = tamperedBody;
121+
122+
try {
123+
CombineRequest.fromTPipeTransferReq(tamperedReq);
124+
Assert.fail("Expected IllegalArgumentException");
125+
} catch (final IllegalArgumentException e) {
126+
Assert.assertTrue(e.getMessage().contains("Unexpected state class"));
127+
}
128+
}
129+
72130
@Test
73131
public void testPipeTransferDataNodeHandshakeReq() throws IOException {
74132
final PipeTransferDataNodeHandshakeV1Req req =

0 commit comments

Comments
 (0)