1+ cmake_minimum_required (VERSION 3.15...3.30 )
2+ project (openequivariance_stable_ext)
3+
4+ find_package (Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module )
5+
6+ # Download LibTorch
7+ include (FetchContent )
8+
9+ FetchContent_Declare (
10+ libtorch
11+ URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.10.0%2Bcpu.zip"
12+ )
13+
14+ message (STATUS "Downloading LibTorch..." )
15+ FetchContent_MakeAvailable (libtorch)
16+
17+ set (LIBTORCH_INCLUDE_DIR "${libtorch_SOURCE_DIR} /include" )
18+ set (LIBTORCH_LIB_DIR "${libtorch_SOURCE_DIR} /lib" )
19+ find_library (TORCH_CPU_LIB NAMES torch_cpu PATHS "${LIBTORCH_LIB_DIR} " NO_DEFAULT_PATH )
20+ find_library (C10_LIB NAMES c10 PATHS "${LIBTORCH_LIB_DIR} " NO_DEFAULT_PATH )
21+
22+ message (STATUS "LibTorch Include: ${LIBTORCH_INCLUDE_DIR} " )
23+ message (STATUS "LibTorch Lib: ${LIBTORCH_LIB_DIR} " )
24+
25+ message (STATUS "Torch CPU Library: ${TORCH_CPU_LIB} " )
26+ message (STATUS "Torch C10 Library: ${C10_LIB} " )
27+
28+ # Setup Nanobind
29+ execute_process (
30+ COMMAND "${Python_EXECUTABLE} " -m nanobind --cmake_dir
31+ OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT
32+ )
33+ message (STATUS "nanobind cmake directory: ${nanobind_ROOT} " )
34+
35+ find_package (nanobind CONFIG REQUIRED )
36+
37+ set (EXT_DIR "${CMAKE_CURRENT_SOURCE_DIR} /openequivariance/extension" )
38+ set (EXT_BACKEND_DIR "${EXT_DIR} /backend" )
39+ set (EXT_JSON_DIR "${EXT_DIR} /json11" )
40+
41+ # Source files
42+ set (OEQ_SOURCES
43+ ${EXT_DIR} /libtorch_tp_jit_stable.cpp
44+ ${EXT_JSON_DIR} /json11.cpp
45+ )
46+
47+ set (OEQ_INSTALL_DIR "${CMAKE_CURRENT_SOURCE_DIR} /openequivariance/_torch/extlib" )
48+
49+ function (add_stable_extension target_name backend_define link_libraries )
50+ # Create nanobind extension
51+ nanobind_add_module (${target_name} NB_STATIC ${OEQ_SOURCES} )
52+
53+ set_target_properties (${target_name} PROPERTIES
54+ CXX_STANDARD 17
55+ CXX_STANDARD_REQUIRED ON
56+ POSITION_INDEPENDENT_CODE ON
57+ )
58+
59+ # Enforce CXX11 ABI to match LibTorch
60+ target_compile_definitions (${target_name} PRIVATE
61+ ${backend_define} =1
62+ _GLIBCXX_USE_CXX11_ABI=1
63+ INCLUDE_NB_EXTENSION
64+ )
65+
66+ target_include_directories (${target_name} PRIVATE
67+ ${EXT_DIR}
68+ ${EXT_BACKEND_DIR}
69+ ${EXT_JSON_DIR}
70+ ${LIBTORCH_INCLUDE_DIR}
71+ )
72+ target_link_libraries (${target_name} PRIVATE
73+ ${TORCH_CPU_LIB}
74+ ${C10_LIB}
75+ ${link_libraries}
76+ )
77+
78+ install (TARGETS ${target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR} " )
79+
80+ # AOTI C++ library (identical except without nanobind and without INCLUDE_NB_EXTENSION)
81+ set (aoti_target_name ${target_name} _aoti)
82+ add_library (${aoti_target_name} SHARED ${OEQ_SOURCES} )
83+
84+ set_target_properties (${aoti_target_name} PROPERTIES
85+ CXX_STANDARD 17
86+ CXX_STANDARD_REQUIRED ON
87+ POSITION_INDEPENDENT_CODE ON
88+ )
89+
90+ target_compile_definitions (${aoti_target_name} PRIVATE
91+ ${backend_define} =1
92+ _GLIBCXX_USE_CXX11_ABI=1
93+ )
94+
95+ target_include_directories (${aoti_target_name} PRIVATE
96+ ${EXT_DIR}
97+ ${EXT_BACKEND_DIR}
98+ ${EXT_JSON_DIR}
99+ ${LIBTORCH_INCLUDE_DIR}
100+ )
101+ target_link_libraries (${aoti_target_name} PRIVATE
102+ ${TORCH_CPU_LIB}
103+ ${C10_LIB}
104+ ${link_libraries}
105+ )
106+
107+ install (TARGETS ${aoti_target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR} " )
108+ endfunction ()
109+
110+ find_package (CUDAToolkit QUIET )
111+ find_package (hip QUIET )
112+
113+ if (CUDAToolkit_FOUND)
114+ message (STATUS "Building stable extension with CUDA backend." )
115+
116+ add_library (cuda_stub_lib SHARED ${EXT_DIR} /stubs/stream.cpp )
117+
118+ target_include_directories (cuda_stub_lib PRIVATE
119+ ${LIBTORCH_INCLUDE_DIR}
120+ )
121+
122+ set_target_properties (cuda_stub_lib PROPERTIES
123+ OUTPUT_NAME "torch_cuda"
124+ POSITION_INDEPENDENT_CODE ON
125+ CXX_STANDARD 17
126+ )
127+
128+ set (CUDA_LINK_LIBS
129+ CUDA::cudart
130+ CUDA::cuda_driver
131+ CUDA::nvrtc
132+ cuda_stub_lib
133+ )
134+ add_stable_extension (oeq_stable_cuda CUDA_BACKEND "${CUDA_LINK_LIBS} " )
135+ endif ()
136+
137+ if (hip_FOUND)
138+ message (STATUS "Building stable extension with HIP backend." )
139+
140+ add_library (hip_stub_lib SHARED ${EXT_DIR} /stubs/stream.cpp )
141+
142+ target_include_directories (hip_stub_lib PRIVATE
143+ ${LIBTORCH_INCLUDE_DIR}
144+ )
145+
146+ set_target_properties (hip_stub_lib PROPERTIES
147+ OUTPUT_NAME "torch_hip"
148+ POSITION_INDEPENDENT_CODE ON
149+ CXX_STANDARD 17
150+ )
151+
152+ set (HIP_LINK_LIBS
153+ hiprtc
154+ hip_stub_lib
155+ )
156+ add_stable_extension (torch_stable_hip HIP_BACKEND "${HIP_LINK_LIBS} " )
157+ endif ()
158+
159+ if (NOT CUDAToolkit_FOUND AND NOT hip_FOUND)
160+ message (WARNING "Neither CUDAToolkit nor HIP was found. The stable extension will not be built." )
161+ endif ()
0 commit comments