Close Menu
    Facebook LinkedIn YouTube WhatsApp X (Twitter) Pinterest
    Trending
    • Better Markets urges courts to let states regulate prediction markets, not CFTC
    • The World’s Smallest Wellness Wearable, Smart Earrings, Just Launched on Kickstarter
    • The FPGA Chip Is an IEEE Milestone
    • Snow Peak Field Rise inflatable rooftop glamping tent
    • OpenAI Really Wants Codex to Shut Up About Goblins
    • Proton VPN to Offer More Speed, More Security, More Servers
    • Sparse AI Hardware Slashes Energy and Latency
    • PyTorch NaNs Are Silent Killers — So I Built a 3ms Hook to Catch Them at the Exact Layer
    Facebook LinkedIn WhatsApp
    Times FeaturedTimes Featured
    Wednesday, April 29
    • Home
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    • More
      • AI
      • Robotics
      • Industries
      • Global
    Times FeaturedTimes Featured
    Home»Artificial Intelligence»PyTorch NaNs Are Silent Killers — So I Built a 3ms Hook to Catch Them at the Exact Layer
    Artificial Intelligence

    PyTorch NaNs Are Silent Killers — So I Built a 3ms Hook to Catch Them at the Exact Layer

    Editor Times FeaturedBy Editor Times FeaturedApril 28, 2026No Comments12 Mins Read
    Facebook Twitter Pinterest Telegram LinkedIn Tumblr WhatsApp Email
    Share
    Facebook Twitter LinkedIn Pinterest Telegram Email WhatsApp Copy Link


    • NaNs don’t originate the place they seem — they silently propagate throughout layers
    • torch.autograd.set_detect_anomaly is too sluggish and sometimes deceptive for actual debugging
    • A ahead hook–primarily based detector can catch NaNs on the precise layer and batch they first happen
    • Overhead is ~3–4 ms per ahead cross, far decrease than anomaly detection (particularly on GPU)
    • Gradient explosion is the true root trigger normally — catching it early prevents NaNs fully
    • The system logs structured occasions (layer, batch, stats) for exact debugging
    • Designed for manufacturing: thread-safe, memory-bounded, and scalable

    It was batch 47,000. A ResNet variant I had been coaching for six hours on a customized medical imaging dataset. The loss was converging cleanly — 1.4, 1.1, 0.87, 0.73 — after which, nothing. Not an error. Not a crash. Simply nan.

    I added torch.autograd.set_detect_anomaly(True) and restarted. The coaching slowed to a crawl — roughly 7–10× longer per batch on CPU alone — and after three hours I lastly acquired a stack hint pointing to a layer that, frankly, seemed superb. The actual wrongdoer was a studying charge scheduler interacting badly with a customized normalization layer two layers upstream. set_detect_anomaly had pointed me on the symptom, not the supply.

    That debugging session price me most of a day. So I constructed one thing higher.

    NaNs don’t crash your mannequin — they quietly corrupt it. By the point you discover, you’re already debugging the flawed layer.

    Full code: https://github.com/Emmimal/pytorch-nan-detector/


    The Downside with set_detect_anomaly

    PyTorch ships with torch.autograd.set_detect_anomaly(True), which is the usual advice for debugging NaN points. It really works by retaining the total computation graph and checking for anomalies throughout the backward cross. That is highly effective, but it surely comes with severe prices that make it unsuitable for something past a fast native sanity verify.

    The core concern is that it forces PyTorch’s autograd engine right into a synchronous mode the place it saves intermediate activations for each single operation. On GPU, this implies breaking the asynchronous execution pipeline — each kernel launch has to finish earlier than the subsequent one begins. The end result, as reported within the PyTorch documentation and extensively noticed in apply, is an overhead that ranges from roughly 10–15× on CPU to 50–100× on GPU for bigger fashions [1][2].

    There’s a second drawback: set_detect_anomaly factors you at the place the NaN propagated to within the backward cross, not essentially the place it originated. If a NaN enters your community at layer 3 of a 50-layer mannequin, the backward cross will floor an error someplace within the gradient computation for a later layer, and you’re left working backward from there.

    My benchmark, run on a small CPU MLP (64→256→256→10), measured:

    Methodology Imply latency Overhead vs baseline
    No detection ~0.60 ms baseline
    NaNDetector (ahead hooks) ~3–4 ms ~5–6×
    set_detect_anomaly ~7–8 ms ~12–13×
    Ahead hook–primarily based NaN detection provides ~3 ms per cross, whereas set_detect_anomaly provides ~7 ms — a small hole right here, however a serious slowdown at scale, particularly on GPU. Picture by Writer

    On this small mannequin absolutely the distinction is modest. At scale — a transformer with lots of of tens of millions of parameters on a number of GPUs — the hole is the distinction between a coaching run that completes and one that doesn’t.


    The Strategy: Ahead Hooks

    PyTorch NaN detection architecture diagram showing forward hooks, gradient monitoring, and training loop integration
    Finish-to-end NaN detection pipeline: ahead hooks catch activation points, gradient norm guard detects instability early, and structured occasions allow exact debugging. Picture by Writer

    PyTorch’s register_forward_hook API allows you to connect a callback to any nn.Module that fires each time that module completes a ahead cross [3]. The callback receives the module itself, its inputs, and its outputs. This implies you possibly can examine each tensor flowing by means of each layer in actual time — with no impression on the computation graph, no compelled synchronization, and no retained activations.

    The important thing perception is that you simply solely must do the NaN verify, not replay the computation. A verify in opposition to torch.isnan() and torch.isinf() on an output tensor is a single CUDA kernel invocation and completes in microseconds.

    def hook(module, inputs, output):
        if torch.isnan(output).any():
            print(f"NaN detected in {layer_name}")

    That’s the core of the thought. What follows is the production-hardened model.


    The Implementation

    The total supply is obtainable at: https://github.com/Emmimal/pytorch-nan-detector/

    I’ll stroll by means of the 4 elements that matter.

    Part 1: The NaNEvent dataclass

    When a NaN is detected, you want greater than a print assertion. You want a structured file you possibly can examine after the very fact, log to disk, or ship to an alerting system.

    @dataclass
    class NaNEvent:
        batch_idx: int
        layer_name: str
        module_type: str
        input_has_nan: bool
        output_has_nan: bool
        input_has_inf: bool
        output_has_inf: bool
        output_shape: tuple
        output_stats: dict = area(default_factory=dict)
        is_backward: bool = False

    The output_stats area comprises the min, max, and imply of the finite values within the output tensor in the intervening time of detection. That is surprisingly helpful — a layer output the place 3 values are NaN however the remaining are finite tells a special story than one that’s all NaN.

    The is_backward flag distinguishes whether or not the occasion was caught in a ahead hook or a backward hook, which issues for root trigger evaluation.

    Part 2: Thread-safe hook registration

    Crucial manufacturing consideration is thread security. PyTorch’s DataLoader runs employee processes that may set off ahead hooks from background threads. In the event you mutate triggered = True and self.occasion = ev with no lock, you’re going to get race situations on multi-worker setups.

    self._lock = threading.Lock()
    
    def _make_fwd_hook(self, layer_name: str):
        def hook(module, inputs, output):
            with self._lock:
                if self.triggered and self.stop_on_first:
                    return
                current_batch = self._batch_idx
            # ... tensor checks occur exterior the lock
            if out_nan or out_inf:
                self._record_event(...)   # lock re-acquired inside
        return hook

    The tensor checks themselves occur exterior the lock as a result of torch.isnan() is read-only and thread-safe. Solely the shared state mutations are locked.

    Part 3: Bounded reminiscence

    A refined concern with lengthy coaching runs: should you accumulate overhead timings in an unbounded checklist, you’ll finally exhaust reminiscence on runs lasting tens of millions of batches. The repair is a straightforward cap:

    _OVERHEAD_CAP = 1000
    
    with self._lock:
        if len(self._overhead_ms) < self._OVERHEAD_CAP:
            self._overhead_ms.append(elapsed)

    The identical logic applies to all_events when stop_on_first=False — a max_events parameter (default 100) prevents unbounded accumulation throughout pathological runs.

    Part 4: Gradient norm guard

    The most typical real-world path to a NaN shouldn’t be a bug that straight produces nan — it’s a studying charge that’s too excessive inflicting gradient norms to blow up to inf, which then propagates into the weights and produces NaN activations on the subsequent ahead cross. By the point your ahead hook fires, you’re already one step too late.

    The check_grad_norms() methodology addresses this by strolling all parameters after loss.backward() and logging a GradEvent for any parameter whose gradient norm exceeds a threshold:

    def check_grad_norms(self) -> bool:
        if self.grad_norm_warn is None:
            return False
        for title, module in self.mannequin.named_modules():
            for pname, param in module.named_parameters(recurse=False):
                if param.grad is None:
                    proceed
                norm = param.grad.detach().float().norm().merchandise()
                if not math.isfinite(norm) or norm > self.grad_norm_warn:
                    # log GradEvent

    Within the demo beneath, this methodology catches gradient explosion at batch 1 — one full coaching step earlier than the NaN would have appeared within the ahead cross.

    Exploding gradient norms detected early during training before NaN appears in forward pass
    Gradient norms explode at batch 1 — caught early earlier than NaNs propagate into activations. Picture by Writer

    Utilization

    Primary: context supervisor

    from nan_detector import NaNDetector
    
    with NaNDetector(mannequin) as det:
        for batch_idx, (x, y) in enumerate(loader):
            det.set_batch(batch_idx)
            loss = criterion(mannequin(x), y)
            loss.backward()
            det.check_grad_norms()
            optimizer.step()
            if det.triggered:
                print(det.occasion)
                break

    When the detector fires, det.occasion comprises the total NaNEvent with layer title, module sort, batch index, and output statistics.

    Manufacturing: drop-in coaching loop

    from nan_detector import train_with_nan_guard
    
    losses, occasion = train_with_nan_guard(
        mannequin, loader, criterion, optimizer,
        machine="cuda",
        grad_norm_warn=50.0,
    )
    
    if occasion:
        print(f"NaN at batch {occasion.batch_idx}, layer {occasion.layer_name}")

    Superior: backward hooks + readable layer names

    For catching gradient NaNs straight (not simply norm warnings), allow check_backward=True. Use OrderedDict when constructing Sequential fashions to get readable names in all log output:

    from collections import OrderedDict
    
    mannequin = nn.Sequential(OrderedDict([
        ("fc1",   nn.Linear(16, 32)),
        ("relu1", nn.ReLU()),
        ("fc2",   nn.Linear(32, 1)),
    ]))
    
    with NaNDetector(mannequin, check_backward=True, grad_norm_warn=10.0) as det:
        ...

    With out OrderedDict, PyTorch names layers by index (0.weight, 2.bias). With it, you get fc1.weight, fc2.bias — a small factor that saves actual time when debugging deep fashions.

    Skipping layers

    Some layer sorts are anticipated to supply non-finite outputs below regular situations — nn.Dropout throughout eval, sure normalization layers throughout the first ahead cross earlier than operating stats are established. Skip them with:

    det = NaNDetector(mannequin, skip_types=(nn.Dropout, nn.BatchNorm1d))

    Demo Output

    Operating the three demos produces the next output:

    ────────────────────────────────────────────────────────────
      Demo 1: Ahead NaN detection + loss curve plot
    ────────────────────────────────────────────────────────────
    [NaNDetector] Connected 5 hooks.
    ============================================================
      NaN/Inf detected! [FORWARD PASS]
      Batch     : 12
      Layer     : layer4
      Kind      : Linear
      Flags     : NaN in INPUT, NaN in OUTPUT
      Out form : (8, 1)
      Out stats : min=n/a (all non-finite)  max=n/a (all non-finite)  imply=n/a (all non-finite)
    ============================================================
    [NaNDetector] Indifferent. Avg overhead: 0.109 ms/forward-pass
    
    ────────────────────────────────────────────────────────────
      Demo 2: Backward / grad-norm detection + grad norm plot
    ────────────────────────────────────────────────────────────
    [NaNDetector] Connected 8 hooks (+ backward).
    [GradNorm WARNING] batch=1  layer=fc1.weight  norm=inf  threshold=10.0
    [GradNorm WARNING] batch=1  layer=fc1.bias    norm=inf  threshold=10.0
    [GradNorm WARNING] batch=1  layer=fc2.weight  norm=inf  threshold=10.0
    [GradNorm WARNING] batch=1  layer=fc2.bias    norm=4.37e+18  threshold=10.0
      Caught at batch 1
    Training loss curve showing smooth convergence followed by sudden NaN failure during model training
    Loss drops steadily — then collapses into NaN at batch 12, instantly caught by the detector.

    The hook overhead of 0.109 ms per ahead cross in Demo 1 is the true quantity you possibly can cite. The benchmark determine of ~3 ms displays a bigger mannequin with 5 registered hook callbacks operating concurrently — which is the extra reasonable manufacturing situation.


    Identified Limitations

    Ahead hooks see activations, not all computation. If a NaN originates inside a customized torch.autograd.Perform‘s backward() methodology, or inside a C++/CUDA extension that doesn’t floor by means of named nn.Module submodules, the ahead hook won’t catch it. Use check_backward=True for gradient-side protection, and grad_norm_warn for early warning.

    Overhead scales with mannequin depth. The benchmark was run on a 5-layer MLP. A 200-layer transformer may have 200 hook callbacks firing per ahead cross. The overhead continues to be sub-millisecond per hook, but it surely accumulates. Mitigate through the use of skip_types to exclude non-parametric layers like ReLU, Dropout, and LayerNorm if overhead turns into a priority.

    CPU benchmark ratios are noisy. The overhead ratio between NaNDetector and set_detect_anomaly diverse between 5× and 6× throughout runs in my testing, as a result of CPU microbenchmarks at sub-millisecond scale are delicate to OS scheduling and cache state. Absolutely the millisecond numbers are extra steady. The 50–100× determine cited for GPU is drawn from the PyTorch documentation and group benchmarks [1][2], not my very own GPU measurements.


    What This Does Not Substitute

    It is a debugging and monitoring instrument, not an alternative choice to good coaching hygiene. The usual suggestions nonetheless apply: gradient clipping (torch.nn.utils.clip_grad_norm_), cautious studying charge scheduling, enter normalization, and weight initialization. NaNDetector tells you the place and when the issue occurred — it doesn’t let you know why, and fixing the basis trigger nonetheless requires engineering judgment.

    If you’re hitting NaNs in mixed-precision coaching (fp16/bf16), the most typical culprits are loss scaling overflow and layer norm instability, and people are value investigating straight earlier than reaching for a debugging hook.


    Benchmark Methodology

    All benchmarks had been run on CPU (Home windows 11, PyTorch 2.x) utilizing a 4-layer MLP with enter dimension 64, two hidden layers of 256, and output dimension 10. Batch dimension was 64. Every methodology ran 30 ahead passes. The primary cross was included within the imply — cold-start results are actual and must be counted. Occasions had been measured with time.perf_counter() across the ahead name solely, not together with information loading or loss computation.

    The total benchmark perform is included within the supply and could be run with benchmark(n_batches=30, batch_size=64).


    References

    [1] PyTorch Documentation. “Autograd Mechanics — Anomaly Detection.” pytorch.org. Out there at: https://pytorch.org/docs/stable/autograd.html#anomaly-detection

    [2] PyTorch Documentation. torch.autograd.set_detect_anomaly. pytorch.org. Out there at: https://docs.pytorch.org/docs/stable/autograd.html

    [3] PyTorch Documentation. torch.nn.Module.register_forward_hook. pytorch.org. Out there at: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook

    [4] PyTorch Documentation. torch.nn.Module.register_full_backward_hook. pytorch.org. Out there at: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook

    [5] PyTorch Documentation. “Gradient Clipping — clip_grad_norm_.” pytorch.org. Out there at: https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html

    [6] Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., … & Chintala, S. (2019). PyTorch: An crucial model, high-performance deep studying library. arXiv preprint arXiv:1912.01703. https://doi.org/10.48550/arXiv.1912.01703

    [7] Python Software program Basis. threading — Thread-based parallelism. Python 3 Documentation. Out there at: https://docs.python.org/3/library/threading.html

    [8] Python Software program Basis. dataclasses — Information Courses. Python 3 Documentation. Out there at: https://docs.python.org/3/library/dataclasses.html

    [9] Hunter, J. D. (2007). Matplotlib: A 2D graphics atmosphere. Computing in Science & Engineering, 9(3), 90–95. https://doi.org/10.1109/MCSE.2007.55


    Disclosure

    I constructed and wrote about this instrument myself. There is no such thing as a sponsorship, no affiliation with PyTorch or the PyTorch Basis, and no monetary relationship with any firm talked about on this article. The benchmarks had been run by myself {hardware} and are reproducible utilizing the code within the repository linked above.

    All code on this article is unique. The instrument was written from scratch; no present open-source NaN detection library was used as a base. In the event you use this in your personal work, attribution is appreciated however not required — the code is MIT licensed.

    The benchmark comparability in opposition to set_detect_anomaly relies by myself measurements on a particular {hardware} configuration. Outcomes will differ by mannequin structure, {hardware}, and PyTorch model. The 50–100× GPU overhead determine is drawn from PyTorch’s official documentation [1][2] and isn’t my very own GPU measurement.

    Full supply code, together with all three demos and the benchmark perform: https://github.com/Emmimal/pytorch-nan-detector/



    Source link

    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Editor Times Featured
    • Website

    Related Posts

    Correlation Doesn’t Mean Causation! But What Does It Mean?

    April 28, 2026

    Let the AI Do the Experimenting

    April 28, 2026

    The Next Frontier of AI in Production Is Chaos Engineering

    April 28, 2026

    How Spreadsheets Quietly Cost Supply Chains Millions

    April 27, 2026

    A Career in Data Is Not Always a Straight Line, and That’s Okay

    April 27, 2026

    Microsoft has loosened its exclusive control over OpenAI, and now the artificial intelligence race appears wide open

    April 27, 2026
    Leave A Reply Cancel Reply

    Editors Picks

    Better Markets urges courts to let states regulate prediction markets, not CFTC

    April 29, 2026

    The World’s Smallest Wellness Wearable, Smart Earrings, Just Launched on Kickstarter

    April 29, 2026

    The FPGA Chip Is an IEEE Milestone

    April 29, 2026

    Snow Peak Field Rise inflatable rooftop glamping tent

    April 29, 2026
    Categories
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    About Us
    About Us

    Welcome to Times Featured, an AI-driven entrepreneurship growth engine that is transforming the future of work, bridging the digital divide and encouraging younger community inclusion in the 4th Industrial Revolution, and nurturing new market leaders.

    Empowering the growth of profiles, leaders, entrepreneurs businesses, and startups on international landscape.

    Asia-Middle East-Europe-North America-Australia-Africa

    Facebook LinkedIn WhatsApp
    Featured Picks

    Stop Talking About AI as if It’s Human. It’s Not

    December 10, 2025

    Herman Miller Is Having a Surprise Flash Sale on Office Chairs

    November 25, 2025

    Stanford’s ultrasound tech targets drug delivery precisely

    August 19, 2025
    Categories
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    Copyright © 2024 Timesfeatured.com IP Limited. All Rights.
    • Privacy Policy
    • Disclaimer
    • Terms and Conditions
    • About us
    • Contact us

    Type above and press Enter to search. Press Esc to cancel.