New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WebGPU] Support dot4I8Packed(int8x4, int8x4)
as a pure extern method
#16976
base: main
Are you sure you want to change the base?
Conversation
This patch adds the support of `__dp4a(int8x4, int8x4)` as a pure extern method of WebGPU target. In the generated WGSL shader, `int8x4` will be translated into `u32`, and `__dp4a(int8x4, int8x4)` will be translated into the WGSL built-in function `dot4I8Packed(u32, u32)`. Here is an example to use `__dp4a` in WebGPU target: ``` n = te.var("n") A = te.placeholder((n,), "int8x4", name="A") B = te.placeholder((n,), "int8x4", name="B") C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "__dp4a", A[i], B[i]), name="C") s = te.create_schedule(C.op) bx, tx = s[C].split(C.op.axis[0], factor=64) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest") ``` Issue: apache#16627
src/target/source/codegen_webgpu.cc
Outdated
if (op->args.size() != 3) { | ||
LOG(FATAL) << "__dp4a can only accept 2 parameters (now: " << op->args.size() - 1 << ")"; | ||
} else { | ||
os << "dot4I8Packed(" << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given we are calling pure_extern
about we use dot4I8Packed
directly in code? instead of __dp4a
. If there is a builtin high-level intrinsic function tir.dp4a
, we can do the intrinsic translation to this dot4I8Packed
in https://github.com/apache/tvm/blob/main/src/target/source/intrin_rule_webgpu.cc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for your suggestion!
__dp4a(int8x4, int8x4)
as a pure extern methoddot4I8Packed(int8x4, int8x4)
as a pure extern method
@@ -113,6 +113,8 @@ TVM_REGISTER_OP("tir.trunc") | |||
// extra dispatch | |||
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchFastErf); | |||
|
|||
TVM_REGISTER_OP("tir.dot4I8Packed").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchPureExtern<Direct>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry i was not being clear, for tir, it is better to have a common name dp4a (as this intrinsic shared across backends)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add tir.dp4a
intrinsic, and use it to lower to various places
This patch adds the support of
dot4I8Packed(int8x4, int8x4)
as a pure extern method of WebGPU target. In the generated WGSL shader,int8x4
will be translated intou32
, anddot4I8Packed(int8x4, int8x4)
will be translated into the WGSL built-in functiondot4I8Packed(u32, u32)
.Here is an example to use
dot4I8Packed
in WebGPU target:Issue: #16627