@@ -15,5 +15,43 @@ RUN conda init && \
1515 conda install python=$PY_VERSION -y && \
1616 conda install -c "nvidia/label/cuda-$DOCKER_CUDA_VERSION" cuda -y
1717
18- # RUN conda init && \
19- # pip install torch==$TORCH_VERSION torchvision==$TORCH_VISION_VERSION torchaudio==$TORCH_VERSION --index-url https://download.pytorch.org/whl/cu$(cat /CUDA_VERSION | awk -F. '{print $1$2}')
18+ # torch installer
19+ COPY <<EOF /usr/local/bin/install_torch_env
20+ # !/bin/bash
21+ set -e
22+ TORCH_VER=\$ 1
23+ VISION_VER=\$ 2
24+ CUDA_VER=\$ 3
25+ FLASH_ATTN_LINK=\$ 4
26+
27+ conda create -n torch\$ {TORCH_VER}
28+
29+ export PATH_BCK=\$ PATH
30+ export PATH=/root/anaconda/envs/torch\$ {TORCH_VER}:\$ PATH
31+
32+ pip install --no-cache-dir torch=="\$ {TORCH_VER}" torchvision=="\$ {VISION_VER}" torchaudio=="\$ {TORCH_VER}" --index-url "https://download.pytorch.org/whl/\$ {CUDA_VER}"
33+ pip install --no-cache-dir \$ FLASH_ATTN_LINK
34+
35+ """conda activate torch\$ {TORCH_VER} && echo " export PATH=/root/anaconda/envs/torch\$ {TORCH_VER}:\\\$ PATH && source activate torch\$ {TORCH_VER}" &>> ~/.bashrc""" > /usr/local/bin/init_torch\$ {TORCH_VER}
36+ chmod +x /usr/local/bin/init_torch\$ {TORCH_VER}
37+
38+ export PATH=\$ PATH_BCK
39+
40+ EOF
41+
42+ RUN chmod +x /usr/local/bin/install_torch_env
43+
44+ # pytorch 2.4.1 (251209 Update)
45+ RUN install_torch_env 2.4.1 0.19.1 cu124 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.3.12/flash_attn-2.8.0+cu124torch2.4-cp310-cp310-linux_x86_64.whl
46+
47+ # pytorch 2.5.1 (251209 Update)
48+ RUN install_torch_env 2.5.1 0.20.1 cu124 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu124torch2.5-cp310-cp310-linux_x86_64.whl
49+
50+ # pytorch 2.6.0 (251209 Update)
51+ RUN install_torch_env 2.6.0 0.21.0 cu124 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu124torch2.6-cp310-cp310-linux_x86_64.whl
52+
53+ # pytorch 2.7.1 (251209 Update)
54+ RUN install_torch_env 2.7.1 0.22.1 cu126 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu126torch2.7-cp310-cp310-linux_x86_64.whl
55+
56+ # pytorch 2.9.0 (251209 Update)
57+ RUN install_torch_env 2.9.0 0.24.0 cu126 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.6.3+cu126torch2.9-cp310-cp310-linux_x86_64.whl
0 commit comments