Skip to content

Commit d68cb06

Browse files
committed
test(nccl): skip when libnccl unavailable
Drops the //go:build cuda guard from nccl_test.go so the package's tests compile on every platform. Tests that require libnccl.so.2 call a requireNccl helper that t.Skips when Available() returns false. Adds two new tests that exercise the pure-Go marshaling and ABI-constant paths without touching the runtime library.
1 parent 106c6a3 commit d68cb06

1 file changed

Lines changed: 48 additions & 2 deletions

File tree

internal/nccl/nccl_test.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
//go:build cuda
2-
31
package nccl
42

53
import (
@@ -10,7 +8,49 @@ import (
108
"github.com/zerfoo/ztensor/internal/cuda"
119
)
1210

11+
// requireNccl skips the test when libnccl.so.2 cannot be dlopen'd. Pure
12+
// constant/marshaling tests do not need this guard and should not call it.
13+
func requireNccl(t *testing.T) {
14+
t.Helper()
15+
if !Available() {
16+
t.Skip("libnccl.so.2 not available on this host")
17+
}
18+
}
19+
20+
func TestConstants(t *testing.T) {
21+
if Float32 != 7 || Float64 != 8 || Int32 != 2 || Int64 != 4 {
22+
t.Fatalf("unexpected NCCL DataType ABI constants: f32=%d f64=%d i32=%d i64=%d",
23+
Float32, Float64, Int32, Int64)
24+
}
25+
if Sum != 0 || Avg != 4 || Max != 2 || Min != 3 {
26+
t.Fatalf("unexpected NCCL ReduceOp ABI constants: sum=%d avg=%d max=%d min=%d",
27+
Sum, Avg, Max, Min)
28+
}
29+
}
30+
31+
func TestUniqueIDFromBytesRoundTripNoLib(t *testing.T) {
32+
// This exercises the pure-Go marshaling path and runs on every platform.
33+
src := make([]byte, 128)
34+
for i := range src {
35+
src[i] = byte(i)
36+
}
37+
uid, err := UniqueIDFromBytes(src)
38+
if err != nil {
39+
t.Fatalf("UniqueIDFromBytes: %v", err)
40+
}
41+
out := uid.Bytes()
42+
if len(out) != 128 {
43+
t.Fatalf("Bytes length = %d, want 128", len(out))
44+
}
45+
for i := range src {
46+
if out[i] != src[i] {
47+
t.Fatalf("byte %d: got %d want %d", i, out[i], src[i])
48+
}
49+
}
50+
}
51+
1352
func TestGetUniqueID(t *testing.T) {
53+
requireNccl(t)
1454
uid, err := GetUniqueID()
1555
if err != nil {
1656
t.Fatalf("GetUniqueID: %v", err)
@@ -22,6 +62,7 @@ func TestGetUniqueID(t *testing.T) {
2262
}
2363

2464
func TestUniqueIDRoundTrip(t *testing.T) {
65+
requireNccl(t)
2566
uid, err := GetUniqueID()
2667
if err != nil {
2768
t.Fatalf("GetUniqueID: %v", err)
@@ -50,6 +91,7 @@ func TestUniqueIDFromBytesInvalidLength(t *testing.T) {
5091
}
5192

5293
func TestSingleRankInitDestroy(t *testing.T) {
94+
requireNccl(t)
5395
count, err := cuda.GetDeviceCount()
5496
if err != nil || count < 1 {
5597
t.Skip("requires at least 1 CUDA device")
@@ -74,6 +116,7 @@ func TestSingleRankInitDestroy(t *testing.T) {
74116
}
75117

76118
func TestSingleRankAllReduce(t *testing.T) {
119+
requireNccl(t)
77120
count, err := cuda.GetDeviceCount()
78121
if err != nil || count < 1 {
79122
t.Skip("requires at least 1 CUDA device")
@@ -134,6 +177,7 @@ func TestSingleRankAllReduce(t *testing.T) {
134177
}
135178

136179
func TestTwoGPUAllReduce(t *testing.T) {
180+
requireNccl(t)
137181
count, err := cuda.GetDeviceCount()
138182
if err != nil || count < 2 {
139183
t.Skip("requires at least 2 CUDA devices")
@@ -231,6 +275,7 @@ func TestTwoGPUAllReduce(t *testing.T) {
231275
}
232276

233277
func TestTwoGPUBroadcast(t *testing.T) {
278+
requireNccl(t)
234279
count, err := cuda.GetDeviceCount()
235280
if err != nil || count < 2 {
236281
t.Skip("requires at least 2 CUDA devices")
@@ -326,6 +371,7 @@ func TestTwoGPUBroadcast(t *testing.T) {
326371
}
327372

328373
func TestGroupStartEnd(t *testing.T) {
374+
requireNccl(t)
329375
// GroupStart/GroupEnd can be called without a communicator.
330376
if err := GroupStart(); err != nil {
331377
t.Fatalf("GroupStart: %v", err)

0 commit comments

Comments
 (0)