feat: first stage of vmi#835
Conversation
Codex Review该评论由 review 机器人自动更新。
Summary发现 2 个兼容性问题: Findings
这里把
新加的 VMI pass 在 |
There was a problem hiding this comment.
Code Review
This pull request introduces the design and initial implementation of the VMI (Virtual Machine Interface) dialect, including its operations, types, attributes, layout assignment, and validation passes, along with extensive test coverage. The review feedback highlights several critical issues: potential compilation errors in VMILayoutAssignment.cpp due to missing overloads for MutableOperandRange in requestDataUse and requestMaskUse, a logical bug in rewriteFunctionType where stale return types could be used from firstReturnOperandsByFunc instead of dynamically querying the up-to-date ReturnOp, and a potential compilation issue in VMI.cpp when using dyn_cast_or_null on the mlir::Attribute value type.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
|
||
| void requestDataUse(OpOperand &operand, VMILayoutAttr layout) { | ||
| if (isa<VMIVRegType>(operand.get().getType())) | ||
| dataUseRequests.push_back(DataUseRequest{&operand, layout}); |
There was a problem hiding this comment.
The requestDataUse function is called with MutableOperandRange arguments (e.g., compress.getSourceMutable()), but it is only defined to accept OpOperand &. This will cause a compilation error because MutableOperandRange does not implicitly convert to OpOperand &. Adding an overload that accepts MutableOperandRange and forwards the first element resolves this issue.
void requestDataUse(OpOperand &operand, VMILayoutAttr layout) {
if (isa<VMIVRegType>(operand.get().getType()))
dataUseRequests.push_back(DataUseRequest{&operand, layout});
}
void requestDataUse(MutableOperandRange range, VMILayoutAttr layout) {
if (!range.empty())
requestDataUse(range[0], layout);
}| LogicalResult requestMaskUse(OpOperand &operand, VMILayoutAttr layout, | ||
| StringRef granularity, Operation *op) { | ||
| if (!isa<VMIMaskType>(operand.get().getType())) | ||
| return success(); | ||
| if (!layout || granularity.empty()) | ||
| return op->emitError() | ||
| << kVMIDiagLayoutContractPrefix | ||
| << "cannot infer concrete mask use layout or granularity"; | ||
| maskUseRequests.push_back( | ||
| MaskUseRequest{&operand, layout, granularity.str()}); | ||
| return success(); | ||
| } |
There was a problem hiding this comment.
To support passing MutableOperandRange directly to requestMaskUse (such as gather.getMaskMutable()), add an overload that accepts MutableOperandRange and forwards the first element.
LogicalResult requestMaskUse(OpOperand &operand, VMILayoutAttr layout,
StringRef granularity, Operation *op) {
if (!isa<VMIMaskType>(operand.get().getType()))
return success();
if (!layout || granularity.empty())
return op->emitError()
<< kVMIDiagLayoutContractPrefix
<< "cannot infer concrete mask use layout or granularity";
maskUseRequests.push_back(
MaskUseRequest{&operand, layout, granularity.str()});
return success();
}
LogicalResult requestMaskUse(MutableOperandRange range, VMILayoutAttr layout,
StringRef granularity, Operation *op) {
if (range.empty())
return success();
return requestMaskUse(range[0], layout, granularity, op);
}| auto it = firstReturnOperandsByFunc.find(func); | ||
| if (it != firstReturnOperandsByFunc.end()) { | ||
| for (Value operand : it->second) | ||
| results.push_back(operand.getType()); |
There was a problem hiding this comment.
In rewriteFunctionType, reading the return types from firstReturnOperandsByFunc can lead to stale types. If any return operand was replaced or materialized during insertDataUseMaterializations or insertMaskUseMaterializations (via request.operand->set(...)), the ReturnOp's operands will point to new values, but firstReturnOperandsByFunc will still hold the old values. This causes a mismatch between the rewritten function signature and the actual returned values. Dynamically walking the function to find the first ReturnOp ensures we always use the up-to-date types.
func::ReturnOp returnOp;
func.walk([&](func::ReturnOp op) {
if (!returnOp)
returnOp = op;
});
if (returnOp) {
for (Value operand : returnOp.getOperands())
results.push_back(operand.getType());
} else {| auto layoutAttr = dyn_cast_or_null<VMILayoutAttr>(layout); | ||
| if (!layoutAttr) | ||
| return failure(); |
There was a problem hiding this comment.
Using dyn_cast_or_null on a value type like mlir::Attribute can cause compilation issues in some LLVM/MLIR versions because it expects a pointer type. It is safer and more idiomatic to check if the attribute is non-null first, and then use dyn_cast.
| auto layoutAttr = dyn_cast_or_null<VMILayoutAttr>(layout); | |
| if (!layoutAttr) | |
| return failure(); | |
| if (!layout) | |
| return failure(); | |
| auto layoutAttr = dyn_cast<VMILayoutAttr>(layout); |
zhangstevenunity
left a comment
There was a problem hiding this comment.
Reviewed the core of this first VMI stage — the dialect (VMI.cpp, VMITypeDefs/VMIOps.td), VMILayoutAssignment, and VMIToVPTO — against the VPTO contracts.
The design holds up well: the closed-set layout model (contiguous / deinterleaved={2,4}) with a precise lane-map, the union-find layout assignment that hard-errors on genuine conflicts, and the two-tier "permissive verifier + strict lowering-capability gate" all read as deliberate, and unsupported cases are generally bailed explicitly rather than silently miscompiled.
Three concrete correctness issues below, all on currently-untested paths:
addWhileConstraintsmis-aligns the twoscf.whilecarry groups.truncfhardcodesrnd="R", emitting an illegalpto.vcvtforf32 -> HiFloat8.- Float
iotabuilds aFloatAttrwith mismatched (IEEEdouble) semantics.
(Layout-coalescing cost in loops, integer width conversion, and the histogram/half layout appear to be deliberately out of first-stage scope per the design docs, so not raised here.)
| if (index < beforeArgs.size() && | ||
| failed(uniteEquivalentValues(anchor, beforeArgs[index], whileOp))) | ||
| return failure(); | ||
| if (conditionOp && index < conditionOp.getArgs().size() && |
There was a problem hiding this comment.
addWhileConstraints mis-aligns the two independent scf.while carry groups.
scf.while has two separate carries:
- before group:
inits[i]==beforeArgs[i]== the after-regionscf.yieldoperandi - after group:
scf.conditionoperandj==afterArg[j]==result[j]
These lists can differ in arity and element type (the before region computes the forwarded scf.condition operands). This loop anchors all five lists on the single inits-indexed anchor, so conditionOp.getArgs()[index] (L864), afterBlock.getArgument(index) (L868) and whileOp.getResult(index) (L876) get united against the wrong carry group. For any non-pass-through scf.while this yields a spurious "conflicting layouts" error, or silently forces one carry's layout onto an unrelated value. Only the degenerate 1-element pass-through is tested (vmi_layout_assignment_scf_while.pto).
Suggest two independent unions: the before group over {inits[i], beforeArgs[i], yield.getOperand(i)}, and the after group over {conditionOp.getArgs()[j], afterArgs[j], results[j]} — the after group must not be anchored on inits.
| return rewriter.notifyMatchFailure(op, | ||
| "failed to build truncf masks"); | ||
|
|
||
| StringAttr rnd = rewriter.getStringAttr("R"); |
There was a problem hiding this comment.
truncf hardcodes rnd = "R", but the 8-bit (P0-P3) branch is shared by fp8 and HiFloat8, and the pto.vcvt contract for f32 -> HiFloat8 requires rnd in {A, H} (lookupVcvtContract(F32, HiF8) -> allowedRnd="AH", enforced by VcvtOp::verify).
VMITruncFOp::verify accepts an hif8 result and checkSupportedTruncFShape checks only result bit-width/factor (not element type), so pto.vmi.truncf : ...f32 -> ...hif8 (deinterleaved=4 source) passes every VMI gate and then emits a pto.vcvt {rnd="R"} that fails its own verifier. No test covers f32 -> hif8 (existing truncf tests are f32->f16 / f32->fp8E*).
Suggest selecting rnd from the target type (R for f8E4M3/f8E5M2/f16/bf16; A or H for HiFloat8), or rejecting HiFloat8 in checkSupportedTruncFShape with a clear unsupported error.
| if (auto floatType = dyn_cast<FloatType>(type)) { | ||
| return rewriter | ||
| .create<arith::ConstantOp>( | ||
| loc, FloatAttr::get(floatType, |
There was a problem hiding this comment.
llvm::APFloat(static_cast<double>(value)) carries IEEEdouble semantics, but floatType here is f16/bf16/f32. FloatAttr::get(Type, APFloat) requires the value's semantics to match the type (verifyInvariants checks &type.getFloatSemantics() == &value.getSemantics()), so this asserts in an asserts-enabled build (and yields an invalid attribute otherwise). Reachable for any non-f64 float iota with a nonzero lane offset (e.g. vmi_to_vpto_iota_f16_deint2_asc) and via the factor/zero constants in createIotaDeinterleavedChunk.
Use the semantics-aware overload FloatAttr::get(floatType, static_cast<double>(value)) (which converts to the type's semantics), or build APFloat(floatType.getFloatSemantics(), ...) and round.
No description provided.