Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen, CUDA] Add FP8 Tensor Core Codegen #16950

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

LeiWang1999
Copy link
Contributor

@LeiWang1999 LeiWang1999 commented Apr 29, 2024

Major changes of this pull request:

  • Change the fp8-related test requires_cuda_compute_version from 9 to 8.9 (since sm_89 ada architecture also supports fp8 tensor cores, which is the platform I have tested on).
  • Improve fp8 vector load/store capabilities; previously, TVM only supported float8x4/2/1 load, but this PR introduces support for float8x8/16 load.
  • Refactor the interface of get_mma_intrin_group and get_mma_intrin functions, as the prior implementation assumed that input A and input B were of the same datatype. However, fp8 tensor cores can process combinations like e5m2e5m2, e5m2e4m3, e4m3e4m3, or e4m3e5m2. Note: This change may affect code in MLC that utilizes get_mma_intrin_group.
  • Implement support for fp8 mma code generation and associated tests.

Check out the correctness:

import tvm
from tvm import te
import numpy as np
import tvm.testing
from tvm.script import tir as T
import os
from tvm.tir.tensor_intrin.cuda import (
    get_mma_intrin_group,
    shared_16x16_to_ldmatrix_32x8_layout,
    shared_32x16_to_ldmatrix_32x16_layout,
    shared_16x32_to_ldmatrix_32x16_layout,
)

log_path = "instance/progress/fp8_matmul"
count = 0


def write_code(code, path, fname):
    global count
    # if path not exist, create it
    fname = str(count) + "." + fname
    count += 1
    if not os.path.exists(path):
        os.makedirs(path)
    # join path and fname
    fname = os.path.join(path, fname)
    with open(fname, "w") as f:
        f.write(code)

def write_sch(sch, path, fname):
    py_fname = fname + ".py"
    write_code(sch.mod["main"].script(), path, py_fname)
    cu_fname = fname + ".cu"
    write_code(sch.mod.astext(), path, cu_fname)


M = 1024
N = 1024
K = 1024

BM = 64
BN = 64
BK = 64
warp_size = 32
block_row_warps = 2
block_col_warps = 4

indtype = "e4m3_float8"
out_dtype = "float32"
# indtype = "int8"
# out_dtype = "int32"
intrin_group = get_mma_intrin_group(
    "shared",
    "global",
    a_dtype=indtype,
    b_dtype=indtype,
    out_dtype=out_dtype,
    trans_a=False,
    trans_b=True,
    not_use_mma_store_intrinic=False,
)

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle, c: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, [M, K], dtype=indtype)
        B = T.match_buffer(b, [N, K], dtype=indtype)
        C = T.match_buffer(c, [M, N], dtype=out_dtype)

        for i, j, k in T.grid(M, N, K):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = T.int32(0)
                C[vi, vj] = C[vi, vj] + \
                    A[vi, vk].astype(out_dtype) * B[vj, vk].astype(out_dtype)


ir_module = MyModule
print(ir_module)
sch = tvm.tir.Schedule(ir_module, debug_mask="all")
write_sch(sch, log_path, "original")

block_b = sch.get_block("B")
# C_wrap = sch.cache_write(block_b, 0, "local")
write_sch(sch, log_path, "cache_related")

(i, j, k) = sch.get_loops(block_b)
by, i = sch.split(i, factors=[None, BM])
bx, j = sch.split(j, factors=[None, BN])
bk, k = sch.split(k, factors=[None, BK])

write_sch(sch, log_path, "split_inner_loops")

sch.reorder(by, bx, bk, i, j, k)
write_sch(sch, log_path, "reorder_inner_loops")

sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")

write_sch(sch, log_path, "block_bind")

# currently, we have a sub-problem which size is 128x256x64, to do the computation, we need to use 2 warps in row and 4 warps in column, so the value of thread_x, thread_y, thread_z will be 32, 4, 2, respectively, so the block_row_warps will be 4, the block_col_warps will be 2.

