1212import java .util .List ;
1313import java .util .concurrent .atomic .AtomicInteger ;
1414import org .assertj .core .api .Assertions ;
15+ import org .junit .jupiter .api .AfterAll ;
16+ import org .junit .jupiter .api .BeforeEach ;
1517import org .junit .jupiter .api .Test ;
1618import org .mockito .Mockito ;
1719import org .reactivestreams .Publisher ;
2628
2729public class LoadbalanceTest {
2830
29- @ Test
30- public void shouldDeliverAllTheRequestsWithRoundRobinStrategy () {
31+ @ BeforeEach
32+ void setUp () {
3133 Hooks .onErrorDropped ((__ ) -> {});
34+ }
35+
36+ @ AfterAll
37+ static void afterAll () {
38+ Hooks .resetOnErrorDropped ();
39+ }
3240
41+ @ Test
42+ public void shouldDeliverAllTheRequestsWithRoundRobinStrategy () throws Exception {
3343 final AtomicInteger counter = new AtomicInteger ();
3444 final ClientTransport mockTransport = Mockito .mock (ClientTransport .class );
3545 final RSocketConnector rSocketConnectorMock = Mockito .mock (RSocketConnector .class );
@@ -76,21 +86,28 @@ public Mono<Void> fireAndForget(Payload payload) {
7686 });
7787
7888 Assertions .assertThat (counter .get ()).isEqualTo (1000 );
79-
8089 counter .set (0 );
8190 }
8291 }
8392
8493 @ Test
85- public void shouldDeliverAllTheRequestsWithWightedStrategy () {
86- Hooks .onErrorDropped ((__ ) -> {});
87-
94+ public void shouldDeliverAllTheRequestsWithWeightedStrategy () throws InterruptedException {
8895 final AtomicInteger counter = new AtomicInteger ();
89- final ClientTransport mockTransport = Mockito .mock (ClientTransport .class );
90- final RSocketConnector rSocketConnectorMock = Mockito .mock (RSocketConnector .class );
9196
92- Mockito .when (rSocketConnectorMock .connect (Mockito .any (ClientTransport .class )))
93- .then (im -> Mono .just (new TestRSocket (new WeightedRSocket (counter ))));
97+ final ClientTransport mockTransport1 = Mockito .mock (ClientTransport .class );
98+ final ClientTransport mockTransport2 = Mockito .mock (ClientTransport .class );
99+
100+ final LoadbalanceTarget target1 = LoadbalanceTarget .from ("1" , mockTransport1 );
101+ final LoadbalanceTarget target2 = LoadbalanceTarget .from ("2" , mockTransport2 );
102+
103+ final WeightedRSocket weightedRSocket1 = new WeightedRSocket (counter );
104+ final WeightedRSocket weightedRSocket2 = new WeightedRSocket (counter );
105+
106+ final RSocketConnector rSocketConnectorMock = Mockito .mock (RSocketConnector .class );
107+ Mockito .when (rSocketConnectorMock .connect (mockTransport1 ))
108+ .then (im -> Mono .just (new TestRSocket (weightedRSocket1 )));
109+ Mockito .when (rSocketConnectorMock .connect (mockTransport2 ))
110+ .then (im -> Mono .just (new TestRSocket (weightedRSocket2 )));
94111
95112 for (int i = 0 ; i < 1000 ; i ++) {
96113 final TestPublisher <List <LoadbalanceTarget >> source = TestPublisher .create ();
@@ -99,42 +116,39 @@ public void shouldDeliverAllTheRequestsWithWightedStrategy() {
99116 rSocketConnectorMock ,
100117 source ,
101118 WeightedLoadbalanceStrategy .builder ()
102- .weightedStatsResolver (r -> (WeightedStats ) r )
119+ .weightedStatsResolver (
120+ rsocket ->
121+ ((PooledRSocket ) rsocket ).target () == target1
122+ ? weightedRSocket1
123+ : weightedRSocket2 )
103124 .build ());
104125
105126 RaceTestUtils .race (
106127 () -> {
107128 for (int j = 0 ; j < 1000 ; j ++) {
108129 Mono .defer (() -> rSocketPool .select ().fireAndForget (EmptyPayload .INSTANCE ))
109130 .retry ()
110- .subscribe ();
131+ .subscribe (aVoid -> {}, Throwable :: printStackTrace );
111132 }
112133 },
113134 () -> {
114135 for (int j = 0 ; j < 100 ; j ++) {
115136 source .next (Collections .emptyList ());
116- source .next (Collections .singletonList (LoadbalanceTarget .from ("1" , mockTransport )));
117- source .next (
118- Arrays .asList (
119- LoadbalanceTarget .from ("1" , mockTransport ),
120- LoadbalanceTarget .from ("2" , mockTransport )));
121- source .next (Collections .singletonList (LoadbalanceTarget .from ("1" , mockTransport )));
122- source .next (Collections .singletonList (LoadbalanceTarget .from ("2" , mockTransport )));
137+ source .next (Collections .singletonList (target1 ));
138+ source .next (Arrays .asList (target1 , target2 )).next (Collections .singletonList (target1 ));
139+ source .next (Collections .singletonList (target2 ));
123140 source .next (Collections .emptyList ());
124- source .next (Collections .singletonList (LoadbalanceTarget . from ( "2" , mockTransport ) ));
141+ source .next (Collections .singletonList (target2 ));
125142 }
126143 });
127144
128145 Assertions .assertThat (counter .get ()).isEqualTo (1000 );
129-
130146 counter .set (0 );
131147 }
132148 }
133149
134150 @ Test
135151 public void ensureRSocketIsCleanedFromThePoolIfSourceRSocketIsDisposed () {
136- Hooks .onErrorDropped ((__ ) -> {});
137-
138152 final AtomicInteger counter = new AtomicInteger ();
139153 final ClientTransport mockTransport = Mockito .mock (ClientTransport .class );
140154 final RSocketConnector rSocketConnectorMock = Mockito .mock (RSocketConnector .class );
@@ -179,8 +193,6 @@ public Mono<Void> fireAndForget(Payload payload) {
179193
180194 @ Test
181195 public void ensureContextIsPropagatedCorrectlyForRequestChannel () {
182- Hooks .onErrorDropped ((__ ) -> {});
183-
184196 final AtomicInteger counter = new AtomicInteger ();
185197 final ClientTransport mockTransport = Mockito .mock (ClientTransport .class );
186198 final RSocketConnector rSocketConnectorMock = Mockito .mock (RSocketConnector .class );
0 commit comments