1- //go:build cuda
2-
31package nccl
42
53import (
@@ -10,7 +8,49 @@ import (
108 "github.com/zerfoo/ztensor/internal/cuda"
119)
1210
11+ // requireNccl skips the test when libnccl.so.2 cannot be dlopen'd. Pure
12+ // constant/marshaling tests do not need this guard and should not call it.
13+ func requireNccl (t * testing.T ) {
14+ t .Helper ()
15+ if ! Available () {
16+ t .Skip ("libnccl.so.2 not available on this host" )
17+ }
18+ }
19+
20+ func TestConstants (t * testing.T ) {
21+ if Float32 != 7 || Float64 != 8 || Int32 != 2 || Int64 != 4 {
22+ t .Fatalf ("unexpected NCCL DataType ABI constants: f32=%d f64=%d i32=%d i64=%d" ,
23+ Float32 , Float64 , Int32 , Int64 )
24+ }
25+ if Sum != 0 || Avg != 4 || Max != 2 || Min != 3 {
26+ t .Fatalf ("unexpected NCCL ReduceOp ABI constants: sum=%d avg=%d max=%d min=%d" ,
27+ Sum , Avg , Max , Min )
28+ }
29+ }
30+
31+ func TestUniqueIDFromBytesRoundTripNoLib (t * testing.T ) {
32+ // This exercises the pure-Go marshaling path and runs on every platform.
33+ src := make ([]byte , 128 )
34+ for i := range src {
35+ src [i ] = byte (i )
36+ }
37+ uid , err := UniqueIDFromBytes (src )
38+ if err != nil {
39+ t .Fatalf ("UniqueIDFromBytes: %v" , err )
40+ }
41+ out := uid .Bytes ()
42+ if len (out ) != 128 {
43+ t .Fatalf ("Bytes length = %d, want 128" , len (out ))
44+ }
45+ for i := range src {
46+ if out [i ] != src [i ] {
47+ t .Fatalf ("byte %d: got %d want %d" , i , out [i ], src [i ])
48+ }
49+ }
50+ }
51+
1352func TestGetUniqueID (t * testing.T ) {
53+ requireNccl (t )
1454 uid , err := GetUniqueID ()
1555 if err != nil {
1656 t .Fatalf ("GetUniqueID: %v" , err )
@@ -22,6 +62,7 @@ func TestGetUniqueID(t *testing.T) {
2262}
2363
2464func TestUniqueIDRoundTrip (t * testing.T ) {
65+ requireNccl (t )
2566 uid , err := GetUniqueID ()
2667 if err != nil {
2768 t .Fatalf ("GetUniqueID: %v" , err )
@@ -50,6 +91,7 @@ func TestUniqueIDFromBytesInvalidLength(t *testing.T) {
5091}
5192
5293func TestSingleRankInitDestroy (t * testing.T ) {
94+ requireNccl (t )
5395 count , err := cuda .GetDeviceCount ()
5496 if err != nil || count < 1 {
5597 t .Skip ("requires at least 1 CUDA device" )
@@ -74,6 +116,7 @@ func TestSingleRankInitDestroy(t *testing.T) {
74116}
75117
76118func TestSingleRankAllReduce (t * testing.T ) {
119+ requireNccl (t )
77120 count , err := cuda .GetDeviceCount ()
78121 if err != nil || count < 1 {
79122 t .Skip ("requires at least 1 CUDA device" )
@@ -134,6 +177,7 @@ func TestSingleRankAllReduce(t *testing.T) {
134177}
135178
136179func TestTwoGPUAllReduce (t * testing.T ) {
180+ requireNccl (t )
137181 count , err := cuda .GetDeviceCount ()
138182 if err != nil || count < 2 {
139183 t .Skip ("requires at least 2 CUDA devices" )
@@ -231,6 +275,7 @@ func TestTwoGPUAllReduce(t *testing.T) {
231275}
232276
233277func TestTwoGPUBroadcast (t * testing.T ) {
278+ requireNccl (t )
234279 count , err := cuda .GetDeviceCount ()
235280 if err != nil || count < 2 {
236281 t .Skip ("requires at least 2 CUDA devices" )
@@ -326,6 +371,7 @@ func TestTwoGPUBroadcast(t *testing.T) {
326371}
327372
328373func TestGroupStartEnd (t * testing.T ) {
374+ requireNccl (t )
329375 // GroupStart/GroupEnd can be called without a communicator.
330376 if err := GroupStart (); err != nil {
331377 t .Fatalf ("GroupStart: %v" , err )
0 commit comments