Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU] Support dot4I8Packed(int8x4, int8x4) as a pure extern method #16976

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Jiawei-Shao
Copy link
Contributor

@Jiawei-Shao Jiawei-Shao commented May 8, 2024

This patch adds the support of dot4I8Packed(int8x4, int8x4) as a pure extern method of WebGPU target. In the generated WGSL shader, int8x4 will be translated into u32, and dot4I8Packed(int8x4, int8x4) will be translated into the WGSL built-in function dot4I8Packed(u32, u32).

Here is an example to use dot4I8Packed in WebGPU target:

n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "dot4I8Packed", A[i], B[i]), name="C")
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")

Issue: #16627

This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure
extern method of WebGPU target. In the generated WGSL shader,
`int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)`
will be translated into the WGSL built-in function
`dot4I8Packed(u32, u32)`.

Here is an example to use `__dp4a` in WebGPU target:

```
n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i]), name="C")
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")
```

Issue: apache#16627
if (op->args.size() != 3) {
LOG(FATAL) << "__dp4a can only accept 2 parameters (now: " << op->args.size() - 1 << ")";
} else {
os << "dot4I8Packed(" << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given we are calling pure_extern about we use dot4I8Packed directly in code? instead of __dp4a. If there is a builtin high-level intrinsic function tir.dp4a, we can do the intrinsic translation to this dot4I8Packed in https://github.com/apache/tvm/blob/main/src/target/source/intrin_rule_webgpu.cc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks for your suggestion!

@Jiawei-Shao Jiawei-Shao changed the title [WebGPU] Support __dp4a(int8x4, int8x4) as a pure extern method [WebGPU] Support dot4I8Packed(int8x4, int8x4) as a pure extern method May 9, 2024
@@ -113,6 +113,8 @@ TVM_REGISTER_OP("tir.trunc")
// extra dispatch
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchFastErf);

TVM_REGISTER_OP("tir.dot4I8Packed").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<Direct>);
Copy link
Member

@tqchen tqchen May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry i was not being clear, for tir, it is better to have a common name dp4a (as this intrinsic shared across backends)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I find there is no tir.dp4a in TVM right now, and I see in TVM dp4a is all called through call_pure_extern(): vulkan cuda

Do you mean we add tir.dp4a in TVM or still support dp4a as a pure external call like what dp4a is supported in codegen_spirv.cc?

Copy link
Member

@tqchen tqchen May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add tir.dp4a intrinsic, and use it to lower to various places

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants