@@ -278,6 +278,23 @@ func (c *clientImpl) DispatchNexusTask(
278278 ctx context.Context ,
279279 request * matchingservice.DispatchNexusTaskRequest ,
280280 opts ... grpc.CallOption ,
281+ ) (* matchingservice.DispatchNexusTaskResponse , error ) {
282+ if ! isPartitionAwareKind (request .GetTaskQueue ().GetKind ()) {
283+ return c .dispatchNexusTask (ctx , PartitionCounts {}, request , opts )
284+ }
285+ pkey := c .partitionCache .makeKey (
286+ request .GetNamespaceId (),
287+ request .GetTaskQueue ().GetName (),
288+ enumspb .TASK_QUEUE_TYPE_NEXUS ,
289+ )
290+ return invokeWithPartitionCounts (ctx , c .logger , c .partitionCache , pkey , request , opts , c .dispatchNexusTask )
291+ }
292+
293+ func (c * clientImpl ) dispatchNexusTask (
294+ ctx context.Context ,
295+ pc PartitionCounts ,
296+ request * matchingservice.DispatchNexusTaskRequest ,
297+ opts []grpc.CallOption ,
281298) (* matchingservice.DispatchNexusTaskResponse , error ) {
282299 // use shallow copy since Request may contain a large payload
283300 request = & matchingservice.DispatchNexusTaskRequest {
@@ -286,7 +303,13 @@ func (c *clientImpl) DispatchNexusTask(
286303 Request : request .Request ,
287304 ForwardInfo : request .ForwardInfo ,
288305 }
289- client , err := c .pickClientForWrite (request .GetTaskQueue (), request .GetNamespaceId (), enumspb .TASK_QUEUE_TYPE_NEXUS , request .GetForwardInfo ().GetSourcePartition ())
306+ client , err := c .pickClientForWrite (
307+ request .GetTaskQueue (),
308+ request .GetNamespaceId (),
309+ enumspb .TASK_QUEUE_TYPE_NEXUS ,
310+ request .GetForwardInfo ().GetSourcePartition (),
311+ pc ,
312+ )
290313 if err != nil {
291314 return nil , err
292315 }
@@ -299,13 +322,32 @@ func (c *clientImpl) PollNexusTaskQueue(
299322 ctx context.Context ,
300323 request * matchingservice.PollNexusTaskQueueRequest ,
301324 opts ... grpc.CallOption ,
325+ ) (* matchingservice.PollNexusTaskQueueResponse , error ) {
326+ if ! isPartitionAwareKind (request .GetRequest ().GetTaskQueue ().GetKind ()) {
327+ return c .pollNexusTaskQueue (ctx , PartitionCounts {}, request , opts )
328+ }
329+ pkey := c .partitionCache .makeKey (
330+ request .GetNamespaceId (),
331+ request .GetRequest ().GetTaskQueue ().GetName (),
332+ enumspb .TASK_QUEUE_TYPE_NEXUS ,
333+ )
334+ return invokeWithPartitionCounts (ctx , c .logger , c .partitionCache , pkey , request , opts , c .pollNexusTaskQueue )
335+ }
336+
337+ func (c * clientImpl ) pollNexusTaskQueue (
338+ ctx context.Context ,
339+ pc PartitionCounts ,
340+ request * matchingservice.PollNexusTaskQueueRequest ,
341+ opts []grpc.CallOption ,
302342) (* matchingservice.PollNexusTaskQueueResponse , error ) {
303343 request = common .CloneProto (request )
304344 client , release , err := c .pickClientForRead (
305345 request .GetRequest ().GetTaskQueue (),
306346 request .GetNamespaceId (),
307347 enumspb .TASK_QUEUE_TYPE_NEXUS ,
308- request .GetForwardedSource ())
348+ request .GetForwardedSource (),
349+ pc ,
350+ )
309351 if err != nil {
310352 return nil , err
311353 }
0 commit comments