Skip to content

Commit 35ea45f

Browse files
committed
chore(nccl): add purego binding for libnccl.so.2
Implements internal/nccl as a zero-CGo runtime dlopen of libnccl.so.2, mirroring the pattern in internal/cublas/cublas_purego.go. The package now compiles on every platform without -tags cuda; non-linux GOOS returns a clean "not supported" error rather than panicking. ABI constants for ncclResult, ncclDataType, ncclRedOp, and NCCL_UNIQUE_ID_BYTES are hardcoded against the stable NCCL 2.x ABI. ncclCommInitRank takes the 128-byte ncclUniqueId by value. Per AAPCS64 rule B.4 (composites > 16 bytes are passed by hidden pointer), passing &uid.id[0] as a uintptr is the correct calling convention on linux/arm64, which is the only NCCL platform ztensor targets today. CI's go vet exclude list adds /internal/nccl$ alongside the other GPU runtime bindings that rely on unsafe.Pointer(uintptr(...)) trampolines. Refs #78
1 parent 9b8e5aa commit 35ea45f

3 files changed

Lines changed: 357 additions & 4 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ jobs:
1818
# Run go vet on all packages except those with intentional
1919
# unsafe.Pointer usage for GPU runtime bindings via purego/dlopen.
2020
# These warnings are expected and documented in docs/QUALITY.md.
21-
go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$' | grep -v '/internal/metal$' | grep -v '/internal/pjrt$')
21+
go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$' | grep -v '/internal/metal$' | grep -v '/internal/pjrt$' | grep -v '/internal/nccl$')
2222
- run: go test -race -timeout 300s ./...

internal/nccl/doc.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// Package nccl provides CGo bindings for the NVIDIA Collective Communications
2-
// Library (NCCL). All functional code requires the "cuda" build tag and a
3-
// working NCCL installation.
1+
// Package nccl provides a zero-CGo binding for the NVIDIA Collective
2+
// Communications Library (NCCL). The library is loaded at runtime via dlopen
3+
// (see nccl_purego.go); a legacy CGo implementation is retained behind the
4+
// `nccl_cgo` build tag for opt-in fallback (nccl_cgo.go).
45
package nccl

