Close Menu
    Facebook LinkedIn YouTube WhatsApp X (Twitter) Pinterest
    Trending
    • CFTC seeks injunction in Kalshi Rhode Island dispute
    • As AI Expands, Erin Brockovich Taps Communities to Map Data Center Concerns
    • Direct-to-Cell Technology: Enabling Satellite Connectivity for Legacy Devices
    • How small businesses can leverage AI
    • Robots-Blog | Humanoide Robotik aus Deutschland: igus bringt neuen Serviceroboter auf den Markt
    • GM reimagines Hummer off-roader with California ideas unit
    • London’s DEScycle secures over €10 million in grant funding to scale critical metals recovery platform
    • How to Edit, Merge, and Split PDFs With Free Online Tools
    Facebook LinkedIn WhatsApp
    Times FeaturedTimes Featured
    Tuesday, June 2
    • Home
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    • More
      • AI
      • Robotics
      • Industries
      • Global
    Times FeaturedTimes Featured
    Home»Artificial Intelligence»On the Challenge of Converting TensorFlow Models to PyTorch
    Artificial Intelligence

    On the Challenge of Converting TensorFlow Models to PyTorch

    Editor Times FeaturedBy Editor Times FeaturedDecember 5, 2025No Comments20 Mins Read
    Facebook Twitter Pinterest Telegram LinkedIn Tumblr WhatsApp Email
    Share
    Facebook Twitter LinkedIn Pinterest Telegram Email WhatsApp Copy Link


    Within the curiosity of managing reader expectations and stopping disappointment, we want to start by stating that this publish does not present a completely passable resolution to the issue described within the title. We are going to suggest and assess two doable schemes for auto-conversion of TensorFlow fashions to PyTorch — the primary based mostly on the Open Neural Network Exchange (ONNX) format and libraries and the second utilizing the Keras3 API. Nonetheless, as we are going to see, every comes with its personal set of challenges and limitations. To the very best of the authors’ data, on the time of this writing, there aren’t any publicly out there foolproof options to this drawback.

    Many because of Rom Maltser for his contributions to this publish.

    The Decline of TensorFlow

    Through the years, the sphere of pc science has recognized its fair proportion of “spiritual wars” — heated, generally hostile, debates amongst programmers and engineers over the “finest” instruments, languages, and methodologies. Up till just a few years in the past, the spiritual conflict between PyTorch and TensorFlow, two distinguished open-source deep studying frameworks, loomed massive. Proponents of TensorFlow would spotlight its quick graph-execution mode, whereas these within the PyTorch camp would emphasize its “Pythonic” nature and ease of use.

    Nonetheless, nowadays, the quantity of exercise in PyTorch far overshadows that of TensorFlow. That is evidenced by the variety of big-tech corporations which have embraced PyTorch over TensorFlow, by the variety of fashions per framework in HuggingFace’s models repository, and by the quantity of innovation and optimization in every framework. Merely put, TensorFlow is a shell of its former self. The conflict is over, with PyTorch the definitive winner. For a short historical past of the Pytorch-TensorFlow wars and the explanations for TensorFlow’s downfall, see Pan Xinghan’s publish: TensorFlow Is Dead. PyTorch Won.

    Drawback: What will we do with all of our legacy TensorFlow fashions?!!

    In mild of this new actuality, many organizations that after used TensorFlow have moved all of their new AI/ML mannequin growth to PyTorch. However they’re confronted with a tough problem in the case of their legacy code: What ought to they do with the entire fashions which have already been constructed and deployed in TensorFlow?

    Choice 1: Do Nothing.

    You is likely to be questioning why that is even an issue — the TensorFlow fashions work — let’s not contact them. Whereas it is a legitimate strategy, there are a selection of disadvantages that ought to be considered:

    1. Lowered upkeep: As TensorFlow continues to say no so will its upkeep. Inevitably, issues will begin to break. For instance, there could also be problems with compatibility with newer Python packages or system libraries.
    2. Restricted Ecosystem: AI/ML options usually contain a number of supporting software program libraries and providers that interface with our framework of selection, be it PyTorch or TensorFlow. Over time, we will anticipate to see many of those discontinue their assist for TensorFlow. Living proof: HuggingFace just lately announced the deprecation of its support for TensorFlow.
    3. Restricted Group: The AI/ML business owes its quick tempo of growth, largely, to its neighborhood. The variety of open supply initiatives, the variety of on-line tutorials, and the quantity of exercise in devoted assist channels within the AI/ML house, is unparalleled. As TensorFlow declines, so will its neighborhood and chances are you’ll expertise rising problem getting the make it easier to want. For sure, the PyTorch neighborhood is flourishing.
    4. Alternative Price: The PyTorch ecosystem is prospering with fixed improvements and optimizations. Current years have seen the event of flash-attention kernels, assist for the eight-bit floating-point knowledge sort, graph compilation, and lots of different developments which have demonstrated important boosts to runtime efficiency and important reductions in AI/ML prices. Throughout the identical time interval the function providing in TensorFlow has remained largely static. Sticking with TensorFlow means forgoing many alternatives for AI/ML value optimization.

    Choice 2: Manually Convert TensorFlow Fashions to PyTorch

    The second choice is to rewrite legacy TensorFlow fashions in PyTorch. That is in all probability the most suitable choice when it comes to its end result, however for corporations which have constructed up technical debt over a few years, changing even a single mannequin could possibly be a frightening process. Given the hassle required, chances are you’ll select to do that just for fashions which are nonetheless beneath lively growth (e.g., within the mannequin coaching part). Doing this for the entire fashions which are already deployed could show prohibitive.

    Choice 3: Automate TensorFlow to PyTorch Conversion

    The third choice, and the strategy we discover on this publish, is to automate the conversion of legacy TensorFlow fashions to PyTorch. On this method, we hope to perform the advantage of mannequin execution in PyTorch, however with out the large effort of manually changing every one.

    To facilitate our dialogue we are going to outline a toy TensorFlow mannequin and assess two proposals for changing it to PyTorch. As our runtime atmosphere, we are going to use an Amazon EC2 g6e.xlarge with an NVIDIA L40S GPU, an AWS Deep Learning Ubuntu (22.04) AMI, and a Python atmosphere that features the TensorFlow (2.20), PyTorch (2.9), torchvision (0.24.0), and transformers (4.55.4) libraries. Please observe that the code blocks we are going to share are meant for demonstrative functions. Please don’t interpret our use of any code, library, or platform as an endorsement of its use.

    Mannequin Conversion — Why is it Exhausting?

    An AI mannequin definition is comprised of two parts: a mannequin structure and its skilled weights. A mannequin conversion resolution should tackle each parts. Conversion of the mannequin weights is fairly easy; the weights are usually saved in a format that may be simply parsed into particular person tensor arrays and reapplied within the framework of selection. In distinction, conversion of the mannequin structure presents a a lot higher problem.

    One strategy could possibly be to create a mapping between the constructing blocks of the mannequin in every of the frameworks. Nonetheless, there are a selection of things that make this strategy, for all intents and functions, nearly intractable:

    • API Overlap and Proliferation: Once you consider the sheer variety of, typically overlapping, TensorFlow APIs for constructing mannequin parts after which add the huge variety of API controls and arguments for every layer, you’ll be able to see how making a complete, one-to-one mapping can shortly get ugly.
    • Differing Implementation Approaches: On the implementation degree, TensorFlow and PyTorch have basically completely different approaches. Though normally hidden behind the top-level APIs, some assumptions require particular person consideration. For instance, whereas TensorFlow defaults to the “channels-last” (NHWC) format, PyTorch prefers “channels-first” (NCHW). This distinction in how tensors are listed and saved complicates the conversion of mannequin operations, as each layer should be checked/altered for proper dimension ordering.

    Slightly than try conversion on the API degree, an alternate strategy could possibly be to seize and convert an inner TensorFlow graph illustration. Nonetheless, as anybody who has ever regarded beneath the hood of TensorFlow will let you know, this too might get fairly nasty in a short time. TensorFlow’s inner graph illustration is extremely advanced, typically together with a large number of low-level operations, management circulate, and auxiliary nodes that would not have a direct equal in PyTorch (particularly should you’re coping with older variations of TensorFlow). Simply its comprehension appears past regular human skill, not to mention its conversion to PyTorch.

    Observe that the identical challenges would make it tough for a generative AI mannequin to carry out the conversion in a way that’s absolutely dependable.

    Proposed Conversion Schemes

    In mild of those difficulties, we abandon our try at implementing our personal mannequin converter and as an alternative look to see what instruments the AI/ML neighborhood has to supply. Extra particularly, we take into account two completely different methods for overcoming the challenges we described:

    1. Conversion Through a Unified Graph Illustration: This resolution assumes a standard commonplace for representing an AI/ML mannequin definition and utilities for changing fashions to and from this commonplace. The answer we are going to discover makes use of the favored ONNX format.
    2. Conversion Based mostly on a Standardized Excessive-level API: On this resolution we simplify the conversion process by limiting our mannequin to an outlined set of excessive degree summary APIs with supported implementations in every of the AI/ML frameworks of curiosity. For this strategy, we are going to use the Keras3 library.

    Within the subsequent sections we are going to assess these methods on a toy TensorFlow mannequin.

    A Toy TensorFlow Mannequin

    Within the code block under we initialize and run a TensorFlow Imaginative and prescient Transformer (ViT) mannequin from HuggingFace’s standard transformers library (model 4.55.4), TFViTForImageClassification. Observe that consistent with HuggingFace’s determination to deprecate assist for TensorFlow, this class was faraway from current releases of the library. The HuggingFace TensorFlow mannequin depends on Keras 2 which we dutifully set up through the tf-keras (2.20.1) bundle. We set the ViTConfig.hidden_act area to “gelu_new” for ONNX compatibility:

    import tensorflow as tf
    gpu = tf.config.list_physical_devices('GPU')[0]
    tf.config.experimental.set_memory_growth(gpu, True)
    
    from transformers import ViTConfig, TFViTForImageClassification
    vit_config = ViTConfig(hidden_act="gelu_new", return_dict=False)
    tf_model = TFViTForImageClassification(vit_config)

    Mannequin Conversion Utilizing ONNX

    The primary technique we assess depends on Open Neural Network Exchange (ONNX), a neighborhood venture that goals to outline an open format for constructing AI/ML fashions to extend interoperability between AI/ML frameworks and cut back the dependence on any single one. Included within the ONNX API providing are utilities for changing fashions from frequent frameworks, together with TensorFlow, to the ONNX format. There are additionally a number of public libraries for changing ONNX fashions to PyTorch. On this publish we use the onnx2torch utility. Thus, mannequin conversion from TensorFlow to PyTorch may be achieved by successively making use of TensorFlow-to-ONNX conversion adopted by ONNX-to-PyTorch conversion.

    To evaluate this resolution we set up the onnx (1.19.1), tf2onnx (1.16.1), and onnx2torch (1.5.15 ) libraries. We apply the no-deps flag to stop an undesired downgrade of the protobuf library:

    pip set up --no-deps onnx tf2onnx onnx2torch

    The conversion scheme seems within the code block under:

    import tensorflow as tf
    import torch
    import tf2onnx, onnx2torch
    
    BATCH_SIZE = 32
    DEVICE = "cuda"
    
    spec = (tf.TensorSpec((BATCH_SIZE, 3, 224, 224), tf.float32, title="enter"),)
    onnx_model, _ = tf2onnx.convert.from_keras(tf_model, input_signature=spec)
    converted_model = onnx2torch.convert(onnx_model)

    To ensure that the resultant mannequin is certainly a PyTorch module, we run the next assertion:

    assert isinstance(converted_model, torch.nn.Module)

    Allow us to now assess the standard and make-up of the resultant PyTorch mannequin.

    Numerical Precision

    To confirm the validity of the transformed mannequin, we execute each the TensorFlow mannequin and the transformed mannequin on the identical enter and examine the outcomes:

    import numpy as np
    
    batch_input = np.random.randn(BATCH_SIZE, 3, 224, 224).astype(np.float32)
    
    # execute tf mannequin
    tf_input = tf.convert_to_tensor(batch_input)
    tf_output = tf_model(tf_input, coaching=False)
    tf_output = tf_output[0].numpy()
    
    # execute transformed mannequin
    converted_model = converted_model.to(DEVICE)
    converted_model = converted_model.eval()
    torch_input = torch.from_numpy(batch_input).to(DEVICE)
    torch_output = converted_model(torch_input)
    torch_output = torch_output.detach().cpu().numpy()
    
    # examine outcomes
    print("Max diff:", np.max(np.abs(tf_output - torch_output)))
    
    # pattern output:
    # Max diff: 9.3877316e-07

    The outputs are definitely shut sufficient to validate the transformed mannequin.

    Mannequin Construction

    To get a really feel for the construction of the transformed mannequin, we calculate the variety of trainable comparisons and examine it that of the unique mannequin:

    num_tf_params = sum([np.prod(v.shape) for v in tf_model.trainable_weights])
    num_pyt_params = sum([p.numel()
                          for p in converted_model.parameters()
                          if p.requires_grad])
    print(f"TensorFlow trainable parameters: {num_tf_params}")
    print(f"PyTorch Trainable Parameters: {num_pyt_params:,}")

    The distinction within the variety of trainable parameters is profound, simply 589,824 within the transformed mannequin in comparison with over 85 million within the unique mannequin. Traversing the layers of the transformed mannequin results in that very same conclusion: The ONNX-based conversion has fully altered the mannequin construction, rendering it primarily unrecognizable. There are a selection of ramifications to this discovering, together with:

    1. Coaching/fine-tuning the transformed mannequin: Though we’ve proven that the transformed mannequin can be utilized for inference, the change in construction — significantly the truth that among the mannequin parameters have been baked in, implies that we can not use the transformed mannequin for coaching or fine-tuning.
    2. Making use of pinpoint PyTorch optimizations to the mannequin: The transformed mannequin consists of a really massive variety of layers every representing a comparatively low-level operation. This enormously limits our skill to interchange inefficient operations with optimized PyTorch equivalents, resembling torch.nn.functional.scaled_dot_product_attention (SPDA).

    Mannequin Optimization

    Now we have already seen that our skill to entry and modify mannequin operations is restricted, however there are a selection of optimizations that we will apply that don’t require such entry. Within the code block under, we apply PyTorch compilation and automatic mixed precision (AMP) and examine the resultant throughput to that of the TensorFlow mannequin. For additional context, we additionally take a look at the runtime of the PyTorch model of the ViTForImageClassification mannequin:

    # Set tf blended precision coverage to bfloat16
    tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
    
    # Set torch matmul precision to excessive
    torch.set_float32_matmul_precision('excessive')
    
    @tf.operate
    def tf_infer_fn(batch):
        return tf_model(batch, coaching=False)
    
    def get_torch_infer_fn(mannequin):
        def infer_fn(batch):
            with torch.inference_mode(), torch.amp.autocast(
                    DEVICE,
                    dtype=torch.bfloat16,
                    enabled=DEVICE=='cuda'
            ):
                output = mannequin(batch)
            return output
        return infer_fn
    
    def benchmark(infer_fn, batch):
        # warm-up
        for _ in vary(20):
            _ = infer_fn(batch)
        begin = torch.cuda.Occasion(enable_timing=True)
        finish = torch.cuda.Occasion(enable_timing=True)
        torch.cuda.synchronize()
        begin.document()
    
        iters = 100
    
        for _ in vary(iters):
            _ = infer_fn(batch)
        finish.document()
        torch.cuda.synchronize()
        return begin.elapsed_time(finish) / iters
    
    # assess throughput of TF mannequin
    avg_time = benchmark(tf_infer_fn, tf_input)
    print(f"nTensorFlow common step time: {(avg_time):.4f}")
    
    # assess throughput of transformed mannequin
    torch_infer_fn = get_torch_infer_fn(converted_model) 
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"nConverted mannequin common step time: {(avg_time):.4f}")
    
    # assess throughput of compiled mannequin
    torch_infer_fn = get_torch_infer_fn(torch.compile(converted_model)) 
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"nCompiled mannequin common step time: {(avg_time):.4f}")
    
    # assess throughput of torch ViT
    from transformers import ViTForImageClassification
    torch_model = ViTForImageClassification(vit_config).to(DEVICE)
    torch_infer_fn = get_torch_infer_fn(torch_model) 
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"nPyTorch ViT mannequin common step time: {(avg_time):.4f}")
    
    # assess throughput of compiled torch ViT
    torch_infer_fn = get_torch_infer_fn(torch.compile(torch_model)) 
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"nCompiled ViT mannequin common step time: {(avg_time):.4f}")

    Observe that originally PyTorch compilation fails on the transformed mannequin resulting from using torch.Size operator within the OnnxReshape layer. Whereas that is simply fixable (e.g., tuple([int(i) for i in shape])), it factors to a deeper impediment to optimization of the mannequin: The reshape layer, which seems dozens of instances within the mannequin, treats shapes as PyTorch tensors residing on the GPU. Which means that every name requires detaching the form tensor from the graph and copying it to the CPU. The conclusion is that though the transformed mannequin is functionally correct, its resultant definition isn’t optimized for runtime efficiency. This may be seen from the step time outcomes of the completely different mannequin configurations:

    ONNX-Based mostly Conversion Step Time Outcomes (by Creator)

    The transformed mannequin is slower than the unique TensorFlow circulate and considerably slower than PyTorch model of the ViT mannequin.

    Limitations

    Though (within the case of our toy mannequin) the ONNX-based conversion scheme works, it has a lot of important limitations:

    1. Throughout the conversion many parameters have been baked into the mannequin, limiting its use to inference workloads solely.
    2. The ONNX conversion breaks the computation graph into low degree operators in a way that makes it tough to use and/or reap the advantage of some PyTorch optimizations.
    3. The reliance on ONNX implies that our conversion scheme will solely work on ONNX-friendly fashions. It is not going to work on fashions that can not be mapped to the usual ONNX operator set (e.g., fashions with dynamic management circulate).
    4. The conversion scheme depends on the well being and upkeep of a third-party library that’s not a part of the official ONNX providing.

    Though the scheme works — no less than for inference workloads — chances are you’ll discover the constraints to be too restrictive to be used by yourself TensorFlow fashions. One doable choice is to desert the ONNX-to-PyTorch conversion and carry out inference utilizing the ONNX Runtime library.

    Mannequin Conversion Through Keras3

    Keras3 is a high-level deep studying API targeted on maximizing the readability, maintainability, and ease of use of AI/ML purposes. In a previous post, we evaluated Keras3 and highlighted its assist for a number of backends. On this publish we revisit its multi-framework assist and assess whether or not this may be utilized for mannequin conversion. The scheme we suggest is to 1) migrate the existing TensorFlow model to Keras3 after which 2) run the mannequin with the Keras3 PyTorch backend.

    Upgrading TensorFlow to Keras3

    Opposite to the ONNX-based conversion scheme, our present resolution could require some code adjustments to the TensorFlow mannequin emigrate it to Keras3. Whereas the documentation makes it sound easy, in apply the issue of the migration will rely enormously on the main points of the mannequin implementation. Within the case of our toy mannequin, HuggingFace explicitly enforces using the legacy tf-keras, stopping using Keras3. To implement our scheme, we have to 1) redefine the mannequin with out this restriction, and a couple of) exchange native TensorFlow operators with Keras3 equivalents. The code block under comprises a stripped-down model of the mannequin, together with the required changes. To get a full grasp of the adjustments that have been required, carry out a side-by-side code comparability with the original model definition.

    import math
    import keras
    
    HIDDEN_SIZE = 768
    IMG_SIZE = 224
    PATCH_SIZE = 16
    ATTN_HEADS = 12
    NUM_LAYERS = 12
    INTER_SZ = 4*HIDDEN_SIZE
    N_LABELS = 2
    
    
    class TFViTEmbeddings(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.patch_embeddings = TFViTPatchEmbeddings()
            num_patches = self.patch_embeddings.num_patches
            self.cls_token = self.add_weight((1, 1, HIDDEN_SIZE))
            self.position_embeddings = self.add_weight((1, num_patches+1,
                                                        HIDDEN_SIZE))
    
        def name(self, pixel_values, coaching=False):
            bs, num_channels, top, width = pixel_values.form
            embeddings = self.patch_embeddings(pixel_values, coaching=coaching)
            cls_tokens = keras.ops.repeat(self.cls_token, repeats=bs, axis=0)
            embeddings = keras.ops.concatenate((cls_tokens, embeddings), axis=1)
            embeddings = embeddings + self.position_embeddings
            return embeddings
    
    class TFViTPatchEmbeddings(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            patch_size = (PATCH_SIZE, PATCH_SIZE)
            image_size = (IMG_SIZE, IMG_SIZE)
            num_patches = (image_size[1]//patch_size[1]) * 
                          (image_size[0]//patch_size[0])
            self.patch_size = patch_size
            self.num_patches = num_patches
            self.projection = keras.layers.Conv2D(
                filters=HIDDEN_SIZE,
                kernel_size=patch_size,
                strides=patch_size,
                padding="legitimate",
                data_format="channels_last"
            )
    
        def name(self, pixel_values, coaching=False):
            bs, num_channels, top, width = pixel_values.form
            pixel_values = keras.ops.transpose(pixel_values, (0, 2, 3, 1))
            projection = self.projection(pixel_values)
            num_patches = (width // self.patch_size[1]) * 
                          (top // self.patch_size[0])
            embeddings = keras.ops.reshape(projection, (bs, num_patches, -1))
            return embeddings
    
    class TFViTSelfAttention(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.num_attention_heads = ATTN_HEADS
            self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS)
            self.all_head_size = ATTN_HEADS * self.attention_head_size
            self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
            self.question = keras.layers.Dense(self.all_head_size,  title="question")
            self.key = keras.layers.Dense(self.all_head_size, title="key")
            self.worth = keras.layers.Dense(self.all_head_size, title="worth")
    
        def transpose_for_scores(self, tensor, batch_size: int):
            tensor = keras.ops.reshape(tensor, (batch_size, -1, ATTN_HEADS,
                                                self.attention_head_size))
            return keras.ops.transpose(tensor, [0, 2, 1, 3])
    
        def name(self, hidden_states, coaching=False):
            bs = hidden_states.form[0]
            mixed_query_layer = self.question(inputs=hidden_states)
            mixed_key_layer = self.key(inputs=hidden_states)
            mixed_value_layer = self.worth(inputs=hidden_states)
            query_layer = self.transpose_for_scores(mixed_query_layer, bs)
            key_layer = self.transpose_for_scores(mixed_key_layer, bs)
            value_layer = self.transpose_for_scores(mixed_value_layer, bs)
            key_layer_T = keras.ops.transpose(key_layer, [0,1,3,2])
            attention_scores = keras.ops.matmul(query_layer, key_layer_T)
            dk = keras.ops.forged(self.sqrt_att_head_size,
                                dtype=attention_scores.dtype)
            attention_scores = keras.ops.divide(attention_scores, dk)
            attention_probs = keras.ops.softmax(attention_scores+1e-9, axis=-1)
            attention_output = keras.ops.matmul(attention_probs, value_layer)
            attention_output = keras.ops.transpose(attention_output,[0,2,1,3])
            attention_output = keras.ops.reshape(attention_output,
                                                 (bs, -1, self.all_head_size))
            return (attention_output,)
    
    class TFViTSelfOutput(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.dense = keras.layers.Dense(HIDDEN_SIZE)
    
        def name(self, hidden_states, input_tensor, coaching = False):
            return self.dense(inputs=hidden_states)
    
    class TFViTAttention(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.self_attention = TFViTSelfAttention()
            self.dense_output = TFViTSelfOutput()
    
        def name(self, input_tensor, coaching = False):
            self_outputs = self.self_attention(
                hidden_states=input_tensor, coaching=coaching
            )
            attention_output = self.dense_output(
                hidden_states=self_outputs[0],
                input_tensor=input_tensor,
                coaching=coaching
            )
            return (attention_output,)
    
    class TFViTIntermediate(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.dense = keras.layers.Dense(INTER_SZ)
            self.intermediate_act_fn = keras.activations.gelu
    
        def name(self, hidden_states):
            hidden_states = self.dense(hidden_states)
            hidden_states = self.intermediate_act_fn(hidden_states)
            return hidden_states
    
    class TFViTOutput(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.dense = keras.layers.Dense(HIDDEN_SIZE)
    
        def name(self, hidden_states, input_tensor, coaching: bool = False):
            hidden_states = self.dense(inputs=hidden_states)
            hidden_states = hidden_states + input_tensor
            return hidden_states
    
    class TFViTLayer(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.consideration = TFViTAttention()
            self.intermediate = TFViTIntermediate()
            self.vit_output = TFViTOutput()
            self.layernorm_before = keras.layers.LayerNormalization(
                epsilon=1e-12
            )
            self.layernorm_after = keras.layers.LayerNormalization(
                epsilon=1e-12
            )
    
        def name(self, hidden_states, coaching=False):
            attention_outputs = self.consideration(
                input_tensor=self.layernorm_before(inputs=hidden_states),
                coaching=coaching,
            )
            attention_output = attention_outputs[0]
            hidden_states = attention_output + hidden_states
            layer_output = self.layernorm_after(hidden_states)
            intermediate_output = self.intermediate(layer_output)
            layer_output = self.vit_output(
                hidden_states=intermediate_output,
                input_tensor=hidden_states,
                coaching=coaching
            )
            outputs = (layer_output,)
            return outputs
    
    class TFViTEncoder(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.layer = [TFViTLayer(name=f"layer_{i}")
                          for i in range(NUM_LAYERS)]
    
        def name(self, hidden_states, coaching=False):
            for i, layer_module in enumerate(self.layer):
                layer_outputs = layer_module(
                    hidden_states=hidden_states,
                    coaching=coaching,
                )
                hidden_states = layer_outputs[0]
            return tuple([hidden_states])
    
    class TFViTMainLayer(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.embeddings = TFViTEmbeddings()
            self.encoder = TFViTEncoder()
            self.layernorm = keras.layers.LayerNormalization(epsilon=1e-12)
    
        def name(self, pixel_values, coaching=False):
            embedding_output = self.embeddings(
                pixel_values=pixel_values,
                coaching=coaching,
            )
            encoder_outputs = self.encoder(
                hidden_states=embedding_output,
                coaching=coaching,
            )
            sequence_output = encoder_outputs[0]
            sequence_output = self.layernorm(inputs=sequence_output)
            return (sequence_output,)
    
    class TFViTForImageClassification(keras.Mannequin):
        def __init__(self, *inputs, **kwargs):
            tremendous().__init__(*inputs, **kwargs)
            self.vit = TFViTMainLayer()
            self.classifier = keras.layers.Dense(N_LABELS)
    
        def name(self, pixel_values, coaching=False):
            outputs = self.vit(pixel_values, coaching=coaching)
            sequence_output = outputs[0]
            logits = self.classifier(inputs=sequence_output[:, 0, :])
            return (logits,)

    TensorFlow to PyTorch Conversion

    The conversion sequence seems within the code block under. As earlier than, we validate the output of the resultant mannequin in addition to the variety of trainable parameters.

    # save weights of TensorFlow mannequin
    tf_model.save_weights("model_weights.h5")
    
    import keras
    keras.config.set_backend("torch")
    
    from keras3_vit import TFViTForImageClassification as Keras3ViT
    keras3_model = Keras3ViT()
    
    # name mannequin to initializate all layers
    keras3_model(torch_input, coaching=False)
    
    # load the weights from the TensorFlow mannequin
    keras3_model.load_weights("model_weights.h5")
    
    # validate transformed mannequin
    assert isinstance(keras3_model, torch.nn.Module)
    
    keras3_model = keras3_model.to(DEVICE)
    keras3_model = keras3_model.eval()
    torch_output = keras3_model(torch_input, coaching=False)
    torch_output = torch_output[0].detach().cpu().numpy()
    print("Max diff:", np.max(np.abs(tf_output - torch_output)))
    
    num_pyt_params = sum([p.numel()
                          for p in keras3_model.parameters()
                          if p.requires_grad])
    print(f"Keras3 Trainable Parameters: {num_pyt_params:,}")

    Coaching/Superb-tuning the Mannequin

    Opposite to the ONNX-converted mannequin, the Keras3 mannequin maintains the identical construction and trainable parameters. This enables for resuming coaching and/or finetuning on the transformed mannequin. This may both be executed throughout the Keras3 training framework or utilizing a standard PyTorch training loop.

    Optimizing Mannequin Layers

    Opposite to the ONNX-converted mannequin, the coherence of the Keras3 mannequin definition permits for simply modifying and optimizing the layer implementations. Within the code block under, we exchange the present consideration mechanism with PyTorch’s extremely environment friendly SDPA operator.

    from torch.nn.practical import scaled_dot_product_attention as sdpa
    
    class TFViTSelfAttention(keras.layers.Layer):
        def __init__(self, **kwargs):
            tremendous().__init__(**kwargs)
            self.num_attention_heads = ATTN_HEADS
            self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS)
            self.all_head_size = ATTN_HEADS * self.attention_head_size
            self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
            self.question = keras.layers.Dense(self.all_head_size,  title="question")
            self.key = keras.layers.Dense(self.all_head_size, title="key")
            self.worth = keras.layers.Dense(self.all_head_size, title="worth")
    
        def transpose_for_scores(self, tensor, batch_size: int):
            tensor = keras.ops.reshape(tensor, (batch_size, -1, ATTN_HEADS,
                                                self.attention_head_size))
            return keras.ops.transpose(tensor, [0, 2, 1, 3])
    
        def name(self, hidden_states, coaching=False):
            bs = hidden_states.form[0]
            mixed_query_layer = self.question(inputs=hidden_states)
            mixed_key_layer = self.key(inputs=hidden_states)
            mixed_value_layer = self.worth(inputs=hidden_states)
            query_layer = self.transpose_for_scores(mixed_query_layer, bs)
            key_layer = self.transpose_for_scores(mixed_key_layer, bs)
            value_layer = self.transpose_for_scores(mixed_value_layer, bs)
            sdpa_output = sdpa(query_layer, key_layer, value_layer)
            attention_output = keras.ops.transpose(sdpa_output,[0,2,1,3])
            attention_output = keras.ops.reshape(attention_output,
                                                 (bs, -1, self.all_head_size))
            return (attention_output,)

    We utilizing the identical benchmarking operate from above to evaluate the influence of this optimization on the mannequin’s runtime efficiency:

    torch_infer_fn = get_torch_infer_fn(keras3_model)
    avg_time = benchmark(torch_infer_fn, torch_input)
    print(f"Keras3 transformed mannequin common step time: {(avg_time):.4f}")

    The outcomes are captured within the desk under:

    Keras3 Conversion Step Time Outcomes (by Creator)

    Utilizing the Keras3-based mannequin conversion scheme, and making use of the SDPA optimization, we’re capable of speed up the mannequin inference throughput by 22% in comparison with the unique TensorFlow mannequin.

    Mannequin Compilation

    One other optimization we want to apply is PyTorch compilation. Sadly (as of the time of this writing), PyTorch compilation in Keras3 is restricted. Within the case of our toy mannequin, each our try to use torch.compile on to the mannequin, in addition to setting the jit_compile area of the Keras3 Model.compile operate, failed. In each instances, the failure resulted from a number of recompilations that have been triggered by the Keras3 inner equipment. Whereas Keras3 grants entry to the PyTorch ecosystem, its high-level abstraction would possibly impose some limitations.

    Limitations

    As soon as once more, we’ve a conversion scheme that works however has a number of limitations:

    1. The TensorFlow fashions should be Keras3-compatible. The quantity of labor this can require will depend upon the main points of your mannequin implementation. It might require some Keras layer customization.
    2. Whereas the resultant mannequin is a torch.nn.Module, it’s not a “pure” PyTorch mannequin within the sense that it’s comprised of Keras3 layers and contains loads of extra Keras3 code. This will likely require some variations to our PyTorch tooling and will impose some restrictions, as we noticed once we tried to use PyTorch compilation.
    3. The answer depends on the well being and upkeep of Keras3 and its assist for the TensorFlow and PyTorch backends.

    Abstract

    On this publish we’ve proposed and assessed two strategies for auto-conversion of legacy TensorFlow fashions to PyTorch. We summarize our findings within the following desk.

    Comparability of Conversion Schemes (by Creator)

    In the end, the very best strategy, whether or not it’s one of many strategies mentioned right here, guide conversion, an answer based mostly on generative AI, or the choice to not carry out conversion in any respect, will enormously depend upon the main points of the mannequin and the state of affairs.



    Source link

    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Editor Times Featured
    • Website

    Related Posts

    Escaping the Valley of Choice in BI

    June 2, 2026

    Ensuring Data Integrity with Cryptographic Hashing and the Ethereum Blockchain

    June 1, 2026

    RAG Is Not Machine Learning, and the ML Toolkit Solves the Wrong Problem

    June 1, 2026

    How to Combine Claude Code and Codex for Maximum Coding Power

    June 1, 2026

    It’s the Lessons We Learned Along the Way. Or, Is It?

    June 1, 2026

    Proxy-Pointer RAG: Eliminating Wasteful Entity & Relations Extraction in Knowledge Graphs

    May 31, 2026

    Comments are closed.

    Editors Picks

    CFTC seeks injunction in Kalshi Rhode Island dispute

    June 2, 2026

    As AI Expands, Erin Brockovich Taps Communities to Map Data Center Concerns

    June 2, 2026

    Direct-to-Cell Technology: Enabling Satellite Connectivity for Legacy Devices

    June 2, 2026

    How small businesses can leverage AI

    June 2, 2026
    Categories
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    About Us
    About Us

    Welcome to Times Featured, an AI-driven entrepreneurship growth engine that is transforming the future of work, bridging the digital divide and encouraging younger community inclusion in the 4th Industrial Revolution, and nurturing new market leaders.

    Empowering the growth of profiles, leaders, entrepreneurs businesses, and startups on international landscape.

    Asia-Middle East-Europe-North America-Australia-Africa

    Facebook LinkedIn WhatsApp
    Featured Picks

    Our Favorite Affordable Air Purifier Is Temporarily Even Cheaper

    March 31, 2026

    MIT’s IoT Chip Advances 5G Internet of Things

    July 10, 2025

    Handwriting changes signal early age-related cognitive decline

    May 25, 2026
    Categories
    • Founders
    • Startups
    • Technology
    • Profiles
    • Entrepreneurs
    • Leaders
    • Students
    • VC Funds
    Copyright © 2024 Timesfeatured.com IP Limited. All Rights.
    • Privacy Policy
    • Disclaimer
    • Terms and Conditions
    • About us
    • Contact us

    Type above and press Enter to search. Press Esc to cancel.