Skip to content

Commit ef3c144

Browse files
committed
Add triton and flash-attn to the supported libraries
1 parent 52a6a53 commit ef3c144

2 files changed

Lines changed: 12 additions & 7 deletions

File tree

uv/Dockerfile

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ SHELL ["/bin/bash", "-c"]
5858
ENV PATH=/root/.local/bin:$PATH
5959

6060
RUN uv python pin ${PY_VERSION}
61-
RUN uv pip install --no-cache --system ipython tqdm rich jupyter jupyterlab ipykernel pandas einops safetensors pyyaml requests psutil opencv-python-headless matplotlib seaborn scikit-learn scipy pillow tensorboard h5py
61+
RUN uv pip install --no-cache --system ipython tqdm rich jupyter jupyterlab ipykernel pandas einops safetensors pyyaml requests psutil opencv-python-headless matplotlib seaborn scikit-learn scipy pillow tensorboard h5py triton
6262

6363
# torch installer
6464
COPY <<EOF /usr/local/bin/install_torch_env
@@ -67,28 +67,31 @@ set -e
6767
TORCH_VER=\$1
6868
VISION_VER=\$2
6969
CUDA_VER=\$3
70+
FLASH_ATTN_LINK=\$4
71+
7072
BASE_DIR=/ddiff-base/py\${PY_VERSION}-torch\${TORCH_VER}
7173
mkdir -p "\$BASE_DIR" && cd "\$BASE_DIR"
7274
uv venv --python "\${PY_VERSION}" --system-site-packages --seed
7375
uv pip install --no-cache torch=="\${TORCH_VER}" torchvision=="\${VISION_VER}" torchaudio=="\${TORCH_VER}" --index-url "https://download.pytorch.org/whl/\${CUDA_VER}"
74-
echo "ln -sfn \${BASE_DIR}/.venv ./.venv && uv add torch==\${TORCH_VER} torchvision==\${VISION_VER} torchaudio==\${TORCH_VER} && echo Created uv virtual environment with torch==\${TORCH_VER}" > /usr/local/bin/uv_init_torch\${TORCH_VER}
76+
uv pip install --no-cache \$FLASH_ATTN_LINK
77+
echo "ln -sfn \${BASE_DIR}/.venv ./.venv && [ -f pyproject.toml ] && uv add torch==\${TORCH_VER} torchvision==\${VISION_VER} torchaudio==\${TORCH_VER}; echo Created uv virtual environment with torch==\${TORCH_VER}" > /usr/local/bin/uv_init_torch\${TORCH_VER}
7578
chmod +x /usr/local/bin/uv_init_torch\${TORCH_VER}
7679
EOF
7780
RUN chmod +x /usr/local/bin/install_torch_env
7881

7982
ENV UV_NO_CACHE=1
8083

8184
# pytorch 2.4.1 (251209 Update)
82-
RUN install_torch_env 2.4.1 0.19.1 cu124
85+
RUN install_torch_env 2.4.1 0.19.1 cu124 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.3.12/flash_attn-2.8.0+cu124torch2.4-cp310-cp310-linux_x86_64.whl
8386

8487
# pytorch 2.5.1 (251209 Update)
85-
RUN install_torch_env 2.5.1 0.20.1 cu124
88+
RUN install_torch_env 2.5.1 0.20.1 cu124 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu124torch2.5-cp310-cp310-linux_x86_64.whl
8689

8790
# pytorch 2.6.0 (251209 Update)
88-
RUN install_torch_env 2.6.0 0.21.0 cu124
91+
RUN install_torch_env 2.6.0 0.21.0 cu124 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu124torch2.6-cp310-cp310-linux_x86_64.whl
8992

9093
# pytorch 2.7.1 (251209 Update)
91-
RUN install_torch_env 2.7.1 0.22.1 cu126
94+
RUN install_torch_env 2.7.1 0.22.1 cu126 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu126torch2.7-cp310-cp310-linux_x86_64.whl
9295

9396
# pytorch 2.9.0 (251209 Update)
94-
RUN install_torch_env 2.9.0 0.24.0 cu126
97+
RUN install_torch_env 2.9.0 0.24.0 cu126 https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.6.3+cu126torch2.9-cp310-cp310-linux_x86_64.whl

uv/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ you can use standard uv commands. The environment is automatically detected via
3838
# Base Images
3939
### Common Packages
4040
* **Python Libraries**
41+
* **Kernel:** triton
4142
* **Data & Science:** numpy pandas scipy scikit-learn matplotlib seaborn
4243
* **Tools:** jupyterlab ipykernel tqdm rich
4344
* **CV & Utils:** opencv-python-headless pillow einops safetensors
@@ -51,5 +52,6 @@ you can use standard uv commands. The environment is automatically detected via
5152
**Image Tag:** `junwha/ddiff-base:cu12.4.1-py3.10-torch-251210`
5253
* **Python:** 3.10
5354
* **Pre-installed PyTorch Versions:** 2.4.1, 2.5.1, 2.6.0, 2.7.1, 2.9.0
55+
* **Pre-installed Flash-attention Version:** 2.8.3
5456
* **Compute Capabilities**: 7.0 7.5 8.0 8.6 8.9 9.0+PTX
5557

0 commit comments

Comments
 (0)