Skip to content

feat: first stage of vmi#835

Draft
mouliangyu wants to merge 4 commits into
hw-native-sys:mainfrom
mouliangyu:feature-vmi
Draft

feat: first stage of vmi#835
mouliangyu wants to merge 4 commits into
hw-native-sys:mainfrom
mouliangyu:feature-vmi

Conversation

@mouliangyu

Copy link
Copy Markdown
Contributor

No description provided.

@reedhecre

reedhecre commented Jun 18, 2026

Copy link
Copy Markdown

Codex Review

该评论由 review 机器人自动更新。

  • PR: feat: first stage of vmi #835 feat: first stage of vmi
  • Author: mouliangyu
  • Base/Head: main / feature-vmi
  • Head SHA: 392b0e35e6e2
  • Trigger: PR 有新提交
  • Generated At: 2026-06-22T04:13:24Z
  • Previous Head SHA: 8df1852412db
  • Status: completed

Summary

发现 2 个兼容性问题:--enable-vmi 会让 mixed-backend fatobj 无法编译,且新 VMI pass 的注册遗漏了 ArithDialect 依赖。

Findings

  1. P2 全局 `--enable-vmi` 会把 mixed `pto.backend` fatobj 模式变成不可编译 tools/ptoas/ptoas.cpp:1645

这里把 --enable-vmi 直接绑定成“只要当前编译单元不是 VPTO 就报错”。但 driver 的 mixed-backend 模式会分别为每个子模块启动 EmitCBackendChildJobVPTOBackendChildJob。结果是:一个 fatobj 里只要同时存在需要 VMI 的 VPTO 子模块和普通 EmitC 子模块,就没有可用命令行了。不开 --enable-vmi,VPTO 子模块无法走 VMI pipeline;开了 --enable-vmi,EmitC 子模块会先在这里报错退出,即使它本身根本不含 VMI。这会直接阻断 mixed-backend 场景下把 VMI vector child 和 EmitC child 组合成同一个 fatobj。

  1. P3 VMI pass 注册遗漏 `ArithDialect`,对外部 pass driver 不自洽 include/PTO/Transforms/Passes.td:633

新加的 VMI pass 在 Passes.td 里只声明了 cf/func/pto/memref/scf 依赖,但 VMIToVPTO.cpp 会实际生成大量 arith::* op(最前面就有 arith::ConstantIntOp,后面还有 IndexCastOpAddIOpMaxSIOp 等)。现在之所以本地 lit 能跑,是因为新加的 pto-test-opt.cpp 手工把 ArithDialect 注册进去了;换成别的 mlir-opt 风格驱动、只依赖 pass 注册元数据时,这个 pass 不能保证自己所需的 dialect 已加载。这是新 VMI pipeline 的 pass-registration contract mismatch。

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the design and initial implementation of the VMI (Virtual Machine Interface) dialect, including its operations, types, attributes, layout assignment, and validation passes, along with extensive test coverage. The review feedback highlights several critical issues: potential compilation errors in VMILayoutAssignment.cpp due to missing overloads for MutableOperandRange in requestDataUse and requestMaskUse, a logical bug in rewriteFunctionType where stale return types could be used from firstReturnOperandsByFunc instead of dynamically querying the up-to-date ReturnOp, and a potential compilation issue in VMI.cpp when using dyn_cast_or_null on the mlir::Attribute value type.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +261 to +264

void requestDataUse(OpOperand &operand, VMILayoutAttr layout) {
if (isa<VMIVRegType>(operand.get().getType()))
dataUseRequests.push_back(DataUseRequest{&operand, layout});

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The requestDataUse function is called with MutableOperandRange arguments (e.g., compress.getSourceMutable()), but it is only defined to accept OpOperand &. This will cause a compilation error because MutableOperandRange does not implicitly convert to OpOperand &. Adding an overload that accepts MutableOperandRange and forwards the first element resolves this issue.

  void requestDataUse(OpOperand &operand, VMILayoutAttr layout) {
    if (isa<VMIVRegType>(operand.get().getType()))
      dataUseRequests.push_back(DataUseRequest{&operand, layout});
  }

  void requestDataUse(MutableOperandRange range, VMILayoutAttr layout) {
    if (!range.empty())
      requestDataUse(range[0], layout);
  }

Comment on lines +294 to +305
LogicalResult requestMaskUse(OpOperand &operand, VMILayoutAttr layout,
StringRef granularity, Operation *op) {
if (!isa<VMIMaskType>(operand.get().getType()))
return success();
if (!layout || granularity.empty())
return op->emitError()
<< kVMIDiagLayoutContractPrefix
<< "cannot infer concrete mask use layout or granularity";
maskUseRequests.push_back(
MaskUseRequest{&operand, layout, granularity.str()});
return success();
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To support passing MutableOperandRange directly to requestMaskUse (such as gather.getMaskMutable()), add an overload that accepts MutableOperandRange and forwards the first element.

  LogicalResult requestMaskUse(OpOperand &operand, VMILayoutAttr layout,
                               StringRef granularity, Operation *op) {
    if (!isa<VMIMaskType>(operand.get().getType()))
      return success();
    if (!layout || granularity.empty())
      return op->emitError()
             << kVMIDiagLayoutContractPrefix
             << "cannot infer concrete mask use layout or granularity";
    maskUseRequests.push_back(
        MaskUseRequest{&operand, layout, granularity.str()});
    return success();
  }

  LogicalResult requestMaskUse(MutableOperandRange range, VMILayoutAttr layout,
                               StringRef granularity, Operation *op) {
    if (range.empty())
      return success();
    return requestMaskUse(range[0], layout, granularity, op);
  }

Comment on lines +1260 to +1263
auto it = firstReturnOperandsByFunc.find(func);
if (it != firstReturnOperandsByFunc.end()) {
for (Value operand : it->second)
results.push_back(operand.getType());

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In rewriteFunctionType, reading the return types from firstReturnOperandsByFunc can lead to stale types. If any return operand was replaced or materialized during insertDataUseMaterializations or insertMaskUseMaterializations (via request.operand->set(...)), the ReturnOp's operands will point to new values, but firstReturnOperandsByFunc will still hold the old values. This causes a mismatch between the rewritten function signature and the actual returned values. Dynamically walking the function to find the first ReturnOp ensures we always use the up-to-date types.

      func::ReturnOp returnOp;
      func.walk([&](func::ReturnOp op) {
        if (!returnOp)
          returnOp = op;
      });
      if (returnOp) {
        for (Value operand : returnOp.getOperands())
          results.push_back(operand.getType());
      } else {

Comment thread lib/PTO/IR/VMI.cpp
Comment on lines +136 to +138
auto layoutAttr = dyn_cast_or_null<VMILayoutAttr>(layout);
if (!layoutAttr)
return failure();

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using dyn_cast_or_null on a value type like mlir::Attribute can cause compilation issues in some LLVM/MLIR versions because it expects a pointer type. It is safer and more idiomatic to check if the attribute is non-null first, and then use dyn_cast.

Suggested change
auto layoutAttr = dyn_cast_or_null<VMILayoutAttr>(layout);
if (!layoutAttr)
return failure();
if (!layout)
return failure();
auto layoutAttr = dyn_cast<VMILayoutAttr>(layout);

@mouliangyu mouliangyu marked this pull request as draft June 18, 2026 01:40

@zhangstevenunity zhangstevenunity left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the core of this first VMI stage — the dialect (VMI.cpp, VMITypeDefs/VMIOps.td), VMILayoutAssignment, and VMIToVPTO — against the VPTO contracts.

The design holds up well: the closed-set layout model (contiguous / deinterleaved={2,4}) with a precise lane-map, the union-find layout assignment that hard-errors on genuine conflicts, and the two-tier "permissive verifier + strict lowering-capability gate" all read as deliberate, and unsupported cases are generally bailed explicitly rather than silently miscompiled.

Three concrete correctness issues below, all on currently-untested paths:

  1. addWhileConstraints mis-aligns the two scf.while carry groups.
  2. truncf hardcodes rnd="R", emitting an illegal pto.vcvt for f32 -> HiFloat8.
  3. Float iota builds a FloatAttr with mismatched (IEEEdouble) semantics.

(Layout-coalescing cost in loops, integer width conversion, and the histogram/half layout appear to be deliberately out of first-stage scope per the design docs, so not raised here.)

if (index < beforeArgs.size() &&
failed(uniteEquivalentValues(anchor, beforeArgs[index], whileOp)))
return failure();
if (conditionOp && index < conditionOp.getArgs().size() &&

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addWhileConstraints mis-aligns the two independent scf.while carry groups.

scf.while has two separate carries:

  • before group: inits[i] == beforeArgs[i] == the after-region scf.yield operand i
  • after group: scf.condition operand j == afterArg[j] == result[j]

These lists can differ in arity and element type (the before region computes the forwarded scf.condition operands). This loop anchors all five lists on the single inits-indexed anchor, so conditionOp.getArgs()[index] (L864), afterBlock.getArgument(index) (L868) and whileOp.getResult(index) (L876) get united against the wrong carry group. For any non-pass-through scf.while this yields a spurious "conflicting layouts" error, or silently forces one carry's layout onto an unrelated value. Only the degenerate 1-element pass-through is tested (vmi_layout_assignment_scf_while.pto).

Suggest two independent unions: the before group over {inits[i], beforeArgs[i], yield.getOperand(i)}, and the after group over {conditionOp.getArgs()[j], afterArgs[j], results[j]} — the after group must not be anchored on inits.

return rewriter.notifyMatchFailure(op,
"failed to build truncf masks");

StringAttr rnd = rewriter.getStringAttr("R");

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

truncf hardcodes rnd = "R", but the 8-bit (P0-P3) branch is shared by fp8 and HiFloat8, and the pto.vcvt contract for f32 -> HiFloat8 requires rnd in {A, H} (lookupVcvtContract(F32, HiF8) -> allowedRnd="AH", enforced by VcvtOp::verify).

VMITruncFOp::verify accepts an hif8 result and checkSupportedTruncFShape checks only result bit-width/factor (not element type), so pto.vmi.truncf : ...f32 -> ...hif8 (deinterleaved=4 source) passes every VMI gate and then emits a pto.vcvt {rnd="R"} that fails its own verifier. No test covers f32 -> hif8 (existing truncf tests are f32->f16 / f32->fp8E*).

Suggest selecting rnd from the target type (R for f8E4M3/f8E5M2/f16/bf16; A or H for HiFloat8), or rejecting HiFloat8 in checkSupportedTruncFShape with a clear unsupported error.

Comment thread lib/PTO/Transforms/VMIToVPTO.cpp Outdated
if (auto floatType = dyn_cast<FloatType>(type)) {
return rewriter
.create<arith::ConstantOp>(
loc, FloatAttr::get(floatType,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::APFloat(static_cast<double>(value)) carries IEEEdouble semantics, but floatType here is f16/bf16/f32. FloatAttr::get(Type, APFloat) requires the value's semantics to match the type (verifyInvariants checks &type.getFloatSemantics() == &value.getSemantics()), so this asserts in an asserts-enabled build (and yields an invalid attribute otherwise). Reachable for any non-f64 float iota with a nonzero lane offset (e.g. vmi_to_vpto_iota_f16_deint2_asc) and via the factor/zero constants in createIotaDeinterleavedChunk.

Use the semantics-aware overload FloatAttr::get(floatType, static_cast<double>(value)) (which converts to the type's semantics), or build APFloat(floatType.getFloatSemantics(), ...) and round.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants