[FEATURE]add cube tile ops(TLOAD>MAT, TSTORE>ACC, TSTORE.MAT, TSTORE_FP)#814
[FEATURE]add cube tile ops(TLOAD>MAT, TSTORE>ACC, TSTORE.MAT, TSTORE_FP)#814pbbb205 wants to merge 12 commits into
Conversation
Codex Review该评论由 review 机器人自动更新。
Summary发现 4 个实质问题,主要是 TSTORE.ACC/TLOAD.MAT 的模板选择错误,以及 TSTORE_FP/TLOAD.MAT 的合同不一致。 Findings
这两个约束唯一的区分条件只是比较
PR 在模板约束和注释里都把
|
There was a problem hiding this comment.
Code Review
This pull request translates comments to English and introduces several new Cube Matrix templates for TLOAD and TSTORE operations (such as GM to MAT, ACC to GM, MAT to GM, and ACC+FP to GM) along with constraint checks and comprehensive LIT tests. The review feedback highlights a critical correctness bug in template_tstore_mat_to_gm where the intermediate UB copy step is omitted, a potential AttributeError in template_tstore_fp_acc_to_gm from using .element_type instead of .dtype on a PartitionTensorView, incomplete layout constraint checks for NZ formats, and several unused variables across the templates.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| mat_ptr = src.as_ptr() | ||
| gm_ptr = dst.as_ptr() | ||
|
|
||
| # NOTE: There is no direct MAT -> GM DMA path. Data is written to UB via | ||
| # mte_l1_ub, then forwarded to GM via copy_ubuf_to_gm. | ||
| len_burst = k * elem_bytes | ||
|
|
||
| pto.mte_l1_ub(mat_ptr, gm_ptr, len_burst, nburst=(m, 0, 0)) |
There was a problem hiding this comment.
There is a critical correctness bug in template_tstore_mat_to_gm. The comment states that data is written to UB via mte_l1_ub and then forwarded to GM via copy_ubuf_to_gm. However, the code calls mte_l1_ub with gm_ptr as the destination (which is a GM pointer, not a UB pointer), and completely omits the copy_ubuf_to_gm step. This will lead to compilation failures or hardware faults. Please allocate a temporary UB tile, copy from MAT to UB, and then copy from UB to GM.
| dst_dtype = dst.element_type | ||
|
|
||
| if pto.constexpr(dst_dtype == pto.bf16): | ||
| quant_mode = "f32_bf16" | ||
| else: | ||
| quant_mode = "f32_f16" |
There was a problem hiding this comment.
In template_tstore_fp_acc_to_gm, dst is a PartitionTensorView. Calling dst.element_type is highly likely to raise an AttributeError because PartitionTensorView objects use .dtype to represent their data type (as seen in the constraints), whereas Tile objects use .element_type. Please use dst.dtype instead.
| dst_dtype = dst.element_type | |
| if pto.constexpr(dst_dtype == pto.bf16): | |
| quant_mode = "f32_bf16" | |
| else: | |
| quant_mode = "f32_f16" | |
| dst_dtype = dst.dtype | |
| if pto.constexpr(dst_dtype == pto.bf16): | |
| quant_mode = 'f32_bf16' | |
| else: | |
| quant_mode = 'f32_f16' |
| def _constraint_tload_mat_base(src, dst) -> bool: | ||
| """TLOAD.MAT base constraint check""" | ||
| # dst must be MemorySpace.MAT | ||
| dst_space = dst.memory_space | ||
| if dst_space is None: | ||
| return False | ||
| dst_space_value = dst_space.value if hasattr(dst_space, "value") else dst_space | ||
| if dst_space_value not in {"mat", "MAT"}: | ||
| return False |
There was a problem hiding this comment.
In _constraint_tload_mat_base, there is no check to ensure that the source src is in the GM memory space. Since TLOAD.MAT is specifically for loading from GM to L1 MAT, we should verify that src is indeed in GM to prevent incorrect matching of other memory spaces.
def _constraint_tload_mat_base(src, dst) -> bool:
\"\"\"TLOAD.MAT base constraint check\"\"\"
# src must be GM
src_space = src.memory_space
src_space_value = 'gm' if src_space is None else (src_space.value if hasattr(src_space, 'value') else src_space)
if src_space_value not in {'gm', 'GM'}:
return False
# dst must be MemorySpace.MAT
dst_space = dst.memory_space
if dst_space is None:
return False
dst_space_value = dst_space.value if hasattr(dst_space, 'value') else dst_space
if dst_space_value not in {'mat', 'MAT'}:
return False| def _constraint_tload_mat_nd2nz(src, dst) -> bool: | ||
| """TLOAD.MAT ND2NZ fractal load constraint""" | ||
| if not _constraint_tload_mat_base(src, dst): | ||
| return False | ||
| # dst layout must be col_major (NZ format) | ||
| config = dst.config | ||
| if config is None: | ||
| return False | ||
| b_layout = config.b_layout | ||
| if b_layout is None: | ||
| return False | ||
| b_layout_value = b_layout.value if hasattr(b_layout, "value") else b_layout | ||
| # COL_MAJOR corresponds to NZ format | ||
| if b_layout_value not in {"col_major", "COL_MAJOR"}: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
In _constraint_tload_mat_nd2nz, only b_layout is checked to be col_major. However, NZ format is defined by both b_layout=col_major and s_layout=row_major. We should check both to ensure strict layout matching.
def _constraint_tload_mat_nd2nz(src, dst) -> bool:
\"\"\"TLOAD.MAT ND2NZ fractal load constraint\"\"\"
if not _constraint_tload_mat_base(src, dst):
return False
# dst layout must be col_major (NZ format)
config = dst.config
if config is None:
return False
b_layout = config.b_layout
s_layout = config.s_layout
if b_layout is None or s_layout is None:
return False
b_layout_value = b_layout.value if hasattr(b_layout, 'value') else b_layout
s_layout_value = s_layout.value if hasattr(s_layout, 'value') else s_layout
# COL_MAJOR + ROW_MAJOR corresponds to NZ format
if b_layout_value not in {'col_major', 'COL_MAJOR'} or s_layout_value not in {'row_major', 'ROW_MAJOR'}:
return False
return True| def _constraint_tload_mat_dn2nz(src, dst) -> bool: | ||
| """TLOAD.MAT DN2NZ fractal load constraint""" | ||
| if not _constraint_tload_mat_base(src, dst): | ||
| return False | ||
| config = dst.config | ||
| if config is None: | ||
| return False | ||
| b_layout = config.b_layout | ||
| if b_layout is None: | ||
| return False | ||
| b_layout_value = b_layout.value if hasattr(b_layout, "value") else b_layout | ||
| if b_layout_value not in {"col_major", "COL_MAJOR"}: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
In _constraint_tload_mat_dn2nz, only b_layout is checked to be col_major. However, NZ format is defined by both b_layout=col_major and s_layout=row_major. We should check both to ensure strict layout matching.
def _constraint_tload_mat_dn2nz(src, dst) -> bool:
\"\"\"TLOAD.MAT DN2NZ fractal load constraint\"\"\"
if not _constraint_tload_mat_base(src, dst):
return False
config = dst.config
if config is None:
return False
b_layout = config.b_layout
s_layout = config.s_layout
if b_layout is None or s_layout is None:
return False
b_layout_value = b_layout.value if hasattr(b_layout, 'value') else b_layout
s_layout_value = s_layout.value if hasattr(s_layout, 'value') else s_layout
if b_layout_value not in {'col_major', 'COL_MAJOR'} or s_layout_value not in {'row_major', 'ROW_MAJOR'}:
return False
return True| m, k = dst.valid_shape | ||
| dtype = dst.element_type | ||
| elem_bytes = pto.bytewidth(dtype) |
There was a problem hiding this comment.
| m, k = dst.valid_shape | ||
| dtype = dst.element_type |
| def _constraint_tstore_acc_nz2nz(src, dst) -> bool: | ||
| """TSTORE.ACC NZ2NZ constraint""" | ||
| if not _constraint_tstore_acc_base(src, dst): | ||
| return False | ||
| # dst must be NZ layout (fractal) | ||
| config = dst.config | ||
| if config is None: | ||
| return False | ||
| # Check for fractal or special NZ layout marking | ||
| s_layout = config.s_layout | ||
| if s_layout is None: | ||
| return False | ||
| s_layout_value = s_layout.value if hasattr(s_layout, "value") else s_layout | ||
| if s_layout_value not in {"row_major", "ROW_MAJOR"}: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
In _constraint_tstore_acc_nz2nz, only s_layout is checked to be row_major. However, NZ format is defined by both b_layout=col_major and s_layout=row_major. We should check both to ensure strict layout matching.
def _constraint_tstore_acc_nz2nz(src, dst) -> bool:
\"\"\"TSTORE.ACC NZ2NZ constraint\"\"\"
if not _constraint_tstore_acc_base(src, dst):
return False
# dst must be NZ layout (fractal)
config = dst.config
if config is None:
return False
b_layout = config.b_layout
s_layout = config.s_layout
if b_layout is None or s_layout is None:
return False
b_layout_value = b_layout.value if hasattr(b_layout, 'value') else b_layout
s_layout_value = s_layout.value if hasattr(s_layout, 'value') else s_layout
if b_layout_value not in {'col_major', 'COL_MAJOR'} or s_layout_value not in {'row_major', 'ROW_MAJOR'}:
return False
return True| m, n = src.valid_shape | ||
| dtype = src.element_type |
| } | ||
|
|
||
| // CHECK-LABEL: tload_mat_dn2nz_f16 | ||
| // CHECK: TLOAD No newline at end of file |
There was a problem hiding this comment.
要测的不是pto -> emitc这条路径,是pto -> expand tileop -> vpto这条路径,需要设置--pto-arch=a5 --pto-backend=vpto
CHECK的目标是输出中包含对应的MTE指令
| // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| // See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| // RUN: not ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s |
| # See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| """`pto.tload` 的 TileLang DSL 模板""" | ||
| """TileLang DSL template for `pto.tload`""" |
There was a problem hiding this comment.
缺少端到端测试的st用例,可以构造一个GM ---tload---> MAT ---tstore---> GM的st用例进行验证,否则无法证明功能正确性
Root cause: tstore_template.py had `raise NotImplementedError` on line 543 which crashed the entire module import in TileLang DSL v1 (arbitrary external calls are not supported). This made ALL templates in the file unavailable, cascading to tstore, tload, tstore_fp, and tmrgsort failures. Changes: - tstore_template.py: replace `raise NotImplementedError` with `pass` (template is unreachable anyway due to dtypes=[]) - tstore_template.py: replace `dst.dtype` with `fp.element_type` in tstore_fp template (PartitionTensorView .dtype not supported in DSL v1) - PTO.cpp: add ACC element type (f32/i32) check to A5 TStoreFPOp verifier (tstore_fp_invalid_dtype test expects verifier to reject f16 ACC) - PTOOps.td: add OpPipeInterface + getPipe() to TStoreFPOp (returns PIPE_FIX, same pipe used by other ACC-store variants) Co-Authored-By: Claude <noreply@anthropic.com>
…ule import The _freeze_dtypes function requires at least one signature tuple. dtypes=[] causes: "dtypes must contain at least one signature tuple", which crashes the entire tstore_template.py import, losing ALL templates in the file (including valid tstore_acc, tstore_fp templates). Replace the disabled template + constraint + decorator with a comment explaining why it's omitted entirely. Co-Authored-By: Claude <noreply@anthropic.com>
Need to see stderr output from ptoas to diagnose why all 7 cube ckernel template tests fail with 'stdin is empty'. Will restore 2>/dev/null after fixing the actual issue. Co-Authored-By: Claude <noreply@anthropic.com>
| mode_value = self._require_string_expr(mode, f"{context} mode") | ||
| if mode_value in {"f32_f16", "f32_bf16"}: | ||
| self._require_fixpipe_scalar_payload(payload, f"{context} payload") | ||
| # f32_f16 / f32_bf16 pre_quant can be either a scalar (for inline |
There was a problem hiding this comment.
不要改这里,底层vpto指令本身就不接受pointer输入
| // Test TSTORE.ACC NZ2DN - tile-op expand to vpto mte_l0c_gm with nz2dn mode | ||
| // Pipeline: ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics | ||
|
|
||
| // RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - | FileCheck %s |
There was a problem hiding this comment.
lit用例名字可以改一下,现在和emitc没关系了
Zhendong404
left a comment
There was a problem hiding this comment.
需要把PTO-ISA仓中与本PR相关的tload/tstore st用例迁移过来
| # always matches the dst dtype in the tstore_fp signature. | ||
| fp_dtype = fp.element_type | ||
|
|
||
| if pto.constexpr(fp_dtype == pto.bf16): |
There was a problem hiding this comment.
tstore_fp不可能选择f32_bf16/f32_f16这两个mode
template <typename TileData, typename GlobalData, typename FpTileData, AtomicType atomicType = AtomicType::AtomicNone,
ReluPreMode reluPreMode = ReluPreMode::NoRelu, STPhase Phase = STPhase::Unspecified>
PTO_INTERNAL void TSTORE_IMPL(GlobalData &dst, TileData &src, FpTileData &fp)
{
static_assert(TileData::Loc == pto::TileType::Acc, "Source TileType only suport Acc!");
using DstT = typename GlobalData::RawDType;
using L0cT = typename TileData::DType;
CheckStaticAcc<TileData, GlobalData, true>();
if constexpr (AtomicType::AtomicAdd == atomicType) {
SetAtomicAdd<DstT>();
}
constexpr QuantMode_t quantPre = GetVectorPreQuantModeGm<L0cT, DstT>();
TStoreAccFp<GlobalData, TileData, FpTileData, quantPre, reluPreMode>(
dst.data(), src.data(), fp.data(), dst.GetShape(pto::GlobalTensorDim::DIM_0),
dst.GetShape(pto::GlobalTensorDim::DIM_1), dst.GetShape(pto::GlobalTensorDim::DIM_2),
dst.GetShape(pto::GlobalTensorDim::DIM_3), dst.GetShape(pto::GlobalTensorDim::DIM_4),
dst.GetStride(pto::GlobalTensorDim::DIM_0), dst.GetStride(pto::GlobalTensorDim::DIM_1),
dst.GetStride(pto::GlobalTensorDim::DIM_2), dst.GetStride(pto::GlobalTensorDim::DIM_3),
dst.GetStride(pto::GlobalTensorDim::DIM_4), src.GetValidRow(), src.GetValidCol());
if constexpr (atomicType == AtomicType::AtomicAdd) {
set_atomic_none();
}
}
add cube tile ops(TLOAD>MAT, TSTORE>ACC, TSTORE.MAT, TSTORE_FP)