|
| 1 | +// Zero-CGo binding to libnccl.so.2 loaded at runtime via dlopen. On non-linux |
| 2 | +// platforms every exported entry point returns a clean "not supported" error |
| 3 | +// without attempting dlopen. On linux the library is dlopen'd lazily on first |
| 4 | +// use; if libnccl.so.2 cannot be found the same error is returned. |
| 5 | +// |
| 6 | +// On AArch64 (DGX hardware) AAPCS64 rule B.4 means aggregates larger than 16 |
| 7 | +// bytes are passed by hidden pointer, which lets us hand the 128-byte |
| 8 | +// ncclUniqueId to ncclCommInitRank as a plain uintptr without an ABI |
| 9 | +// trampoline. The legacy CGo implementation is retained behind the |
| 10 | +// `nccl_cgo` build tag (see nccl_cgo.go). |
| 11 | +package nccl |
| 12 | + |
| 13 | +import ( |
| 14 | + "fmt" |
| 15 | + "runtime" |
| 16 | + "sync" |
| 17 | + "unsafe" |
| 18 | + |
| 19 | + "github.com/zerfoo/ztensor/internal/cuda" |
| 20 | +) |
| 21 | + |
| 22 | +// NCCL_UNIQUE_ID_BYTES is the fixed size of an ncclUniqueId. |
| 23 | +const ncclUniqueIDBytes = 128 |
| 24 | + |
| 25 | +// NCCL result codes. |
| 26 | +const ncclSuccess = 0 |
| 27 | + |
| 28 | +// NCCL data type enum (stable ABI for NCCL 2.x). |
| 29 | +const ( |
| 30 | + ncclInt8 = 0 |
| 31 | + ncclUint8 = 1 |
| 32 | + ncclInt32 = 2 |
| 33 | + ncclUint32 = 3 |
| 34 | + ncclInt64 = 4 |
| 35 | + ncclUint64 = 5 |
| 36 | + ncclFloat16 = 6 |
| 37 | + ncclFloat32 = 7 |
| 38 | + ncclFloat64 = 8 |
| 39 | + ncclBfloat16 = 9 |
| 40 | +) |
| 41 | + |
| 42 | +// NCCL reduction op enum (stable ABI for NCCL 2.x). |
| 43 | +const ( |
| 44 | + ncclSum = 0 |
| 45 | + ncclProd = 1 |
| 46 | + ncclMax = 2 |
| 47 | + ncclMin = 3 |
| 48 | + ncclAvg = 4 |
| 49 | +) |
| 50 | + |
| 51 | +// DataType specifies the element type for NCCL operations. |
| 52 | +type DataType int |
| 53 | + |
| 54 | +const ( |
| 55 | + Float32 DataType = ncclFloat32 |
| 56 | + Float64 DataType = ncclFloat64 |
| 57 | + Int32 DataType = ncclInt32 |
| 58 | + Int64 DataType = ncclInt64 |
| 59 | +) |
| 60 | + |
| 61 | +// ReduceOp specifies the reduction operation for collective calls. |
| 62 | +type ReduceOp int |
| 63 | + |
| 64 | +const ( |
| 65 | + Sum ReduceOp = ncclSum |
| 66 | + Avg ReduceOp = ncclAvg |
| 67 | + Max ReduceOp = ncclMax |
| 68 | + Min ReduceOp = ncclMin |
| 69 | +) |
| 70 | + |
| 71 | +// ncclLib holds dlsym-resolved function pointers for libnccl.so.2. |
| 72 | +type ncclLib struct { |
| 73 | + getUniqueId uintptr // ncclGetUniqueId |
| 74 | + commInitRank uintptr // ncclCommInitRank |
| 75 | + commDestroy uintptr // ncclCommDestroy |
| 76 | + commGetAsyncError uintptr // ncclCommGetAsyncError |
| 77 | + allReduce uintptr // ncclAllReduce |
| 78 | + broadcast uintptr // ncclBroadcast |
| 79 | + groupStart uintptr // ncclGroupStart |
| 80 | + groupEnd uintptr // ncclGroupEnd |
| 81 | + getErrorString uintptr // ncclGetErrorString |
| 82 | +} |
| 83 | + |
| 84 | +var ( |
| 85 | + ncclLibInst *ncclLib |
| 86 | + ncclOnce sync.Once |
| 87 | + ncclLoadErr error |
| 88 | +) |
| 89 | + |
| 90 | +// libnccl candidate paths (linux only). |
| 91 | +var ncclLibPaths = []string{ |
| 92 | + "libnccl.so.2", |
| 93 | + "libnccl.so", |
| 94 | +} |
| 95 | + |
| 96 | +func loadNccl() (*ncclLib, error) { |
| 97 | + if runtime.GOOS != "linux" { |
| 98 | + return nil, fmt.Errorf("nccl: not supported on %s", runtime.GOOS) |
| 99 | + } |
| 100 | + var handle uintptr |
| 101 | + var lastErr string |
| 102 | + for _, path := range ncclLibPaths { |
| 103 | + h, err := cuda.DlopenPath(path) |
| 104 | + if err == nil { |
| 105 | + handle = h |
| 106 | + break |
| 107 | + } |
| 108 | + lastErr = err.Error() |
| 109 | + } |
| 110 | + if handle == 0 { |
| 111 | + return nil, fmt.Errorf("nccl: dlopen failed: %s", lastErr) |
| 112 | + } |
| 113 | + |
| 114 | + lib := &ncclLib{} |
| 115 | + type sym struct { |
| 116 | + name string |
| 117 | + ptr *uintptr |
| 118 | + } |
| 119 | + syms := []sym{ |
| 120 | + {"ncclGetUniqueId", &lib.getUniqueId}, |
| 121 | + {"ncclCommInitRank", &lib.commInitRank}, |
| 122 | + {"ncclCommDestroy", &lib.commDestroy}, |
| 123 | + {"ncclCommGetAsyncError", &lib.commGetAsyncError}, |
| 124 | + {"ncclAllReduce", &lib.allReduce}, |
| 125 | + {"ncclBroadcast", &lib.broadcast}, |
| 126 | + {"ncclGroupStart", &lib.groupStart}, |
| 127 | + {"ncclGroupEnd", &lib.groupEnd}, |
| 128 | + {"ncclGetErrorString", &lib.getErrorString}, |
| 129 | + } |
| 130 | + for _, s := range syms { |
| 131 | + addr, err := cuda.Dlsym(handle, s.name) |
| 132 | + if err != nil { |
| 133 | + return nil, fmt.Errorf("nccl: %w", err) |
| 134 | + } |
| 135 | + *s.ptr = addr |
| 136 | + } |
| 137 | + return lib, nil |
| 138 | +} |
| 139 | + |
| 140 | +func getNcclLib() (*ncclLib, error) { |
| 141 | + ncclOnce.Do(func() { |
| 142 | + ncclLibInst, ncclLoadErr = loadNccl() |
| 143 | + }) |
| 144 | + return ncclLibInst, ncclLoadErr |
| 145 | +} |
| 146 | + |
| 147 | +// Available returns true if libnccl can be loaded at runtime. |
| 148 | +func Available() bool { |
| 149 | + _, err := getNcclLib() |
| 150 | + return err == nil |
| 151 | +} |
| 152 | + |
| 153 | +// errorString returns the human-readable error string for an NCCL result code. |
| 154 | +// Falls back to a numeric description if ncclGetErrorString cannot be invoked. |
| 155 | +func (l *ncclLib) errorString(rc uintptr) string { |
| 156 | + if l == nil || l.getErrorString == 0 { |
| 157 | + return fmt.Sprintf("ncclResult=%d", rc) |
| 158 | + } |
| 159 | + cstr := cuda.Ccall(l.getErrorString, rc) |
| 160 | + if cstr == 0 { |
| 161 | + return fmt.Sprintf("ncclResult=%d", rc) |
| 162 | + } |
| 163 | + // Read C-string at cstr. |
| 164 | + var b []byte |
| 165 | + for i := 0; i < 1024; i++ { |
| 166 | + c := *(*byte)(unsafe.Pointer(cstr + uintptr(i))) |
| 167 | + if c == 0 { |
| 168 | + break |
| 169 | + } |
| 170 | + b = append(b, c) |
| 171 | + } |
| 172 | + return string(b) |
| 173 | +} |
| 174 | + |
| 175 | +// UniqueID wraps an ncclUniqueId (128-byte opaque blob) used to bootstrap |
| 176 | +// communicator creation. |
| 177 | +type UniqueID struct { |
| 178 | + id [ncclUniqueIDBytes]byte |
| 179 | +} |
| 180 | + |
| 181 | +// GetUniqueID generates a new unique ID for communicator initialization. |
| 182 | +// Exactly one rank should call this and broadcast the result to all other ranks. |
| 183 | +func GetUniqueID() (*UniqueID, error) { |
| 184 | + lib, err := getNcclLib() |
| 185 | + if err != nil { |
| 186 | + return nil, err |
| 187 | + } |
| 188 | + uid := &UniqueID{} |
| 189 | + rc := cuda.Ccall(lib.getUniqueId, uintptr(unsafe.Pointer(&uid.id[0]))) |
| 190 | + if rc != ncclSuccess { |
| 191 | + return nil, fmt.Errorf("ncclGetUniqueId failed: %s", lib.errorString(rc)) |
| 192 | + } |
| 193 | + return uid, nil |
| 194 | +} |
| 195 | + |
| 196 | +// Bytes returns a copy of the raw bytes of the unique ID for serialization. |
| 197 | +func (u *UniqueID) Bytes() []byte { |
| 198 | + out := make([]byte, ncclUniqueIDBytes) |
| 199 | + copy(out, u.id[:]) |
| 200 | + return out |
| 201 | +} |
| 202 | + |
| 203 | +// UniqueIDFromBytes reconstructs a UniqueID from raw bytes. |
| 204 | +func UniqueIDFromBytes(b []byte) (*UniqueID, error) { |
| 205 | + if len(b) != ncclUniqueIDBytes { |
| 206 | + return nil, fmt.Errorf("UniqueIDFromBytes: expected %d bytes, got %d", ncclUniqueIDBytes, len(b)) |
| 207 | + } |
| 208 | + uid := &UniqueID{} |
| 209 | + copy(uid.id[:], b) |
| 210 | + return uid, nil |
| 211 | +} |
| 212 | + |
| 213 | +// Comm wraps an ncclComm_t communicator (opaque pointer). |
| 214 | +type Comm struct { |
| 215 | + comm uintptr |
| 216 | +} |
| 217 | + |
| 218 | +// InitRank initializes a communicator for a given rank in a group of nRanks. |
| 219 | +// All ranks must call this with the same UniqueID and nRanks. The CUDA device |
| 220 | +// for this rank must be set via cuda.SetDevice before calling InitRank. |
| 221 | +// |
| 222 | +// On AArch64 (AAPCS64) and other AAPCS-derived ABIs, aggregates larger than |
| 223 | +// 16 bytes are passed by hidden pointer (rule B.4), so passing &uid.id[0] |
| 224 | +// matches the C calling convention for ncclUniqueId-by-value. This binding |
| 225 | +// is therefore correct on linux/arm64 (the supported NCCL platform); other |
| 226 | +// ABIs (System V AMD64) pass large aggregates on the stack and would need a |
| 227 | +// dedicated trampoline. |
| 228 | +func InitRank(uid *UniqueID, nRanks, rank int) (*Comm, error) { |
| 229 | + if uid == nil { |
| 230 | + return nil, fmt.Errorf("nccl InitRank: nil UniqueID") |
| 231 | + } |
| 232 | + lib, err := getNcclLib() |
| 233 | + if err != nil { |
| 234 | + return nil, err |
| 235 | + } |
| 236 | + var comm uintptr |
| 237 | + rc := cuda.Ccall(lib.commInitRank, |
| 238 | + uintptr(unsafe.Pointer(&comm)), |
| 239 | + uintptr(nRanks), |
| 240 | + uintptr(unsafe.Pointer(&uid.id[0])), |
| 241 | + uintptr(rank), |
| 242 | + ) |
| 243 | + if rc != ncclSuccess { |
| 244 | + return nil, fmt.Errorf("ncclCommInitRank(nRanks=%d, rank=%d) failed: %s", |
| 245 | + nRanks, rank, lib.errorString(rc)) |
| 246 | + } |
| 247 | + return &Comm{comm: comm}, nil |
| 248 | +} |
| 249 | + |
| 250 | +// Destroy releases the communicator resources. |
| 251 | +func (c *Comm) Destroy() error { |
| 252 | + lib, err := getNcclLib() |
| 253 | + if err != nil { |
| 254 | + return err |
| 255 | + } |
| 256 | + rc := cuda.Ccall(lib.commDestroy, c.comm) |
| 257 | + if rc != ncclSuccess { |
| 258 | + return fmt.Errorf("ncclCommDestroy failed: %s", lib.errorString(rc)) |
| 259 | + } |
| 260 | + return nil |
| 261 | +} |
| 262 | + |
| 263 | +// AllReduce performs an in-place all-reduce across all ranks. sendBuf and |
| 264 | +// recvBuf may be the same pointer for in-place operation. count is the number |
| 265 | +// of elements (not bytes). The stream parameter is a cudaStream_t as |
| 266 | +// unsafe.Pointer. |
| 267 | +func (c *Comm) AllReduce(sendBuf, recvBuf unsafe.Pointer, count int, dtype DataType, op ReduceOp, stream unsafe.Pointer) error { |
| 268 | + lib, err := getNcclLib() |
| 269 | + if err != nil { |
| 270 | + return err |
| 271 | + } |
| 272 | + rc := cuda.Ccall(lib.allReduce, |
| 273 | + uintptr(sendBuf), |
| 274 | + uintptr(recvBuf), |
| 275 | + uintptr(count), |
| 276 | + uintptr(dtype), |
| 277 | + uintptr(op), |
| 278 | + c.comm, |
| 279 | + uintptr(stream), |
| 280 | + ) |
| 281 | + if rc != ncclSuccess { |
| 282 | + return fmt.Errorf("ncclAllReduce failed: %s", lib.errorString(rc)) |
| 283 | + } |
| 284 | + return nil |
| 285 | +} |
| 286 | + |
| 287 | +// Broadcast sends count elements from root's sendBuf to all ranks' recvBuf. |
| 288 | +// For root, sendBuf and recvBuf may differ or be the same. |
| 289 | +func (c *Comm) Broadcast(sendBuf, recvBuf unsafe.Pointer, count int, dtype DataType, root int, stream unsafe.Pointer) error { |
| 290 | + lib, err := getNcclLib() |
| 291 | + if err != nil { |
| 292 | + return err |
| 293 | + } |
| 294 | + rc := cuda.Ccall(lib.broadcast, |
| 295 | + uintptr(sendBuf), |
| 296 | + uintptr(recvBuf), |
| 297 | + uintptr(count), |
| 298 | + uintptr(dtype), |
| 299 | + uintptr(root), |
| 300 | + c.comm, |
| 301 | + uintptr(stream), |
| 302 | + ) |
| 303 | + if rc != ncclSuccess { |
| 304 | + return fmt.Errorf("ncclBroadcast failed: %s", lib.errorString(rc)) |
| 305 | + } |
| 306 | + return nil |
| 307 | +} |
| 308 | + |
| 309 | +// GroupStart begins a group of NCCL operations. All NCCL calls between |
| 310 | +// GroupStart and GroupEnd are batched into a single launch. |
| 311 | +func GroupStart() error { |
| 312 | + lib, err := getNcclLib() |
| 313 | + if err != nil { |
| 314 | + return err |
| 315 | + } |
| 316 | + rc := cuda.Ccall(lib.groupStart) |
| 317 | + if rc != ncclSuccess { |
| 318 | + return fmt.Errorf("ncclGroupStart failed: %s", lib.errorString(rc)) |
| 319 | + } |
| 320 | + return nil |
| 321 | +} |
| 322 | + |
| 323 | +// GroupEnd completes a group of NCCL operations and launches them. |
| 324 | +func GroupEnd() error { |
| 325 | + lib, err := getNcclLib() |
| 326 | + if err != nil { |
| 327 | + return err |
| 328 | + } |
| 329 | + rc := cuda.Ccall(lib.groupEnd) |
| 330 | + if rc != ncclSuccess { |
| 331 | + return fmt.Errorf("ncclGroupEnd failed: %s", lib.errorString(rc)) |
| 332 | + } |
| 333 | + return nil |
| 334 | +} |
| 335 | + |
| 336 | +// GetAsyncError queries the communicator for any asynchronous errors that |
| 337 | +// occurred during previous operations. |
| 338 | +func (c *Comm) GetAsyncError() error { |
| 339 | + lib, err := getNcclLib() |
| 340 | + if err != nil { |
| 341 | + return err |
| 342 | + } |
| 343 | + var result uintptr |
| 344 | + rc := cuda.Ccall(lib.commGetAsyncError, c.comm, uintptr(unsafe.Pointer(&result))) |
| 345 | + if rc != ncclSuccess { |
| 346 | + return fmt.Errorf("ncclCommGetAsyncError query failed: %s", lib.errorString(rc)) |
| 347 | + } |
| 348 | + if result != ncclSuccess { |
| 349 | + return fmt.Errorf("NCCL async error: %s", lib.errorString(result)) |
| 350 | + } |
| 351 | + return nil |
| 352 | +} |
0 commit comments