Skip to content

TRAK on cuda version 12.1 have CUDA error #82

@enkeejunior1

Description

@enkeejunior1
  • minimal code for reproduce the error:
import torch
import trak
from trak.projectors import ProjectionType, AbstractProjector, CudaProjector

print("trak.test_install:", trak.test_install(use_fast_jl=True))
grad_dim = int(1e6)
projector = CudaProjector(
    grad_dim=grad_dim,
    proj_dim=32768,
    seed=42, 
    proj_type=ProjectionType.normal,
    device='cuda:0',
    max_batch_size=8,
)
grad = torch.randn(8, grad_dim, device='cuda:0')
proj = projector.project(grad, model_id=0)
print(proj)
  • env installation code
pip install scikit-learn matplotlib einops ipykernel
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121

conda install cuda=12.1 -c nvidia
conda install cuda-nvcc=12.1 -c nvidia -y
conda install cuda-toolkit=12.1 -c nvidia -y
export CUDA_HOME=$CONDA_PREFIX
export PYTHONPATH=$CONDA_PREFIX/lib/python3.x/site-packages:$PYTHONPATH
pip install traker[fast]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions