This Python package implements the activation function transformations and weight initializations used in Deep Kernel Shaping (DKS) and Tailored Activation Transformations (TAT). DKS and TAT, which were introduced in the DKS paper and TAT paper, are methods for constructing/transforming neural networks to make them much easier to train. For example, these methods can be used in conjunction with K-FAC to train deep vanilla deep convnets (without skip connections or normalization layers) as fast as standard ResNets of the same depth.
The package supports the JAX, PyTorch, and TensorFlow tensor programming frameworks.
Questions/comments about the code can be sent to dks-dev@google.com.