internal/nccl/nccl_purego.go

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
// Zero-CGo binding to libnccl.so.2 loaded at runtime via dlopen. On non-linux
2+
// platforms every exported entry point returns a clean "not supported" error
3+
// without attempting dlopen. On linux the library is dlopen'd lazily on first
4+
// use; if libnccl.so.2 cannot be found the same error is returned.
5+
//
6+
// On AArch64 (DGX hardware) AAPCS64 rule B.4 means aggregates larger than 16
7+
// bytes are passed by hidden pointer, which lets us hand the 128-byte
8+
// ncclUniqueId to ncclCommInitRank as a plain uintptr without an ABI
9+
// trampoline. The legacy CGo implementation is retained behind the
10+
// `nccl_cgo` build tag (see nccl_cgo.go).
11+
package nccl
12+
13+
import (
14+
"fmt"
15+
"runtime"
16+
"sync"
17+
"unsafe"
18+
19+
"github.com/zerfoo/ztensor/internal/cuda"
20+
)
21+
22+
// NCCL_UNIQUE_ID_BYTES is the fixed size of an ncclUniqueId.
23+
const ncclUniqueIDBytes = 128
24+
25+
// NCCL result codes.
26+
const ncclSuccess = 0
27+
28+
// NCCL data type enum (stable ABI for NCCL 2.x).
29+
const (
30+
ncclInt8 = 0
31+
ncclUint8 = 1
32+
ncclInt32 = 2
33+
ncclUint32 = 3
34+
ncclInt64 = 4
35+
ncclUint64 = 5
36+
ncclFloat16 = 6
37+
ncclFloat32 = 7
38+
ncclFloat64 = 8
39+
ncclBfloat16 = 9
40+
)
41+
42+
// NCCL reduction op enum (stable ABI for NCCL 2.x).
43+
const (
44+
ncclSum = 0
45+
ncclProd = 1
46+
ncclMax = 2
47+
ncclMin = 3
48+
ncclAvg = 4
49+
)
50+
51+
// DataType specifies the element type for NCCL operations.
52+
type DataType int
53+
54+
const (
55+
Float32 DataType = ncclFloat32
56+
Float64 DataType = ncclFloat64
57+
Int32 DataType = ncclInt32
58+
Int64 DataType = ncclInt64
59+
)
60+
61+
// ReduceOp specifies the reduction operation for collective calls.
62+
type ReduceOp int
63+
64+
const (
65+
Sum ReduceOp = ncclSum
66+
Avg ReduceOp = ncclAvg
67+
Max ReduceOp = ncclMax
68+
Min ReduceOp = ncclMin
69+
)
70+
71+
// ncclLib holds dlsym-resolved function pointers for libnccl.so.2.
72+
type ncclLib struct {
73+
getUniqueId uintptr // ncclGetUniqueId
74+
commInitRank uintptr // ncclCommInitRank
75+
commDestroy uintptr // ncclCommDestroy
76+
commGetAsyncError uintptr // ncclCommGetAsyncError
77+
allReduce uintptr // ncclAllReduce
78+
broadcast uintptr // ncclBroadcast
79+
groupStart uintptr // ncclGroupStart
80+
groupEnd uintptr // ncclGroupEnd
81+
getErrorString uintptr // ncclGetErrorString
82+
}
83+
84+
var (
85+
ncclLibInst *ncclLib
86+
ncclOnce sync.Once
87+
ncclLoadErr error
88+
)
89+
90+
// libnccl candidate paths (linux only).
91+
var ncclLibPaths = []string{
92+
"libnccl.so.2",
93+
"libnccl.so",
94+
}
95+
96+
func loadNccl() (*ncclLib, error) {
97+
if runtime.GOOS != "linux" {
98+
return nil, fmt.Errorf("nccl: not supported on %s", runtime.GOOS)
99+
}
100+
var handle uintptr
101+
var lastErr string
102+
for _, path := range ncclLibPaths {
103+
h, err := cuda.DlopenPath(path)
104+
if err == nil {
105+
handle = h
106+
break
107+
}
108+
lastErr = err.Error()
109+
}
110+
if handle == 0 {
111+
return nil, fmt.Errorf("nccl: dlopen failed: %s", lastErr)
112+
}
113+
114+
lib := &ncclLib{}
115+
type sym struct {
116+
name string
117+
ptr *uintptr
118+
}
119+
syms := []sym{
120+
{"ncclGetUniqueId", &lib.getUniqueId},
121+
{"ncclCommInitRank", &lib.commInitRank},
122+
{"ncclCommDestroy", &lib.commDestroy},
123+
{"ncclCommGetAsyncError", &lib.commGetAsyncError},
124+
{"ncclAllReduce", &lib.allReduce},
125+
{"ncclBroadcast", &lib.broadcast},
126+
{"ncclGroupStart", &lib.groupStart},
127+
{"ncclGroupEnd", &lib.groupEnd},
128+
{"ncclGetErrorString", &lib.getErrorString},
129+
}
130+
for _, s := range syms {
131+
addr, err := cuda.Dlsym(handle, s.name)
132+
if err != nil {
133+
return nil, fmt.Errorf("nccl: %w", err)
134+
}
135+
*s.ptr = addr
136+
}
137+
return lib, nil
138+
}
139+
140+
func getNcclLib() (*ncclLib, error) {
141+
ncclOnce.Do(func() {
142+
ncclLibInst, ncclLoadErr = loadNccl()
143+
})
144+
return ncclLibInst, ncclLoadErr
145+
}
146+
147+
// Available returns true if libnccl can be loaded at runtime.
148+
func Available() bool {
149+
_, err := getNcclLib()
150+
return err == nil
151+
}
152+
153+
// errorString returns the human-readable error string for an NCCL result code.
154+
// Falls back to a numeric description if ncclGetErrorString cannot be invoked.
155+
func (l *ncclLib) errorString(rc uintptr) string {
156+
if l == nil || l.getErrorString == 0 {
157+
return fmt.Sprintf("ncclResult=%d", rc)
158+
}
159+
cstr := cuda.Ccall(l.getErrorString, rc)
160+
if cstr == 0 {
161+
return fmt.Sprintf("ncclResult=%d", rc)
162+
}
163+
// Read C-string at cstr.
164+
var b []byte
165+
for i := 0; i < 1024; i++ {
166+
c := *(*byte)(unsafe.Pointer(cstr + uintptr(i)))
167+
if c == 0 {
168+
break
169+
}
170+
b = append(b, c)
171+
}
172+
return string(b)
173+
}
174+
175+
// UniqueID wraps an ncclUniqueId (128-byte opaque blob) used to bootstrap
176+
// communicator creation.
177+
type UniqueID struct {
178+
id [ncclUniqueIDBytes]byte
179+
}
180+
181+
// GetUniqueID generates a new unique ID for communicator initialization.
182+
// Exactly one rank should call this and broadcast the result to all other ranks.
183+
func GetUniqueID() (*UniqueID, error) {
184+
lib, err := getNcclLib()
185+
if err != nil {
186+
return nil, err
187+
}
188+
uid := &UniqueID{}
189+
rc := cuda.Ccall(lib.getUniqueId, uintptr(unsafe.Pointer(&uid.id[0])))
190+
if rc != ncclSuccess {
191+
return nil, fmt.Errorf("ncclGetUniqueId failed: %s", lib.errorString(rc))
192+
}
193+
return uid, nil
194+
}
195+
196+
// Bytes returns a copy of the raw bytes of the unique ID for serialization.
197+
func (u *UniqueID) Bytes() []byte {
198+
out := make([]byte, ncclUniqueIDBytes)
199+
copy(out, u.id[:])
200+
return out
201+
}
202+
203+
// UniqueIDFromBytes reconstructs a UniqueID from raw bytes.
204+
func UniqueIDFromBytes(b []byte) (*UniqueID, error) {
205+
if len(b) != ncclUniqueIDBytes {
206+
return nil, fmt.Errorf("UniqueIDFromBytes: expected %d bytes, got %d", ncclUniqueIDBytes, len(b))
207+
}
208+
uid := &UniqueID{}
209+
copy(uid.id[:], b)
210+
return uid, nil
211+
}
212+
213+
// Comm wraps an ncclComm_t communicator (opaque pointer).
214+
type Comm struct {
215+
comm uintptr
216+
}
217+
218+
// InitRank initializes a communicator for a given rank in a group of nRanks.
219+
// All ranks must call this with the same UniqueID and nRanks. The CUDA device
220+
// for this rank must be set via cuda.SetDevice before calling InitRank.
221+
//
222+
// On AArch64 (AAPCS64) and other AAPCS-derived ABIs, aggregates larger than
223+
// 16 bytes are passed by hidden pointer (rule B.4), so passing &uid.id[0]
224+
// matches the C calling convention for ncclUniqueId-by-value. This binding
225+
// is therefore correct on linux/arm64 (the supported NCCL platform); other
226+
// ABIs (System V AMD64) pass large aggregates on the stack and would need a
227+
// dedicated trampoline.
228+
func InitRank(uid *UniqueID, nRanks, rank int) (*Comm, error) {
229+
if uid == nil {
230+
return nil, fmt.Errorf("nccl InitRank: nil UniqueID")
231+
}
232+
lib, err := getNcclLib()
233+
if err != nil {
234+
return nil, err
235+
}
236+
var comm uintptr
237+
rc := cuda.Ccall(lib.commInitRank,
238+
uintptr(unsafe.Pointer(&comm)),
239+
uintptr(nRanks),
240+
uintptr(unsafe.Pointer(&uid.id[0])),
241+
uintptr(rank),
242+
)
243+
if rc != ncclSuccess {
244+
return nil, fmt.Errorf("ncclCommInitRank(nRanks=%d, rank=%d) failed: %s",
245+
nRanks, rank, lib.errorString(rc))
246+
}
247+
return &Comm{comm: comm}, nil
248+
}
249+
250+
// Destroy releases the communicator resources.
251+
func (c *Comm) Destroy() error {
252+
lib, err := getNcclLib()
253+
if err != nil {
254+
return err
255+
}
256+
rc := cuda.Ccall(lib.commDestroy, c.comm)
257+
if rc != ncclSuccess {
258+
return fmt.Errorf("ncclCommDestroy failed: %s", lib.errorString(rc))
259+
}
260+
return nil
261+
}
262+
263+
// AllReduce performs an in-place all-reduce across all ranks. sendBuf and
264+
// recvBuf may be the same pointer for in-place operation. count is the number
265+
// of elements (not bytes). The stream parameter is a cudaStream_t as
266+
// unsafe.Pointer.
267+
func (c *Comm) AllReduce(sendBuf, recvBuf unsafe.Pointer, count int, dtype DataType, op ReduceOp, stream unsafe.Pointer) error {
268+
lib, err := getNcclLib()
269+
if err != nil {
270+
return err
271+
}
272+
rc := cuda.Ccall(lib.allReduce,
273+
uintptr(sendBuf),
274+
uintptr(recvBuf),
275+
uintptr(count),
276+
uintptr(dtype),
277+
uintptr(op),
278+
c.comm,
279+
uintptr(stream),
280+
)
281+
if rc != ncclSuccess {
282+
return fmt.Errorf("ncclAllReduce failed: %s", lib.errorString(rc))
283+
}
284+
return nil
285+
}
286+
287+
// Broadcast sends count elements from root's sendBuf to all ranks' recvBuf.
288+
// For root, sendBuf and recvBuf may differ or be the same.
289+
func (c *Comm) Broadcast(sendBuf, recvBuf unsafe.Pointer, count int, dtype DataType, root int, stream unsafe.Pointer) error {
290+
lib, err := getNcclLib()
291+
if err != nil {
292+
return err
293+
}
294+
rc := cuda.Ccall(lib.broadcast,
295+
uintptr(sendBuf),
296+
uintptr(recvBuf),
297+
uintptr(count),
298+
uintptr(dtype),
299+
uintptr(root),
300+
c.comm,
301+
uintptr(stream),
302+
)
303+
if rc != ncclSuccess {
304+
return fmt.Errorf("ncclBroadcast failed: %s", lib.errorString(rc))
305+
}
306+
return nil
307+
}
308+
309+
// GroupStart begins a group of NCCL operations. All NCCL calls between
310+
// GroupStart and GroupEnd are batched into a single launch.
311+
func GroupStart() error {
312+
lib, err := getNcclLib()
313+
if err != nil {
314+
return err
315+
}
316+
rc := cuda.Ccall(lib.groupStart)
317+
if rc != ncclSuccess {
318+
return fmt.Errorf("ncclGroupStart failed: %s", lib.errorString(rc))
319+
}
320+
return nil
321+
}
322+
323+
// GroupEnd completes a group of NCCL operations and launches them.
324+
func GroupEnd() error {
325+
lib, err := getNcclLib()
326+
if err != nil {
327+
return err
328+
}
329+
rc := cuda.Ccall(lib.groupEnd)
330+
if rc != ncclSuccess {
331+
return fmt.Errorf("ncclGroupEnd failed: %s", lib.errorString(rc))
332+
}
333+
return nil
334+
}
335+
336+
// GetAsyncError queries the communicator for any asynchronous errors that
337+
// occurred during previous operations.
338+
func (c *Comm) GetAsyncError() error {
339+
lib, err := getNcclLib()
340+
if err != nil {
341+
return err
342+
}
343+
var result uintptr
344+
rc := cuda.Ccall(lib.commGetAsyncError, c.comm, uintptr(unsafe.Pointer(&result)))
345+
if rc != ncclSuccess {
346+
return fmt.Errorf("ncclCommGetAsyncError query failed: %s", lib.errorString(rc))
347+
}
348+
if result != ncclSuccess {
349+
return fmt.Errorf("NCCL async error: %s", lib.errorString(result))
350+
}
351+
return nil
352+
}

0 commit comments

Comments
 (0)