Close Menu
    Facebook LinkedIn YouTube WhatsApp X (Twitter) Pinterest
    Trending
    • Francis Bacon and the Scientific Method
    • Proxy-Pointer RAG: Structure Meets Scale at 100% Accuracy with Smarter Retrieval
    • Sulfur lava exoplanet L 98-59 d defies classification
    • Hisense U7SG TV Review (2026): Better Design, Great Value
    • Google is in talks with Marvell Technology to develop a memory processing unit that works alongside TPUs, and a new TPU for running AI models (Qianer Liu/The Information)
    • Premier League Soccer: Stream Man City vs. Arsenal From Anywhere Live
    • Dreaming in Cubes | Towards Data Science
    • Onda tiny house flips layout to fit three bedrooms and two bathrooms
    Facebook LinkedIn WhatsApp
    Times FeaturedTimes Featured
    Sunday, April 19
    • Home
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    • More
      • AI
      • Robotics
      • Industries
      • Global
    Times FeaturedTimes Featured
    Home»Artificial Intelligence»Learning Triton One Kernel at a Time: Softmax
    Artificial Intelligence

    Learning Triton One Kernel at a Time: Softmax

    Editor Times FeaturedBy Editor Times FeaturedNovember 23, 2025No Comments11 Mins Read
    Facebook Twitter Pinterest Telegram LinkedIn Tumblr WhatsApp Email
    Share
    Facebook Twitter LinkedIn Pinterest Telegram Email WhatsApp Copy Link


    Within the previous article of this sequence, operation in all fields of laptop science: matrix multiplication. It’s closely utilized in neural networks to compute the activation of linear layers. Nevertheless, activations on their very own are tough to interpret, since their values and statistics (imply, variance, min-max amplitude) can fluctuate wildly from layer to layer. This is without doubt one of the explanation why we use activation capabilities, for instance the logistic perform (aka sigmoid) which initiatives any actual quantity within the [0; 1] vary.

    The softmax perform, often known as the normalised exponential perform, is a multi-dimensional generalisation of the sigmoid. It converts a vector of uncooked scores (logits) right into a likelihood distribution over M courses. We are able to interpret it as a weighted common that behaves as a easy perform and might be conveniently differentiated. It’s a essential element of dot-product consideration, language modeling, and multinomial logistic regression.

    On this article, we’ll cowl:

    1. Implementing an environment friendly softmax kernel in Triton.
    2. Implementing the backward cross (autograd).
    3. Optimisation: cache modifiers and auto-tuning.

    Should you aren’t aware of Triton but, confer with the earlier articles!

    Disclaimer: all of the illustrations and animations are made by the writer except specified in any other case.

    Definition

    The softmax is outlined as follows:

    The normalisation ensures that the vector sums to 1, in order that it may be interpreted as a sound likelihood distribution.

    Observe that this formulation of the softmax is extremely delicate to numerical overflow. Recall that the utmost worth a normal float16 can symbolize is 65 504, which is roughly exp(11). Because of this any enter worth higher than ~11 will lead to exp(z_i) exceeding the representable vary, resulting in overflow.

    A standard trick to mitigate this concern is to subtract the utmost worth of the enter vector from each aspect, such that the brand new most is 0 earlier than exponentiation and 1 after.

    Naive Implementation

    As you possibly can see, computing the softmax includes two discount operations, a max and a sum. A naive algorithm require three separate passes over the enter vector. First to compute the utmost, then the sum, and at last the normalised outputs.

    Right here’s what a naive Numpy implementation appears like:

    A recurrent theme on this Triton sequence is minimising high-latency international reminiscence entry. Our present Numpy implementation requires three separate reminiscence reads of the complete enter vector, which is extremely inefficient.

    On-line Softmax

    Happily, we will use a intelligent trick, often known as the on-line softmax, to fuse the max and sum steps, lowering the variety of reminiscence reads to 2.

    First, we outline the sum of exponentials recursively. Within the following set of equalities, m_i refers back to the most over x till the i-th index.

    This equality permits us to compute the sum of exponentials iteratively utilizing the utmost worth thus far. We are able to leverage it to fuse the primary and second loop within the naive implementation and compute the utmost and sum of exponentials iteratively.

    Our algorithm turns into:

    That is simply translated to Numpy:

    Now that we perceive the primary rules behind the softmax, we’ll implement it in Triton, beginning by the straightforward, single-block model and constructing as much as the net, multi-block formulation. Ultimately, we wish our kernel to behave like a PyTorch module and be suitable with autograd.

    Sadly, from PyTorch’s viewpoint, Triton kernels behave like black containers: the operations they carry out are usually not traced by autograd. This requires us to implement the backward cross ourselves and explicitly specify how gradients must be computed. Let’s brush up on our beloved chain rule and derive the softmax gradient.

    Gradient

    For the reason that outputs of the softmax are strictly optimistic, we will use the logarithmic by-product to make the derivation of the gradient simpler. Right here, we take the by-product of the log of the output and apply the chain rule:

    From there, we rearrange the phrases and comply with these steps:

    Now assume that we’ve got some upstream gradient, for instance generated by a loss perform L (e.g. a cross-entropy loss). We get the next expression of the gradient:

    The simplification of the left time period in (9) is because of the truth that δ_ij will solely be equal to 1 for the i-th aspect, collapsing the sum over j to a single time period.

    Triton Implementation

    Single Block Softmax

    Now that we labored by the derivation of the gradient, we will write the ahead and backward softmax kernels. First, let’s concentrate on the PyTorch wrapper to know how the one block implementation works at a excessive stage. Given a 2D enter tensor, the ahead and backward kernels are going to course of all rows in parallel.

    For simplicity, we’ll outline the BLOCK_SIZE to be massive sufficient to deal with all columns without delay. Particularly, we’ll set it as the subsequent energy of two superior to the variety of columns, as required by Triton.

    Then, we’ll outline our `grid` to be the variety of rows (it might probably additionally deal with a batch dimension).

    The PyTorch wrapper for our SoftmaxSingleBlock is a category inheriting from torch.autograd.Operate that implements ahead and backward. Each strategies take a ctx argument, which we’ll use to cache the softmax outputs in the course of the ahead cross and reuse them in the course of the backward cross.

    Each kernels are fairly easy, we begin by loading the row inputs utilizing the identical syntax as in my earlier vector addition article. Discover that BLOCK_SIZE and num_warps are computed utilizing a calculate_settings perform. This perform comes from the Unsloth library and was reused in different kernel libraries corresponding to LigerKernel (which the kernels on this article are loosely primarily based on), it offers heuristics to tune each variables:

    def calculate_settings(n: int) -> tuple[int, int]:
     MAX_FUSED_SIZE = 65536 # most grid dimension on Nvidia GPUs
        BLOCK_SIZE = next_power_of_2(n)
        if BLOCK_SIZE > MAX_FUSED_SIZE:
            # we take away this assertion on this article
            increase RuntimeError(
                f"Can't launch Triton kernel since n = {n} exceeds "
                f"the utmost CUDA blocksize = {MAX_FUSED_SIZE}."
            )
        num_warps = 4
        if BLOCK_SIZE >= 32768:
            num_warps = 32
        elif BLOCK_SIZE >= 8192:
            num_warps = 16
        elif BLOCK_SIZE >= 2048:
            num_warps = 8
        return BLOCK_SIZE, num_warps

    Then, we implement the common softmax for the ahead cross and equation (10) for the backward cross. The one novelty right here in comparison with earlier articles is using cache modifiers, which inform the compiler tips on how to cache and evict knowledge. For now, we’ll solely concentrate on three cache modifiers:

    • .ca (Cache in any respect ranges): Tells the compiler to load the info in each L1 and L2 cache, suggesting that it is perhaps reused quickly. This modifier must be used when the info is sufficiently small to suit into L1 (~128–192KB per SM on an A100) and can probably be accessed repeatedly.
    • .cs (Streaming): Deal with knowledge as streaming, it is going to be used as soon as after which discarded to unlock area in L1.
    • .wb (Write-back): Regular cached write, the info will stay within the cache hierarchy, good if the output could also be reused.

    Within the following kernels, we’ll use the .ca modifier for hundreds since we carry out a number of operations on the loaded knowledge. For storing, we’ll use .cs within the ahead cross, for the reason that outputs received’t be instantly reused and .wb within the backward cross since within the context of autograd (i.e. the chain rule), gradient outputs shall be consumed by downstream kernels.

    Multi-Block Softmax

    Now, let’s check out the net formulation of the softmax. On this part, we implement a multi-block variant of the earlier kernel. This model will use BLOCK_SIZE < n_cols, in different phrases, we’ll solely load a tile with BLOCK_SIZE components at a time, much like how we dealt with tiled GEMM within the last tutorial. Now you may ask “how can we choose the block dimension?”. 

    This can be a nice event to introduce Triton’s autotune utility. Supplied with a listing of configuration, autotune will carry out a grid-search to find out and cache one of the best configuration for a particular enter form. This course of is repeated each time a brand new enter form is handed to the kernel.

    Right here, we carry out a grid search over the block dimension and variety of warps utilizing the next utility perform:

    from itertools import product
    
    # --- Multi Block Tuning ---
    BLOCK_SIZES = [256, 512, 1024, 2048, 4096, 8192]
    NUM_WARPS = [2, 4, 8, 16]
    
    def get_autotune_config(
        block_sizes: checklist[int], num_warps: checklist[int]
    ) -> checklist[triton.Config]:
        return [
            triton.Config(kwargs={"BLOCK_SIZE": bs}, num_warps=nw)
            for (bs, nw) in list(product(block_sizes, num_warps))
        ]

    We are able to now enhance our multi-block kernels with autotune and cross the checklist of configs, key=”n_cols” signifies that the optimum config relies on the variety of columns of the enter.

    The implementation of those kernels is conceptually very near the net softmax we coated earlier than, the primary variations is that we iterate over tiles (not over single components like in Numpy), which requires some changes. As an illustration, we add a sum over the tile within the d replace and the backward kernel now requires two iterations as nicely.

    Observe: the PyTorch wrapper is strictly the identical besides we delete the road the place BLOCK_SIZE and num_warps are declared (since they’re picked by autotune).

    Testing and Benchmarking

    We are able to now execute a ahead and backward cross with each kernels and guarantee they match the PyTorch baselines:

    def validate_kernel(kernel_fn: callable) -> None:
        machine = "cuda:0" if torch.cuda.is_available() else "cpu"
        torch.random.manual_seed(0)
    
        # Generate inputs
        x = torch.randn((256, 512), machine=machine) # triton enter
        x.requires_grad = True
        xt = deepcopy(x) # torch enter
    
        triton_output = kernel_fn(x)
        torch_output = torch.softmax(xt, dim=1)
        torch.testing.assert_close(triton_output, torch_output) # take a look at fwd kernel
    
        # Setup pretend labels
        y = torch.zeros_like(x)
        inds = (torch.arange(0, y.form[0]), torch.randint(0, 3, (y.form[0],)))
        y[inds] = 1
    
        # Outline loss and run backward cross
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(torch_output, y)
        loss.backward()
    
        # Save gradient tensor for later
        torch_xgrad = xt.grad.detach().clone()
        triton_loss = loss_fn(triton_output, y)
        triton_loss.backward()
        torch.testing.assert_close(x.grad, torch_xgrad) # take a look at grad outputs
    
    validate_kernel(softmax_sb)
    validate_kernel(softmax_mb)

    Lastly, we benchmark our implementation towards the PyTorch baseline utilizing the next snippet:

    # --- Supply: Triton softmax tutorial ---
    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["N"],  # argument names to make use of as an x-axis for the plot
            x_vals=[
                128 * i for i in range(2, 100)
            ],  # totally different doable values for `x_name`
            line_arg="supplier",  # argument title whose worth corresponds to a distinct line within the plot
            line_vals=[
                "triton_single_block",
                "triton_multi_block",
                "torch",
            ],  # doable values for `line_arg``
            line_names=[
                "Triton_single_block",
                "Triton_multi_block",
                "Torch",
            ],  # label title for the traces
            kinds=[("blue", "-"), ("green", "-"), ("red", "-")],
            ylabel="GB/s",  # label title for the y-axis
            plot_name="softmax-performance",  # title for the plot. Used additionally as a file title for saving the plot.
            args={"M": 4096},  # values for perform arguments not in `x_names` and `y_name`
        )
    )
    def benchmark(M, N, supplier):
        x = torch.randn(M, N, machine=DEVICE, dtype=torch.float32)
        stream = getattr(torch, DEVICE.sort).Stream()
        getattr(torch, DEVICE.sort).set_stream(stream)
        if supplier == "torch":
            ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
        if supplier == "triton_single_block":
            torch.cuda.synchronize()
            ms = triton.testing.do_bench(lambda: softmax_sb(x))
            torch.cuda.synchronize()
        if supplier == "triton_multi_block":
            torch.cuda.synchronize()
            ms = triton.testing.do_bench(lambda: softmax_mb(x))
            torch.cuda.synchronize()
        gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
        return gbps(ms)
    
    benchmark.run(show_plots=True, print_data=True)

    Excellent news! Our single-block kernel constantly outperforms the PyTorch baseline whereas the multi-block variant falls off for inputs with greater than 6k columns:

    Contemplating bigger inputs, we will make a number of observations:

    1. The multi-block kernel ultimately stabilises round 900GB/s of throughput, surpassing the PyTorch baseline for inputs with greater than 30k columns. 
    2. Curiously, it looks like the multi-block variant will dominate for inputs with greater than 60k columns.
    3.  Though we exceed the utmost block dimension with the single-block variant, the kernel nonetheless runs easily for some purpose. Certainly, Triton routinely manages the block dimension beneath the hood. 
      When n_cols is bigger than the {hardware} restrict, Triton will break down the enter and iterate over it. Nevertheless, this appears to be slower than the multi-block method. 

    To go additional, we might mix each approaches in a single kernel that explicitly selects the optimum kernel primarily based on the enter dimension. This manner, we might profit from the excessive efficiency of the single-block kernel for small inputs and the upper throughput of the multi-block variant for inputs with greater than 60k columns.

    This concludes the third episode of this Triton sequence, thanks once more in your help!

    Within the subsequent article, we’ll leverage the net softmax formulation within the context of Flash Consideration.

    Till subsequent time! 👋

    Assets:



    Source link

    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Editor Times Featured
    • Website

    Related Posts

    Proxy-Pointer RAG: Structure Meets Scale at 100% Accuracy with Smarter Retrieval

    April 19, 2026

    Dreaming in Cubes | Towards Data Science

    April 19, 2026

    AI Agents Need Their Own Desk, and Git Worktrees Give Them One

    April 18, 2026

    Your RAG System Retrieves the Right Data — But Still Produces Wrong Answers. Here’s Why (and How to Fix It).

    April 18, 2026

    Europe Warns of a Next-Gen Cyber Threat

    April 18, 2026

    How to Learn Python for Data Science Fast in 2026 (Without Wasting Time)

    April 18, 2026

    Comments are closed.

    Editors Picks

    Francis Bacon and the Scientific Method

    April 19, 2026

    Proxy-Pointer RAG: Structure Meets Scale at 100% Accuracy with Smarter Retrieval

    April 19, 2026

    Sulfur lava exoplanet L 98-59 d defies classification

    April 19, 2026

    Hisense U7SG TV Review (2026): Better Design, Great Value

    April 19, 2026
    Categories
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    About Us
    About Us

    Welcome to Times Featured, an AI-driven entrepreneurship growth engine that is transforming the future of work, bridging the digital divide and encouraging younger community inclusion in the 4th Industrial Revolution, and nurturing new market leaders.

    Empowering the growth of profiles, leaders, entrepreneurs businesses, and startups on international landscape.

    Asia-Middle East-Europe-North America-Australia-Africa

    Facebook LinkedIn WhatsApp
    Featured Picks

    Could aluminium become the packaging ‘champion’?

    July 29, 2025

    Samsung Networks Innovation Center Opens its Doors, Offering a Close Look at Advanced Network Connectivity

    August 16, 2024

    Elon Musk’s SpaceX is bending the rules to launch its $3 trillion IPO

    April 17, 2026
    Categories
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    Copyright © 2024 Timesfeatured.com IP Limited. All Rights.
    • Privacy Policy
    • Disclaimer
    • Terms and Conditions
    • About us
    • Contact us

    Type above and press Enter to search. Press Esc to cancel.