{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%%shell\n# Installs the latest dev build of TVM from PyPI, with CUDA enabled. To use this,\n# you must request a Google Colab instance with a GPU by going to Runtime ->\n# Change runtime type -> Hardware accelerator -> GPU. If you wish to build from\n# source, see https://tvm.apache.org/docs/install/from_source.html\npip install tlcpack-nightly-cu113 --pre -f https://tlcpack.ai/wheels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Tuning High Performance Convolution on NVIDIA GPUs\n**Author**: [Lianmin Zheng](https://github.com/merrymercy)\n\nThis is an advanced tutorial for writing high performance tunable template for\nNVIDIA GPU. By running auto-tuner on this template, we can outperform the\nvendor provided library CuDNN in many cases.\n\nNote that this tutorial will not run on Windows or recent versions of macOS. To\nget it to run, you will need to wrap the body of this tutorial in a :code:`if\n__name__ == \"__main__\":` block.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies\nTo use autotvm package in tvm, we need to install some extra dependencies.\n(change \"3\" to \"2\" if you use python2):\n\n```bash\npip3 install --user psutil xgboost tornado cloudpickle\n```\nTo make TVM run faster in tuning, it is recommended to use cython\nas FFI of tvm. In the root directory of tvm, execute\n\n```bash\npip3 install --user cython\nsudo make cython3\n```\nNow return to python code. Import packages.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import logging\nimport sys\nimport numpy as np\n\nimport tvm\nfrom tvm import te, topi, testing\nfrom tvm.topi.testing import conv2d_nchw_python\nimport tvm.testing\n\nfrom tvm import autotvm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Define the search space\nThere are plenty of useful schedule primitives in tvm. You can also find\nsome tutorials that describe them in more details, such as\n(1). `opt-conv-gpu`\n(2). [Optimizing DepthwiseConv on NVIDIA GPU](https://tvm.apache.org/2017/08/22/Optimize-Deep-Learning-GPU-Operators-with-TVM-A-Depthwise-Convolution-Example)\n\nHowever, their implementations are manually tuned for some special input\nshapes. In this section, we build a large enough space to cover\nthe techniques used in these tutorials. Then we rely on the efficient auto-tuner\nto search through this space and pick some good configurations.\n\nIf you are familiar with writing cuda schedule, you can find the following\ntemplate is very general. Actually this template can be easily modified\nto tune other operators such as depthwise convolution and GEMM.\nIn order to fully understand this template, you should be familiar with\nthe schedule primitives and auto tuning API. You can refer to the above\ntutorials and `autotvm tutorial `\n\nIt is worth noting that the search space for a conv2d operator\ncan be very large (at the level of 10^9 for some input shapes)\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@autotvm.template(\"tutorial/conv2d_no_batching\")\ndef conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):\n assert N == 1, \"Only consider batch_size = 1 in this template\"\n\n data = te.placeholder((N, CI, H, W), name=\"data\")\n kernel = te.placeholder((CO, CI, KH, KW), name=\"kernel\")\n conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype=\"float32\")\n s = te.create_schedule([conv.op])\n\n ##### space definition begin #####\n n, f, y, x = s[conv].op.axis\n rc, ry, rx = s[conv].op.reduce_axis\n\n cfg = autotvm.get_config()\n cfg.define_split(\"tile_f\", f, num_outputs=4)\n cfg.define_split(\"tile_y\", y, num_outputs=4)\n cfg.define_split(\"tile_x\", x, num_outputs=4)\n cfg.define_split(\"tile_rc\", rc, num_outputs=3)\n cfg.define_split(\"tile_ry\", ry, num_outputs=3)\n cfg.define_split(\"tile_rx\", rx, num_outputs=3)\n cfg.define_knob(\"auto_unroll_max_step\", [0, 512, 1500])\n cfg.define_knob(\"unroll_explicit\", [0, 1])\n ##### space definition end #####\n\n # inline padding\n pad_data = s[conv].op.input_tensors[0]\n s[pad_data].compute_inline()\n data, raw_data = pad_data, data\n\n output = conv\n OL = s.cache_write(conv, \"local\")\n\n # create cache stage\n AA = s.cache_read(data, \"shared\", [OL])\n WW = s.cache_read(kernel, \"shared\", [OL])\n AL = s.cache_read(AA, \"local\", [OL])\n WL = s.cache_read(WW, \"local\", [OL])\n\n # tile and bind spatial axes\n n, f, y, x = s[output].op.axis\n bf, vf, tf, fi = cfg[\"tile_f\"].apply(s, output, f)\n by, vy, ty, yi = cfg[\"tile_y\"].apply(s, output, y)\n bx, vx, tx, xi = cfg[\"tile_x\"].apply(s, output, x)\n kernel_scope = n # this is the scope to attach global config inside this kernel\n\n s[output].bind(bf, te.thread_axis(\"blockIdx.z\"))\n s[output].bind(by, te.thread_axis(\"blockIdx.y\"))\n s[output].bind(bx, te.thread_axis(\"blockIdx.x\"))\n s[output].bind(vf, te.thread_axis(\"vthread\"))\n s[output].bind(vy, te.thread_axis(\"vthread\"))\n s[output].bind(vx, te.thread_axis(\"vthread\"))\n s[output].bind(tf, te.thread_axis(\"threadIdx.z\"))\n s[output].bind(ty, te.thread_axis(\"threadIdx.y\"))\n s[output].bind(tx, te.thread_axis(\"threadIdx.x\"))\n s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)\n s[OL].compute_at(s[output], tx)\n\n # tile reduction axes\n n, f, y, x = s[OL].op.axis\n rc, ry, rx = s[OL].op.reduce_axis\n rco, rcm, rci = cfg[\"tile_rc\"].apply(s, OL, rc)\n ryo, rym, ryi = cfg[\"tile_rx\"].apply(s, OL, ry)\n rxo, rxm, rxi = cfg[\"tile_ry\"].apply(s, OL, rx)\n s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)\n\n s[AA].compute_at(s[OL], rxo)\n s[WW].compute_at(s[OL], rxo)\n s[AL].compute_at(s[OL], rxm)\n s[WL].compute_at(s[OL], rxm)\n\n # cooperative fetching\n for load in [AA, WW]:\n n, f, y, x = s[load].op.axis\n fused = s[load].fuse(n, f, y, x)\n tz, fused = s[load].split(fused, nparts=cfg[\"tile_f\"].size[2])\n ty, fused = s[load].split(fused, nparts=cfg[\"tile_y\"].size[2])\n tx, fused = s[load].split(fused, nparts=cfg[\"tile_x\"].size[2])\n s[load].bind(tz, te.thread_axis(\"threadIdx.z\"))\n s[load].bind(ty, te.thread_axis(\"threadIdx.y\"))\n s[load].bind(tx, te.thread_axis(\"threadIdx.x\"))\n\n # tune unroll\n s[output].pragma(kernel_scope, \"auto_unroll_max_step\", cfg[\"auto_unroll_max_step\"].val)\n s[output].pragma(kernel_scope, \"unroll_explicit\", cfg[\"unroll_explicit\"].val)\n\n return s, [raw_data, kernel, conv]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Search through the space\nWe pick the last layer on resnet as test case.\nSince our space is very large, :code:`XGBoostTuner` is most suitable\nfor our case. Here we only do 20 trials for demonstration.\nIn practice, making 1000 trials usually can find some good kernels\nfor this template\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# logging config (for printing tuning log to screen)\nlogging.getLogger(\"autotvm\").setLevel(logging.DEBUG)\nlogging.getLogger(\"autotvm\").addHandler(logging.StreamHandler(sys.stdout))\n\n# the last layer in resnet\nN, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)\ntask = autotvm.task.create(\n \"tutorial/conv2d_no_batching\", args=(N, H, W, CO, CI, KH, KW, strides, padding), target=\"cuda\"\n)\nprint(task.config_space)\n\n# Use local gpu, measure 10 times for every config to reduce variance\n# The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds\nmeasure_option = autotvm.measure_option(\n builder=autotvm.LocalBuilder(),\n runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4),\n)\n\nrecord_file = None\n# Begin tuning, log records to file `conv2d.log`\n# During tuning we will also try many invalid configs, so you are expected to\n# see many error reports. As long as you can see non-zero GFLOPS, it is okay.\n\n# We do not run the tuning in our webpage server since it takes too long.\n# Uncomment the following lines to run it by yourself.\n\n# tuner = autotvm.tuner.XGBTuner(task)\n# record_file = \"conv2d.log\"\n# tuner.tune(\n# n_trial=5,\n# measure_option=measure_option,\n# callbacks=[autotvm.callback.log_to_file(record_file)],\n# )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally we can inspect the best config from log file, check correctness,\nand measure running time.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# inspect the best config\ndispatch_context = autotvm.apply_history_best(record_file)\nbest_config = dispatch_context.query(task.target, task.workload)\nprint(\"\\nBest config:\")\nprint(best_config)\n\n# apply history best from log file\nwith autotvm.apply_history_best(record_file):\n with tvm.target.Target(\"cuda\"):\n s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)\n func = tvm.build(s, arg_bufs)\n\n# check correctness\na_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)\nw_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)\nc_np = conv2d_nchw_python(a_np, w_np, strides, padding)\n\ndev = tvm.cuda()\na_tvm = tvm.nd.array(a_np, device=dev)\nw_tvm = tvm.nd.array(w_np, device=dev)\nc_tvm = tvm.nd.empty(c_np.shape, device=dev)\nfunc(a_tvm, w_tvm, c_tvm)\n\ntvm.testing.assert_allclose(c_np, c_tvm.numpy(), rtol=1e-2)\n\n# Evaluate running time. Here we choose a large repeat number (400) to reduce the noise\n# and the overhead of kernel launch. You can also use nvprof to validate the result.\nevaluator = func.time_evaluator(func.entry_name, dev, number=400)\nprint(\"Time cost of this operator: %f\" % evaluator(a_tvm, w_tvm, c_tvm).mean)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.20" } }, "nbformat": 4, "nbformat_minor": 0 }