{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Group GEMM\nThis group gemm kernel launches a fixed number of CTA to compute a group\nof gemms. The scheduling is static and we do it on device.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.\n#\n# Permission is hereby granted, free of charge, to any person obtaining\n# a copy of this software and associated documentation files\n# (the \"Software\"), to deal in the Software without restriction,\n# including without limitation the rights to use, copy, modify, merge,\n# publish, distribute, sublicense, and/or sell copies of the Software,\n# and to permit persons to whom the Software is furnished to do so,\n# subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be\n# included in all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\n# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.\n# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY\n# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,\n# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE\n# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n\nimport torch\n\nimport triton\nimport triton.language as tl\n\nDEVICE = triton.runtime.driver.active.get_active_torch_device()\n\n\n@triton.autotune(\n configs=[\n triton.Config({\n 'BLOCK_SIZE_M': 128,\n 'BLOCK_SIZE_N': 128,\n 'BLOCK_SIZE_K': 32,\n 'NUM_SM': 84,\n }),\n triton.Config({\n 'BLOCK_SIZE_M': 128,\n 'BLOCK_SIZE_N': 128,\n 'BLOCK_SIZE_K': 32,\n 'NUM_SM': 128,\n }),\n triton.Config({\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 64,\n 'BLOCK_SIZE_K': 32,\n 'NUM_SM': 84,\n }),\n triton.Config({\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 64,\n 'BLOCK_SIZE_K': 32,\n 'NUM_SM': 128,\n }),\n ],\n key=['group_size'],\n)\n@triton.jit\ndef grouped_matmul_kernel(\n # device tensor of matrices pointers\n group_a_ptrs,\n group_b_ptrs,\n group_c_ptrs,\n # device tensor of gemm sizes. its shape is [group_size, 3]\n # dim 0 is group_size, dim 1 is the values of of each gemm\n group_gemm_sizes,\n # device tensor of leading dimension sizes. its shape is [group_size, 3]\n # dim 0 is group_size, dim 1 is the values of of each gemm\n g_lds,\n # number of gemms\n group_size,\n # number of virtual SM\n NUM_SM: tl.constexpr,\n # tile sizes\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n):\n tile_idx = tl.program_id(0)\n last_problem_end = 0\n for g in range(group_size):\n # get the gemm size of the current problem\n gm = tl.load(group_gemm_sizes + g * 3)\n gn = tl.load(group_gemm_sizes + g * 3 + 1)\n gk = tl.load(group_gemm_sizes + g * 3 + 2)\n num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)\n num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)\n num_tiles = num_m_tiles * num_n_tiles\n # iterate through the tiles in the current gemm problem\n while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):\n # pick up a tile from the current gemm problem\n k = gk\n lda = tl.load(g_lds + g * 3)\n ldb = tl.load(g_lds + g * 3 + 1)\n ldc = tl.load(g_lds + g * 3 + 2)\n a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))\n b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))\n c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))\n # figure out tile coordinates\n tile_idx_in_gemm = tile_idx - last_problem_end\n tile_m_idx = tile_idx_in_gemm // num_n_tiles\n tile_n_idx = tile_idx_in_gemm % num_n_tiles\n\n # do regular gemm here\n offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]\n b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):\n # hint to Triton compiler to do proper loop pipelining\n tl.multiple_of(a_ptrs, [16, 16])\n tl.multiple_of(b_ptrs, [16, 16])\n # assume full tile for now\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += BLOCK_SIZE_K * ldb\n c = accumulator.to(tl.float16)\n\n offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]\n\n # assumes full tile for now\n tl.store(c_ptrs, c)\n\n # go to the next tile by advancing NUM_SM\n tile_idx += NUM_SM\n\n # get ready to go to the next gemm problem\n last_problem_end = last_problem_end + num_tiles\n\n\ndef group_gemm_fn(group_A, group_B):\n assert len(group_A) == len(group_B)\n group_size = len(group_A)\n\n A_addrs = []\n B_addrs = []\n C_addrs = []\n g_sizes = []\n g_lds = []\n group_C = []\n for i in range(group_size):\n A = group_A[i]\n B = group_B[i]\n assert A.shape[1] == B.shape[0]\n M, K = A.shape\n K, N = B.shape\n C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)\n group_C.append(C)\n A_addrs.append(A.data_ptr())\n B_addrs.append(B.data_ptr())\n C_addrs.append(C.data_ptr())\n g_sizes += [M, N, K]\n g_lds += [A.stride(0), B.stride(0), C.stride(0)]\n\n # note these are device tensors\n d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)\n d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)\n d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)\n d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)\n d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)\n # we use a fixed number of CTA, and it's auto-tunable\n grid = lambda META: (META['NUM_SM'], )\n grouped_matmul_kernel[grid](\n d_a_ptrs,\n d_b_ptrs,\n d_c_ptrs,\n d_g_sizes,\n d_g_lds,\n group_size,\n )\n\n return group_C\n\n\ngroup_m = [1024, 512, 256, 128]\ngroup_n = [1024, 512, 256, 128]\ngroup_k = [1024, 512, 256, 128]\ngroup_A = []\ngroup_B = []\nassert len(group_m) == len(group_n)\nassert len(group_n) == len(group_k)\ngroup_size = len(group_m)\nfor i in range(group_size):\n M = group_m[i]\n N = group_n[i]\n K = group_k[i]\n A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)\n B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)\n group_A.append(A)\n group_B.append(B)\n\ntri_out = group_gemm_fn(group_A, group_B)\nref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]\nfor i in range(group_size):\n assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0)\n\n\n# only launch the kernel, no tensor preparation here to remove all overhead\ndef triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):\n grid = lambda META: (META['NUM_SM'], )\n grouped_matmul_kernel[grid](\n a_ptrs,\n b_ptrs,\n c_ptrs,\n sizes,\n lds,\n group_size,\n )\n\n\ndef torch_perf_fn(group_A, group_B):\n for a, b in zip(group_A, group_B):\n torch.matmul(a, b)\n\n\n@triton.testing.perf_report(\n triton.testing.Benchmark(\n # argument names to use as an x-axis for the plot\n x_names=['N'],\n x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name`\n line_arg='provider',\n # argument name whose value corresponds to a different line in the plot\n # possible values for `line_arg``\n line_vals=['cublas', 'triton'],\n # label name for the lines\n line_names=[\"cuBLAS\", \"Triton\"],\n # line styles\n styles=[('green', '-'), ('blue', '-')],\n ylabel=\"runtime(ms)\", # label name for the y-axis\n plot_name=\"group-gemm-performance\",\n # name for the plot. Used also as a file name for saving the plot.\n args={},\n ))\ndef benchmark(N, provider):\n group_size = 4\n group_A = []\n group_B = []\n A_addrs = []\n B_addrs = []\n C_addrs = []\n g_sizes = []\n g_lds = []\n group_C = []\n for i in range(group_size):\n A = torch.rand((N, N), device=DEVICE, dtype=torch.float16)\n B = torch.rand((N, N), device=DEVICE, dtype=torch.float16)\n C = torch.empty((N, N), device=DEVICE, dtype=torch.float16)\n group_A.append(A)\n group_B.append(B)\n group_C.append(C)\n A_addrs.append(A.data_ptr())\n B_addrs.append(B.data_ptr())\n C_addrs.append(C.data_ptr())\n g_sizes += [N, N, N]\n g_lds += [N, N, N]\n\n d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)\n d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)\n d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)\n d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)\n d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)\n\n quantiles = [0.5, 0.2, 0.8]\n if provider == 'cublas':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(\n lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles)\n return ms, max_ms, min_ms\n\n\nbenchmark.run(show_plots=True, print_data=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
}