# because each warp will produce a 64x64 matrix, but the size of mma is 16x16, so we need to use 4x4 = 16 mma to do the computation, so the value of warp_row_tiles and warp_col_tiles will be 4, 4, or 2, 8, respectively.

# i->128, block_row_warps->2, result: 64
block_b_tz, block_b_inner_i = sch.split(
    i, factors=[block_row_warps, None])

# j->256, block_col_warps->4, result: 64
block_b_ty, block_b_inner_j = sch.split(
    j, factors=[block_col_warps, None])
# k->64
sch.reorder(block_b_tz, block_b_ty, bk, block_b_inner_i, block_b_inner_j, k)

write_sch(sch, log_path, "split_outer_loops")

sch.bind(block_b_tz, "threadIdx.z")
sch.bind(block_b_ty, "threadIdx.y")

write_sch(sch, log_path, "thread_bind")

# schdule the shared memory

def fetch_to_shared(block, idx):
    block_read = sch.cache_read(block, idx, "shared")
    sch.compute_at(block_read, bk)
    vector_size = 16
    fused = sch.fuse(*sch.get_loops(block_read)[-2:])
    _, f_1, f_2, f_3 = sch.split(
        fused, factors=[None, block_col_warps, warp_size, vector_size])
    sch.bind(f_2, "threadIdx.x")
    sch.bind(f_1, "threadIdx.y")
    sch.vectorize(f_3)
    offset = 0
    sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)

# schedule A
fetch_to_shared(block_b, 0)
# schedule B
fetch_to_shared(block_b, 1)
write_sch(sch, log_path, "shared_memory_schedule")


# blockize for mma tensorize

mma_m = 16
mma_n = 16
mma_k = 32

block_b_inner_i, block_b_inner_i_tc = sch.split(
    block_b_inner_i, factors=[None, mma_m])
block_b_inner_j, block_b_inner_j_tc = sch.split(
    block_b_inner_j, factors=[None, mma_n])
k, k_tc = sch.split(k, factors=[None, mma_k])

sch.reorder(block_b_inner_i, block_b_inner_j,
            k, block_b_inner_i_tc, block_b_inner_j_tc, k_tc)

write_sch(sch, log_path, "mma_tile")

# block_inner = sch.blockize(block_b_inner_i_tc)
# block_outer, block_inner = block_inner, block_b
write_sch(sch, log_path, "blockize")

A_warp = sch.cache_read(block_b, 0, "warp")
B_warp = sch.cache_read(block_b, 1, "warp")
sch.compute_at(A_warp, k)
sch.compute_at(B_warp, k)
C_warp = sch.cache_write(block_b, 0, "warp")
sch.reverse_compute_at(C_warp, block_b_ty)
write_sch(sch, log_path, "cache_read_write_warp")

ii, jj = sch.get_loops(C_warp)[-2:]
io, ii = sch.split(ii, factors=[None, mma_m])
jo, ji = sch.split(jj, factors=[None, mma_n])
sch.reorder(io, jo, ii, ji)


def tile_wmma_fragment(block_read, height, width):
    i, j = sch.get_loops(block_read)[-2:]
    # i0, i1 = sch.split(i, factors=[None, height])
    # j0, j1 = sch.split(j, factors=[None, width])
    # sch.reorder(i0, j0, i1, j1)
    return i

loop_a = tile_wmma_fragment(A_warp, mma_m, mma_k)

loop_b = tile_wmma_fragment(B_warp, mma_n, mma_k)

write_sch(sch, log_path, "tile_fragment")


block_init_c = sch.decompose_reduction(
    block_b, bk)
write_sch(sch, log_path, "decompose_reduction")

def index_map_A(i, j):
    return (
        i // 16,
        j // 32,
        *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32),
    )

def index_map_B(i, j):
    return (
        i // 32,
        j // 16,
        *shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16),
    )

def index_map_C(i, j):
    return (
        i // 16,
        j // 16,
        *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
    )


sch.transform_layout(A_warp, ("write", 0), index_map_A)
sch.transform_layout(B_warp, ("write", 0), index_map_A)
sch.transform_layout(C_warp, ("read", 0), index_map_C)

write_sch(sch, log_path, "transform_layout")

sch.tensorize(loop_a, intrin_group["load_a"])
sch.tensorize(loop_b, intrin_group["load_b"])
write_sch(sch, log_path, "tensorize_ldmatrix")

# _test_block = sch.get_block("")
sch.tensorize(block_b_inner_i_tc, intrin_group["compute"])

sch.tensorize(sch.get_loops(block_init_c)[-2], intrin_group["init"])
sch.tensorize(sch.get_loops(C_warp)[-2], intrin_group["store"])

write_sch(sch, log_path, "tensorize")

ctx = tvm.cuda(0)
cuda_mod = tvm.build(sch.mod, target="cuda")

write_code(cuda_mod.imported_modules[0].get_source(), log_path, "tmp.cu")

def map_numpy_type(intype):
    
    typemap = {
        'e4m3_float8': 'float8_e4m3fn',
        'e5m2_float8': 'float8_e5m2',
    }
    if intype in typemap:
        return typemap[intype]
    else:
        return intype

numpytype_a = map_numpy_type(indtype)
numpytype_b = map_numpy_type(indtype)
numpytype_c = map_numpy_type(out_dtype)
a = np.random.uniform(low=-5, high=5, size=(M*K)).reshape((M, K)).astype(numpytype_a)
b = np.random.uniform(low=-5, high=5, size=(N*K)).reshape((K, N)).astype(numpytype_b)
out = np.matmul(a, b.T)

print("numpy_simulated:", out)

cuda_a = tvm.nd.array(a, ctx)
cuda_b = tvm.nd.array(b, ctx)
cuda_c = tvm.nd.array(np.zeros((M, N)).astype(numpytype_c), ctx)
cuda_mod(cuda_a, cuda_b, cuda_c)

print("codegen:", cuda_c)
num_flops = 2 * M * K * N
num_runs = 1
timer_cuda_mod = cuda_mod.time_evaluator(
    cuda_mod.entry_name, ctx, number=num_runs)

t = timer_cuda_mod(cuda_a, cuda_b, cuda_c).mean

GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." %
      (num_runs, t * 1e3, GFLOPS))

expected output:

numpy_simulated: [[-410.33817   -30.429443 -470.51312  ...   64.58632  -381.49658
    14.920105]
 [  56.357788  744.9746    -29.630783 ...  -44.779022  298.5943
   -24.109558]
 [  77.765305 -426.8894    286.35736  ...   10.655792 -129.63507
   232.30026 ]
 ...
 [  39.094635  -47.508118 -225.59912  ...  775.10614  -109.92264
   268.50952 ]
 [-813.8422    111.21069  -316.5697   ...  455.90875   -37.09839
   478.28406 ]
 [ 122.78345   148.104     340.1291   ... -304.5721   -115.578735
  -639.9563  ]]
codegen: [[-410.28125    -30.441406  -470.09375   ...   64.66406   -381.5
    14.8203125]
 [  56.367188   744.8125     -29.597656  ...  -44.695312   298.625
   -24.148438 ]
 [  77.65625   -426.71875    286.3125    ...   10.746094  -129.6875
   232.34375  ]
 ...
 [  39.191406   -47.539062  -225.57812   ...  774.9375    -109.875
   268.46875  ]
 [-813.625      111.109375  -316.46875   ...  455.96875    -37.08203
   478.0625   ]
 [ 122.75       148.10938    339.84375   ... -304.5       -115.546875
  -639.8125   ]]

Please CC @yzh119

@Hzfengsy
Copy link
Member

cc @vinx13

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants