Close Menu
    Facebook LinkedIn YouTube WhatsApp X (Twitter) Pinterest
    Trending
    • Dyson Just Launched a Hair Dryer That Fits in Your Carry-On
    • Your RAG Gets Confidently Wrong as Memory Grows – I Built the Memory Layer That Stops It
    • Ancient parrot feathers reveal vast Andes trade routes
    • After building global startup, two founders who met at uni are backing a new generation of Kiwi students
    • This Scammer Used an AI-Generated MAGA Girl to Grift ‘Super Dumb’ Men
    • Arizona court battle against Kalshi slows amid legal scope disputes
    • Today’s NYT Connections Hints, Answers for April 21 #1045
    • High-Endurance ASW and Strike USV
    Facebook LinkedIn WhatsApp
    Times FeaturedTimes Featured
    Tuesday, April 21
    • Home
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    • More
      • AI
      • Robotics
      • Industries
      • Global
    Times FeaturedTimes Featured
    Home»Artificial Intelligence»AI in Multiple GPUs: ZeRO & FSDP
    Artificial Intelligence

    AI in Multiple GPUs: ZeRO & FSDP

    Editor Times FeaturedBy Editor Times FeaturedMarch 5, 2026No Comments10 Mins Read
    Facebook Twitter Pinterest Telegram LinkedIn Tumblr WhatsApp Email
    Share
    Facebook Twitter LinkedIn Pinterest Telegram Email WhatsApp Copy Link


    of a sequence about distributed AI throughout a number of GPUs:

    Introduction

    Within the earlier put up, we noticed how Distributed Information Parallelism (DDP) accelerates coaching by splitting batches throughout GPUs. DDP solves the throughput downside, however it introduces a brand new problem: reminiscence redundancy.

    In vanilla DDP, each GPU holds an entire copy of the mannequin parameters, gradients, and optimizer states. For giant fashions like GPT-3 (175B parameters), this redundancy turns into an enormous waste of valuable VRAM.

    Picture by writer: Mannequin, gradients and optimizer are redundant throughout GPUs in common DDP

    ZeRO (Zero Redundancy Optimizer) solves this. There are three ranges:

    • ZeRO-1 partitions solely optimizer states
    • ZeRO-2 partitions optimizer states + gradients
    • ZeRO-3 partitions optimizer states + gradients + mannequin parameters

    ZeRO isn’t a parallelism method as a result of all GPUs nonetheless run the identical ahead and backward passes. It’s a reminiscence optimization technique that eliminates redundancy throughout GPUs, letting you practice bigger fashions on the identical {hardware}.

    The Reminiscence Downside in DDP

    Let’s break down what truly consumes reminiscence throughout coaching. For a mannequin with  parameters:

    • Mannequin Parameters:  values (the weights of your neural community)
    • Gradients:  values (one gradient per parameter)
    • Optimizer States (Adam):  values (first second  and second second  for every parameter)
    • Activations: Intermediate outputs saved throughout ahead cross to be used in backward cross

    The primary three scale with mannequin measurement and are redundant throughout GPUs in DDP. Activations scale with batch measurement, sequence size, and # neurons, and are distinctive per GPU since every GPU processes completely different information. ZeRO doesn’t contact activation reminiscence.

    Let’s calculate the reminiscence utilization for a 7B-parameter mannequin utilizing Adam and FP32:

    • Parameters: 7 billion * 4 bytes = 28 GB
    • Gradients: 7 billion * 4 bytes = 28 GB
    • Optimizer states: 7 billion * 2 * 4 bytes = 56 GB
    • Reminiscence per GPU in DDP:  112 GB

    Activations add important reminiscence on prime of this, however since they’re distinctive per GPU, ZeRO can’t partition them. Methods like activation checkpointing might help, it discards some activations after which recomputes them as wanted in the course of the backward cross. However that’s outdoors the scope of this text.

    Let’s perceive how ZeRO works by implementing it from the bottom up, beginning with ZeRO-1 and dealing our technique to ZeRO-3.

    ZeRO-1: Optimizer State Partitioning

    In ZeRO-1, solely the optimizer states are partitioned. Every GPU:

    • Nonetheless holds the full mannequin parameters and gradients
    • Shops solely 1/N of the optimizer states (N = variety of GPUs)
    • Updates solely the corresponding 1/N of the parameters

    That is the sequence actions taken throughout coaching:

    1. Ahead cross: every GPU processes its personal micro-batch
    2. Backward cross: compute gradients
    3. all-reduce gradients: each GPU will get the all gradients
    4. Optimizer step: Every GPU updates its parameter partition
    5. all-gather parameters: sync the up to date mannequin throughout GPUs
    Picture by writer: Zero 1 animation

    Right here’s a simplified implementation:

    import torch
    import torch.distributed as dist
    
    
    class ZeRO_1:
        def __init__(self, mannequin, optimizer_cls):
            self.mannequin = mannequin
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
    
            self.param_shards = listing()  # every rank holds solely its shard of the optimizer states
            self.param_metadata = listing()  # metadata to reconstruct shards
    
            for param in self.mannequin.parameters():
                original_shape = param.information.form
                flat = param.information.view(-1)
                numel = flat.numel()
    
                the rest = numel % self.world_size
                pad_size = (self.world_size - the rest) % self.world_size
                padded_numel = numel + pad_size
                shard_size = padded_numel // self.world_size
    
                shard_start = self.rank * shard_size
                shard_end = shard_start + shard_size
    
                self.param_metadata.append(
                    {
                        "original_shape": original_shape,
                        "numel": numel,
                        "padded_numel": padded_numel,
                        "shard_size": shard_size,
                        "shard_start": shard_start,
                        "shard_end": shard_end,
                    }
                )
    
                if pad_size > 0:
                    flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
                else:
                    flat_padded = flat
    
                shard = flat_padded[shard_start:shard_end].clone()
                shard.requires_grad_(True)
                self.param_shards.append(shard)
    
            self.optimizer = optimizer_cls(self.param_shards)
    
        def training_step(self, inputs, targets, loss_fn):
            output = self.mannequin(inputs) # ahead
            loss = loss_fn(output, targets) # compute loss
            loss.backward() # backward
    
            self._sync_gradients()  # all-reduce gradients throughout GPUs
            self.optimizer.step() # replace native shard of parameters
            self._sync_params() # all collect mannequin params
    
            # clear gradients for the following step
            for param in self.mannequin.parameters():
                param.grad = None
    
        def _sync_gradients(self):
            for idx, param in enumerate(self.mannequin.parameters()):
                meta = self.param_metadata[idx]
    
                dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
                param.grad /= self.world_size
    
                self.param_shards[idx].grad = param.grad.view(-1)[meta["shard_start"]:meta["shard_end"]]
    
        def _sync_params(self):
            for idx, param in enumerate(self.mannequin.parameters()):
                meta = self.param_metadata[idx]
    
                full_flat = torch.empty(meta["padded_numel"], machine=param.machine, dtype=param.dtype)
                dist.all_gather_into_tensor(
                    output_tensor=full_flat,
                    input_tensor=self.param_shards[idx].information,
                )
                
                reconstructed = full_flat[:meta["numel"]].view(meta["original_shape"])
                param.information.copy_(reconstructed)

    Discover that the all-reduce syncs all gradients, however every GPU solely makes use of the gradients for its personal parameter partition, it’s overcommunicating. ZeRO-2 fixes this by sharding the gradients too.

    In observe, you’d by no means use ZeRO-1 as ZeRO-2 offers you higher reminiscence financial savings at basically the identical price. However it’s nonetheless value going over it for studying functions.

    Reminiscence with ZeRO-1, 7B mannequin, 8 GPUs:

    • Parameters: 28 GB (absolutely replicated)
    • Gradients: 28 GB (absolutely replicated)
    • Optimizer states: 56 GB / 8 = 7 GB
    • Complete per GPU: 63 GB (down from  GB)

    ZeRO-2: Gradient Partitioning

    ZeRO-2 partitions each optimizer states and gradients. Since every GPU solely updates a partition of parameters, it solely wants the corresponding gradients.

    ZeRO-1 makes use of all-reduce, which provides each GPU all of the gradients. ZeRO-2 replaces this with reduce-scatter, every GPU receives solely the gradients it truly wants. This protects each reminiscence and communication bandwidth.

    Coaching steps:

    1. Ahead cross: every GPU processes its personal micro-batch
    2. Backward cross: compute gradients
    3. reduce-scatter gradients: every GPU will get solely its partition
    4. Optimizer step: Every GPU updates its parameter partition
    5. all-gather parameters: sync the up to date mannequin throughout GPUs
    Picture by writer: Zero 2 animation

    The implementation is similar to ZeRO-1, however the gradient synchronization step makes use of reduce-scatter as an alternative of all-reduce:
    However wait, if each GPU computes all gradients throughout backprop, how does this truly save VRAM? Right here’s how:

    • Because the parameter gradients are computed layer by layer, they’re instantly reduce-scattered and the native copy is freed (our simplified implementation doesn’t carry out this).
    • Throughout backprop, you solely want the gradient of the following neuron activation to compute the present param’s gradient, i.e., you don’t want all the gradient graph.
    • That approach you possibly can unencumber the reminiscence for gradients as you’re transferring backwards, preserving solely the assigned partition for every GPU.

    Reminiscence with ZeRO-2, 7B mannequin, 8 GPUs:

    • Parameters: 28 GB (absolutely replicated)
    • Gradients: 28 GB / 8 = 3.5 GB
    • Optimizer states: 56 GB / 8 = 7 GB
    • Complete per GPU: 38.5 GB (down from 112 GB)

    ZeRO-3: Parameter Partitioning

    ZeRO-3 partitions optimizer states, gradients, and parameters. Every GPU shops only one/N of all the mannequin state.

    Throughout ahead and backward passes, every layer wants its full parameters, however every GPU solely shops a fraction. So we all-gather parameters just-in-time, use them, then discard instantly after.

    Coaching steps:

    • Ahead cross:
      • All-gather the layer’s parameters from all GPUs
      • Run the layer’s ahead cross utilizing earlier layer’s activations as enter
      • Discard the gathered parameters (maintain solely the native partition)
      • Repeat these steps till all layers are performed
    • Backward cross (per layer, in reverse):
      • All-gather the layer’s parameters once more
      • Compute gradients for present layer utilizing activation gradients from subsequent layer
      • Cut back-scatter the gradients (every GPU retains its shard)
      • Discard the gathered parameters (maintain solely the native partition)
      • Repeat these steps till all layers are performed
    • Every GPU runs an optimizer step on its partition
    • No ultimate all-gather wanted since parameters are gathered layer-by-layer in the course of the ahead cross
    Picture by writer: Zero 3 animation

    Right here’s a simplified implementation:

    class ZeRO_3(ZeRO_2):
        """
        ZeRO-3: Shard optimizer states (stage 1) + gradients (stage 2) + mannequin parameters (stage 3).
    
        At relaxation, every rank holds solely param_shards[idx] — a 1/world_size slice
        of every parameter. Full parameters are materialised quickly throughout
        the ahead and backward passes through all_gather, then instantly freed.
        """
    
        def __init__(self, mannequin, optimizer_cls):
            self.mannequin = mannequin
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
    
            self.param_metadata = []
            shard_list = []
    
            self._param_to_idx = {}
    
            for idx, param in enumerate(self.mannequin.parameters()):
                original_shape = param.information.form
                flat = param.information.view(-1)
                numel = flat.numel()
    
                the rest = numel % self.world_size
                pad_size = (self.world_size - the rest) % self.world_size
                padded_numel = numel + pad_size
                shard_size = padded_numel // self.world_size
    
                shard_start = self.rank * shard_size
                shard_end = shard_start + shard_size
    
                self.param_metadata.append(
                    {
                        "original_shape": original_shape,
                        "numel": numel,
                        "padded_numel": padded_numel,
                        "shard_size": shard_size,
                        "shard_start": shard_start,
                        "shard_end": shard_end,
                    }
                )
    
                if pad_size > 0:
                    flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
                else:
                    flat_padded = flat
    
                shard = flat_padded[shard_start:shard_end].clone()
                shard_list.append(shard)
    
                # Exchange the complete tensor with solely this rank's shard.
                # The mannequin's param.information now factors to a tiny slice; the complete
                # weight will probably be reconstructed on demand throughout ahead/backward.
                param.information = shard.detach()
                self._param_to_idx[param] = idx
    
            self.param_shards = [s.requires_grad_(True) for s in shard_list]
            self.optimizer = optimizer_cls(self.param_shards)
    
            self._register_hooks()
    
        def _gather_param(self, idx, machine, dtype):
            """All-gather the complete parameter tensor for parameter `idx`."""
            meta = self.param_metadata[idx]
            full_flat = torch.empty(meta["padded_numel"], machine=machine, dtype=dtype)
            dist.all_gather_into_tensor(
                output_tensor=full_flat,
                input_tensor=self.param_shards[idx].information,
            )
            return full_flat[: meta["numel"]].view(meta["original_shape"])
    
        def _gather_module_params(self, module):
            """Collect full params for each parameter that belongs to this module solely (not youngsters)."""
            for param in module.parameters(recurse=False):
                idx = self._param_to_idx[param]
                param.information = self._gather_param(idx, param.machine, param.dtype)
    
        def _reshard_module_params(self, module):
            """Reshard params again to native shard for each direct param of this module."""
            for param in module.parameters(recurse=False):
                idx = self._param_to_idx[param]
                param.information = self.param_shards[idx].information
    
        def _register_hooks(self):
            self._hooks = []
            for module in self.mannequin.modules():
                # Skip container modules that don't have any direct parameters
                if not listing(module.parameters(recurse=False)):
                    proceed
    
                # Ahead: collect -> run -> reshard
                h1 = module.register_forward_pre_hook(
                    lambda mod, _inputs: self._gather_module_params(mod)
                )
                h2 = module.register_forward_hook(
                    lambda mod, _inputs, _output: self._reshard_module_params(mod)
                )
    
                # Backward: collect earlier than grad computation → reshard after
                h3 = module.register_full_backward_pre_hook(
                    lambda mod, _grad_output: self._gather_module_params(mod)
                )
                h4 = module.register_full_backward_hook(
                    lambda mod, _grad_input, _grad_output: self._reshard_module_params(mod)
                )
    
                self._hooks.prolong([h1, h2, h3, h4])
    
        def training_step(self, inputs, targets, loss_fn):
            # Hooks deal with all collect/reshard round every module routinely
            output = self.mannequin(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
    
            self._sync_gradients()
    
            # Every rank updates solely its native shard
            self.optimizer.step()
    
            for param in self.mannequin.parameters():
                param.grad = None

    Every layer’s parameters are gathered proper earlier than they’re wanted and freed instantly after. This retains peak reminiscence minimal at the price of extra communication. In observe, implementations overlap the all-gather for layer N+1 with the ahead of layer N to cover latency.

    Reminiscence with ZeRO-3, 7B mannequin, 8 GPUs:

    • Parameters: 28 GB / 8 = 3.5 GB
    • Gradients: 28 GB / 8 = 3.5 GB
    • Optimizer states: 56 GB / 8 = 7 GB
    • Complete per GPU: 14 GB (down from 112 GB)

    That’s an 8x discount in reminiscence utilization, which is precisely what we’d count on from partitioning throughout 8 GPUs.

    Utilizing ZeRO in PyTorch

    PyTorch ships with two implementations of ZeRO-3: FSDP1 (older, much less optimized) and FSDP2 (newer, advisable). At all times use FSDP2.

    FSDP (Totally Sharded Information Parallel) handles parameter gathering, gradient scattering, communication overlap, and reminiscence administration routinely:

    from torch.distributed.fsdp import fully_shard
    
    mannequin = Transformer()
    for layer in mannequin.layers:
        fully_shard(layer)
    fully_shard(mannequin)

    You need to apply fully_shard layer-by-layer after which wrap the entire mannequin.

    Conclusion

    ZeRO is exchanging reminiscence for communication, so it’s not a free lunch. Typically it’s not value it for smaller fashions (e.g. BERT) however it’s a sport changer for bigger fashions.

    Congratulations on making it to the tip! On this put up, you realized about:

    • The reminiscence redundancy downside in normal DDP
    • How ZeRO partitions optimizer states, gradients, and parameters throughout GPUs
    • The three phases of ZeRO and their reminiscence/communication trade-offs
    • The best way to use ZeRO-3 through PyTorch’s FSDP

    Within the subsequent article, we’ll discover Tensor Parallelism, a mannequin parallelism method that accelerates a layer computation by distributing work throughout GPUs.

    References

    1. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (Authentic Paper)
    2. PyTorch FSDP Tutorial
    3. FSDP API Reference
    4. The Ultra-Scale Playbook by Huggging Face



    Source link

    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Editor Times Featured
    • Website

    Related Posts

    Your RAG Gets Confidently Wrong as Memory Grows – I Built the Memory Layer That Stops It

    April 21, 2026

    The LLM Gamble | Towards Data Science

    April 21, 2026

    Context Payload Optimization for ICL-Based Tabular Foundation Models

    April 21, 2026

    What Does the p-value Even Mean?

    April 20, 2026

    From Risk to Asset: Designing a Practical Data Strategy That Actually Works

    April 20, 2026

    Will Humans Live Forever? AI Races to Defeat Aging

    April 20, 2026

    Comments are closed.

    Editors Picks

    Dyson Just Launched a Hair Dryer That Fits in Your Carry-On

    April 21, 2026

    Your RAG Gets Confidently Wrong as Memory Grows – I Built the Memory Layer That Stops It

    April 21, 2026

    Ancient parrot feathers reveal vast Andes trade routes

    April 21, 2026

    After building global startup, two founders who met at uni are backing a new generation of Kiwi students

    April 21, 2026
    Categories
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    About Us
    About Us

    Welcome to Times Featured, an AI-driven entrepreneurship growth engine that is transforming the future of work, bridging the digital divide and encouraging younger community inclusion in the 4th Industrial Revolution, and nurturing new market leaders.

    Empowering the growth of profiles, leaders, entrepreneurs businesses, and startups on international landscape.

    Asia-Middle East-Europe-North America-Australia-Africa

    Facebook LinkedIn WhatsApp
    Featured Picks

    AI-Generated Anti-ICE Videos Are Getting the Fanfic Treatment

    January 29, 2026

    What you may have missed about GPT-5

    August 12, 2025

    I Tested UnGPT: Some Features Surprised Me

    August 28, 2025
    Categories
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    Copyright © 2024 Timesfeatured.com IP Limited. All Rights.
    • Privacy Policy
    • Disclaimer
    • Terms and Conditions
    • About us
    • Contact us

    Type above and press Enter to search. Press Esc to cancel.