Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SVE] Add support for representing and creating buffer-level predicates #16966

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented May 3, 2024

Representation

This commit extends BufferLoad and BufferStore to accept a predicate mask argument indicating which lanes in a vectorized buffer load/store should be read/written.

As a simple example, we can load all lanes:

tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(1, 8))

Or disable loading all lanes:

tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(0, 8))

In TVMScript, buffer loads and stores are currently displayed using a "short-hand" notation e.g. A[0:4], but there was no clear path for extending this notation to support predicates. Therefore, the vload/vstore notation is used e.g. A.vload([T.Ramp(0, 1, 4)], predicate=...). The TVMScript printer falls back to the vload/vstore notation whenever predicates are specified.

Creation

Buffer-level predication becomes more motivating when combined with the tir.get_active_lane_mask intrinsic. It can be used to mask off lanes when the vectorized axis is not divisible by the vector length. A detailed example and rationale can be found in the RFC.

Predicated buffer load/stores are created in the VectorizeLoop pass via TryPredicateBufferAccesses. This pass aims to convert block-level predicates e.g.

for i_0 in T.serial(4):
    for i_1 in T.vectorized(4):
        if i_0 * 4 + i_1 < 14:
            B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0

to buffer-level predicates, e.g.

for i_0 in T.serial(4):
    predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14)
    A_load = T.meta_var(A.vload([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate))
    B.vstore([T.Ramp(i_0 * 4, 1, 4)], A_load, predicate=predicate)

It takes a conservative approach for now, focussing only on expressions produced by the split scheduling primitive, but more complex expressions could be supported in the future.

TryPredicateBufferAccesses can be explicitly enabled/disabled with the tir.enable_buffer_level_predication pass context option. By default it will be disabled, unless the target supports SVE, in which case it will be enabled by default.

Note: this commit depends on #16965, so also contains the contents of #16965.

Co-authored-by: Elen Kalda elen.kalda@arm.com
Co-authored-by: Neil Hickey neil.hickey@arm.com

@lhutton1 lhutton1 force-pushed the predicated-buffer-load-store-support branch from 784a75a to efb057b Compare May 7, 2024 08:31
@lhutton1 lhutton1 marked this pull request as ready for review May 7, 2024 11:47
@lhutton1
Copy link
Contributor Author

lhutton1 commented May 7, 2024

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the functionality overall. I have a couple of requested changes, which mostly fall under the categories below.

  • TIR edge cases, when Target::Current() may be overridden.
  • Using Optional<PrimExpr> instead of PrimExpr.
  • Validating that !predicate.defined() in any target that does not support it.

include/tvm/script/ir_builder/tir/ir.h Outdated Show resolved Hide resolved
include/tvm/tir/expr.h Outdated Show resolved Hide resolved
include/tvm/tir/stmt.h Outdated Show resolved Hide resolved
python/tvm/tir/buffer.py Outdated Show resolved Hide resolved
python/tvm/tir/expr.py Outdated Show resolved Hide resolved
src/target/llvm/codegen_llvm.cc Show resolved Hide resolved
} else {
load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment),
is_volatile);
}
#elif TVM_LLVM_VERSION >= 80
auto load =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR only adds CreateMaskedLoad when TVM_LLVM_VERSION >= 110. If somebody is using an older version of LLVM, it would silently ignore the predicate for the load/store. We should either support it, or throw an exception.

It looks like CreateMaskedLoad has been supported in LLVM since this commit, so I'd lean toward adding it in the other #elif branches.

Copy link
Contributor Author

@lhutton1 lhutton1 May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, thanks. I've added support for previous versions of LLVM. I've checked the build with the following versions of LLVM: 7*, 8*, 9*, 10, 11, 12, 13, 17

* fails to build due to other seemingly unrelated errors

@@ -700,5 +700,31 @@ def before(a: T.handle):
assert "get.active.lane.mask" in ll


@pytest.mark.skipif(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validates that we have the correct output for architectures that support SVE, but it doesn't test the behavior of other targets that do not (yet) support predicated loads/stores. While the VectorizeLoop pass would only insert a predicated load/store for targets that support it, the predicated load/store could still be generated in hand-written kernels, or through other transforms in the future.

Can we add a test, parametrized over each target tested in CI, which attempts to compile a PrimFunc containing predicated loads/stores? For each target that supports sve, tvm.build should compile without error, and for each target that does not, tvm.build should raise an exception.

Copy link
Contributor Author

@lhutton1 lhutton1 May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - I wasn't able to check all targets locally, so I'm hoping they pass CI here

src/tir/ir/stmt.cc Show resolved Hide resolved
Copy link
Contributor Author

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking the time to review @Lunderberg, I'm working through the comments but wanted to ask a couple of questions / respond to some of the comments before I continue

src/target/llvm/codegen_llvm.cc Show resolved Hide resolved
python/tvm/tir/buffer.py Outdated Show resolved Hide resolved
@@ -72,6 +72,126 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
return Broadcast(e, CreateNewLanes(is_scalable, lanes));
}

bool EnableBufferLevelPredication() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that makes sense. Do you know if there is any general infrastructure for keep track of the current target (which takes into account this override functionality) from within a pass? Otherwise I feel we will be duplicating this functionality in multiple places. I was thinking something similar to: LexicalOnDeviceMixin (assuming I understood it correctly)

@Lunderberg
Copy link
Contributor

No problem, and thank you on the revisions!

Copy link
Contributor

@ekalda ekalda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lhutton1 for all the work on this (it's a lot of work) and @Lunderberg for constructive reviews! I've only got some minor nits.

include/tvm/tir/buffer.h Outdated Show resolved Hide resolved
src/tir/ir/stmt.cc Show resolved Hide resolved
src/tir/transforms/inject_rolling_buffer.cc Outdated Show resolved Hide resolved
@lhutton1 lhutton1 force-pushed the predicated-buffer-load-store-support branch 2 times, most recently from 834ba44 to d8795a0 Compare May 19, 2024 11:32
Copy link
Contributor

@ekalda ekalda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lhutton1 LGTM! 🚀

@lhutton1
Copy link
Contributor Author

@tvm-bot rerun

Representation
--------------
This commit extends `BufferLoad` and `BufferStore` to accept a predicate
mask argument indicating which lanes in a vectorized buffer load/store
should be read/written.

As a simple example, we can load all lanes:
```
tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(1, 8))
```

Or disable loading all lanes:
```
tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(0, 8))
```

In TVMScript, buffer loads and stores are currently displayed using a
"short-hand" notation e.g. `A[0:4]`, but there was no clear path for
extending this notation to support predicates. Therefore, a "long-hand"
notation is introduced e.g. `A.load([T.Ramp(0, 1, 4)], predicate=...)`.
The TVMScript printer falls back to the long-hand notation whenever
predicates are specified.

Creation
--------
Buffer-level predication becomes more motivating when combined with the
`tir.get_active_lane_mask` intrinsic. It can be used to mask off lanes
when the vectorized axis is not divisible by the vector length. A
detailed example and rationale can be found in the
[RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0104-scalable-vectors-in-tir.md#predication).

Predicated buffer load/stores are created in the `VectorizeLoop` pass
via `TryPredicateBufferAccesses`. This pass aims to convert block-level
predicates e.g.
```
for i_0 in T.serial(4):
    for i_1 in T.vectorized(4):
        if i_0 * 4 + i_1 < 14:
            B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0
```
to buffer-level predicates, e.g.
```
for i_0 in T.serial(4):
    predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14)
    A_load = T.meta_var(A.load([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate))
    B.store(A_load, [T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)
```
It takes a conservative approach for now, focussing only on expressions
produced by the split scheduling primitive, but more complex expressions
could be supported in the future.

`TryPredicateBufferAccesses` can be explicitly enabled/disabled with the
`tir.enable_buffer_level_predication` pass context option. By default it
will be disabled, unless the target supports SVE, in which case it will
be enabled by default.

Co-authored-by: Elen Kalda <elen.kalda@arm.com>
Co-authored-by: Neil Hickey <neil.hickey@arm.com>

Change-Id: Idde259a7d7e4536f00ed3a1dafedd0a5d24a1593
Change-Id: I864475c3d03e9b426ce5ef987989216d57f3e019
This includes:
* Taking into account possibility of target being overridden in
  the vectorize pass.
* Predicate PrimExpr -> Optional<PrimExpr>
* Checking that predicate is not used for any target that doesn't
  support it.
* Use vload/vstore API as opposed to load/store
* int1 mask -> uint1 mask for boolean representation. This is converted
  to int1 in the LLVM backend.

Change-Id: I4da0705352e321f6be6333a5bb777caa6a6ca9ef
Change-Id: Idd3f3593fe524f3444487c520d947dfd53386db0
* vload/vstore updates that were missed previously
* int1 -> bool updates
* fix gpu target tests

Fixes a test and updates comments referencing old load/store api

Change-Id: I26a0c480d2dedee442ca0116909a7751d1dfa9ac
- Correct doc strings
- Correct typo in error message
- Add some additional checks for BufferLoad

Change-Id: Ie25563d569c0ed729ac915a6ba3a724a9e191014
Change-Id: I821210665e36c26bfa37fc9ed380b5d03c9e816e
@lhutton1 lhutton1 force-pushed the predicated-buffer-load-store-support branch from d8795a0 to cbd2e48 Compare May 22, 2024 09:04
@lhutton1
Copy link
Contributor Author

friendly ping @Lunderberg if you have some free time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants