Supercharge Python with DeepMind's JAX

Supercharge Python with DeepMind's JAX

Ever hoped for numpy to offer automatic differentiation and run math computations on a GPU? You might find?DeepMind's JAX?to be the solution.?

In this piece, we'll delve into JAX's automatic differentiation capabilities and assess how its just-in-time execution compares to?numpy.

What you will learn: Applying DeepMind's JAX for speed up numerical computation and differentiation in Python.

Notes:

  • Library versions: python 3.11, JAX 0.4.18, Jax-metal 0.0.4 (Mac M1/M2), NumPy 1.26.0, matplotlib 3.8.0
  • The performance evaluation?Performance: JAX vs NumPy?relies on AWS?m4.2xlarge?EC2 instance for CPU and?p3.2xlargeinstance?equipped with 8 virtual cores, 64GB of memory, and an Nvidia V100 GPU.
  • JAX provides developers with a profiler to generate traces that can be visualized using the?Perfetto?visualizer.

Introduction

As a quick recall, NumPy stands as a Python library for numerical and scientific computation. It equips data scientists and engineers with capabilities for working with multidimensional arrays, performing speedy array operations, and handling fundamental tasks in linear algebra and statistics [ref?1].

JAX [ref?2] is a numerical computing and machine learning library in Python, developed by DeepMind, that builds upon the foundation of NumPy. JAX offers:

  • Composable function transformations.
  • Auto-vectorization of data batches, enabling parallel processing.
  • First and second-order automatic differentiation for various numerical functions.
  • Just-in-time compilation for GPU execution [ref 3].

Components

  • AutoGrad: Upgraded to improve performance of automatic differentiation.
  • Accelerated Linear Algebra (XLA):?JAX uses XLA?to compile and run your NumPy code on accelerators.
  • Just-in-time compilation (JIT): Running on XLA
  • Perfetto: Visualization of profiler trace data.

Installation

Here is an overview of basic steps for installing JAX. It is advisable to consult the installation guide as each environment has specific requirements [ref 4].

CPU (MacOS, Linux)

  • pip install --upgrade "jax[cpu]"

GPU (Linux/CUDA)

  1. nvcc --version ? # -> ve to be used in the?
  2. pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

GPU (MacOS/mps)

  • python3 -m venv ~/jax-metal
  • source ~/jax-metal/bin/activate
  • python -m pip install jax-metal
  • pip install ml_dtypes==0.2.0

Conda

  • conda install jax -c conda-forge

Automatic differentiation

Overview

Automatic differentiation is a tool that facilitates the automatic calculation of derivatives for a specified mathematical function [ref :].

This technique efficiently determines precise derivatives by retaining details during the forward pass, which are then utilized during the backward pass. Essentially,?

  • It interprets a code that calculates a function and leverages it to compute the function's derivative.?
  • It crafts a software approach to efficiently determine the derivatives, bypassing the necessity for a closed-form solution.

This article focuses on the Forward Mode Automatic Differentiation, which consists of replacing each primitive operation in the original program by its differential analogue.

To illustrate the concept, let consider the function:

?Let's build its forward computation graph:

fig 1. Simplified forward computation graph

Notes:?

  • This computation graph does not include data type conversion (Python values to JAX or NumPy arrays).
  • The limitation of the forward mode is that the gradient is computed by re-executing the program all over again. The solution is to stored the derivatives to be chained and computed during a backward path: Reverse Model Automatic Differentiation.

Single variable function

Let's implement a class, JaxDifferentiation?that wraps the computation of first, second, ... derivatives of a function with a single variable: func: R→R

.The derivative of various orders are computed in the ?constructor. The method?'__call__' return the list [f, f', f", ..].

class JaxDifferentiation(object):
    """
        Create a set of derivatives of first, second, ... order_derivative order
        :param func Differentiable function 
        :param Order of derivatives
    """
    def __init__(self, 
                func: Callable[[float], float], 
                order_derivative: int):
        assert order_derivative < 5, 
         f'Order derivatives {order_derivative} should be [0, 4]'
        
        # Build list of derivative f, f', f", ....
        self.derivatives: List[Callable[[float], float]] = [func]
        temp = func
        if order_derivative > 0:
            for order in range(order_derivative):
                # Compute the single variable next order derivative
                temp = jnp.grad(temp)
                self.derivatives.append(temp)


    def __call__(self, x: float) -> List[float]:
        """ Compute derivatives of all orders for value x"""
        return [derivative(x) for derivative in self.derivatives]        

Let's compute the derivative of the following function.

The first and second derivatives are provided for evaluation purpose (Oracle).

# Function definition
def func1(x: float) -> float:
   return 2.0*x**4 + x**3

# First order derivative
def dfunc1(x: float) -> float:
   return 8.0*x**3 + 3*x**2

# Second order derivative
def ddfunc1(x: float) -> float:
   return 24.0*x**2 + 6.0*x

funcs1 = [func1, dfunc1, ddfunc1]
jax_differentiation = JaxDifferentiation(func1, len(funcs1))
compared = [f'{oracle}, {jax_value}, {oracle-jax_value}'
         for oracle, jax_value in zip([func(y) for func in funcs1], 
           jax_differentiation(2.0))]
print(compared)        

Output

Oracle, Jax, ?? Difference

40.0,??? ?40.0,?? ?0.0

76.0,??? ?76.0,?? ?0.0

108.0, 108.0,?? ?0.0

Multi-variable function

The next step is to evaluate the computation of partial derivative of a multi-variable function f(x, y,...).

Let's consider the following function for which the first order partial derivative (Jacobian vector) is provided.?

# Function definition
def func2(x: List[float]) -> float:
  return 2.0*x[0]*x[0] - 3.0*x[0]*x[1] + x[2]

# Partial derivative over x
def dfunc2_x(x: List[float]) -> float:
  return 4.0*x[0] - 3.0*x[1]

# Partial derivative over y
def dfunc2_y(x: List[float]) -> float:
  return -3.0*x[0]

# Partial derivative over z
def dfunc2_z(x: List[float]) -> float:
  return 1.0        

Let's compare the output of the direct computation of the symbolic derivatives (Oracle)?dfunc2_x,?dfunc2_y?and?dfunc2_z?with the partial derivatives computed by JAX.

We use the?forward mode?automatic differentiation function?jacfwd?to compute the gradient [ref?6].?

# Invoke the Jacobian vector forward function 
dfunc2 = jnp.jacfwd(func2)

y = [2.0, -1.0, 6.0]
derivatives = dfunc2(y)

print(f'df/dx: {derivatives[0]}, {dfunc2_x(y)}\ndf/dy: {derivatives[1]}, {dfunc2_y(y)}\ndf/dz: {derivatives[2]}, {dfunc2_z(y)}'

)        

Output? ? Oracle. Jax

df/dx ?? 11.0 ?? ?11.0

df/dy ?? ? -6.0 ?? ?-6.0

df/dz. ??1.0, ?? ? ??1.0

Note: The reverse mode automatic differentiation Jax method, jacrev would have produce the same result.

Performance: JAX vs. NumPy

A significant drawback of the NumPy library is its absence of GPU support. The next objective is to measure the performance gains achieved by JAX, with and without its just-in-time compiler, on both CPU and GPU.

To facilitate this, we will establish a class named JaxNumpyData containing two functions: np_func, which utilizes NumPy, and jnp_func, its JAX counterpart. These functions will be applied to datasets of various sizes. The compare method will extract 20 subsets from the initial dataset by employing a basic fraction-based approach.

class JaxNumpyData(object):
    """
    Initialize the numpy and Jax function to process data (arrays)
    :param np_function Numpy numerical function
    :param jnp_function Corresponding Jax numerical function
    """
    def __init__(self,
                 np_func: Callable[[np.array], np.array],
                 jnp_func: Callable[[jnp.array], jnp.array]):
        self.np_func = np_func
        self.jnp_func = jnp_func



    def compare(self, full_data_size: int, func_label: AnyStr):
        """
        Compare the Numpy and JAX computation of give dataset
        :param full_data_size Size of the original dataset used 
               to extract sub-data set
        :param func_label Label used for performance results and   
               plotting
        """
        for index in range(1, 20):
            fraction = 0.05 * index
            data_size = int(full_data_size*fraction)

           # Execute on full_data_size*fraction elements using Numpy
            x_0 = np.linspace(0.0, 100.0, data_size)
            result1 = self.map_numpy(x_0, f'numpy_{func_label}')

            # Execute on full_data_size*fraction elements using 
            # JAX and JAX-JIT
            x_1 = jnp.linspace(0.0, 100.0, data_size)
            result2 = self.map_jax(x_1, f'jax_{func_label}')
            result3 = self.map_jif(x_1, f'jif_{func_label}')
            
            del x_0, x_1, result1, result2, result3

     
    """ 
    Process numpy array, np_x through numpy function np_func 
    """
    @time_it
    def map_numpy(self, np_x: np.array, label: AnyStr) -> np.array:
        return self.np_func(np_x)


    """ Process Jax array, jnp_x through Jax function jnp_func """
    @time_it
    def map_jax(self, jnp_x: jnp.array, label: AnyStr) -> jnp.array:
        return self.jnp_func(jnp_x)


    """ 
    Process Jax array, jnp_x through Jax function jnp_func using JIT 
    """
    @time_it
    def map_jif(self, jnp_x: jnp.array, label: AnyStr) -> jnp.array:
        from jax import jit
        return jit(self.jnp_func)(jnp_x)        

The method map_numpy (resp. map_jax and map_jit) applies the NumPy method np_func (resp. JAX method jnp_func) to the NumPy array np_array (resp. JAX array jnp_array).

CPU

In this first performance test, we measure the duration to compute?

?on 1,000,000,000 values using NumPy, JAX w/o just in time compiler.

def np_func1(x: np.array) -> np.array:
    return np.sinh(x) + np.cos(x)

def jnp_func1(x: jnp.array) -> jnp.array:
    return jnp.sinh(x) + jnp.cos(x)        

The JAX produces a 7 fold performance improvement over NumPy. The just in time processor adds another 35% improvement.

The second latency test computes the mean value:

?of a NumPy and JAX array for 1,200,000,000 values.

def np_func2(x: np.array) -> np.array:
    return np.mean(x)

def jnp_func2(x: jnp.array) -> jnp.array:
    return jnp.mean(x)        

The just-in-time processor outperforms both NumPy and JAX native library on CPU.

GPU

For this last test, we execute the function:

over 200,000,000 values on Nvidia V100 GPU.

As anticipated, NumPy is currently running on the CPU of the EC2 instance, which means it cannot match the performance of JAX running on the Nvidia processor.

Conclusion

In summary, JAX offers data scientists and machine learning engineers a high-performance GPU computing tool that significantly outperforms the NumPy library. Our exploration has only touched the surface of JAX's capabilities, and I encourage readers to delve deeper into features like Autobatching, Vectorization, Generalized convolutions, and its integration with PyTorch and TensorFlow.


Thank you for reading this article. For more information ...


References

[1]?NumPy user guide

[2]?JAX: High-Performance Array Computing

[3]?Just in time compilation with JAX

[4]?Installing JAX

[5]?An introduction to automatic differentiation

[6]?JAX Automatic Differentiation


Appendix

We include the decorator used for timing the execution of the various functions, for reference.

timing_stats = {}
def time_it(func):
    """ Decorator for timing execution of methods """
    def wrapper(*args, **kwargs):
        start = time.time()
        func(*args, **kwargs)
        duration = '{:.3f}'.format(time.time() - start)
        key: AnyStr = args[2]
        print(f'{key}\t{duration} secs.')
        cur_list = timing_stats.get(key)
        
        if cur_list is None:
            cur_list = [time.time() - start]
        else:
            cur_list.append(time.time() - start)
        timing_stats[key] = cur_list
        return 0
    return wrapper        

---------------------------

Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning.? He has been director of data engineering at Aideo Technologies since 2017 and he is the?author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3


#python #numpy #jax #automaticdifferentiation #vectorization #autograd #machinelearning

Ikram Benabdelouahab

Assistant Professor at Faculty of Science and Techniques Tangier

1 年

Interesting ??

要查看或添加评论,请登录

Patrick Nicolas的更多文章

  • Riemannian Manifolds for Geometric Learning

    Riemannian Manifolds for Geometric Learning

    Intrigued by the idea of applying differential geometry to machine learning but feel daunted? Beyond theoretical…

  • Einstein Summation in Geometric Deep Learning

    Einstein Summation in Geometric Deep Learning

    The einsum function in NumPy and PyTorch, which implements Einstein summation notation, provides a powerful and…

  • Visualization of Graph Neural Networks

    Visualization of Graph Neural Networks

    Have you ever found it challenging to represent a graph from a very large dataset while building a graph neural network…

  • Modeling Graph Neural Networks with PyTorch

    Modeling Graph Neural Networks with PyTorch

    Have you ever wondered how to get started with Graph Neural Networks (GNNs)? Torch Geometric (PyG) provides a…

  • Approximating PCA on Manifolds

    Approximating PCA on Manifolds

    Have you ever wondered how to perform Principal Component Analysis on manifolds? An approximate solution relies on the…

  • Reviews of Papers on Geometric Learning - 2024

    Reviews of Papers on Geometric Learning - 2024

    2024 introduced a fascinating collection of papers on geometric deep learning. Here are reviews of a selection of them.

    1 条评论
  • Fréchet Centroid on Manifolds in Python

    Fréchet Centroid on Manifolds in Python

    The Fréchet centroid (or intrinsic centroid) is a generalization of the concept of a mean to data points that lie on a…

  • Einstein Summation in Numpy

    Einstein Summation in Numpy

    Many research papers use Einstein summation notation to describe mathematical concepts. Wouldn't it be great to have a…

  • Deep Learning on Mac Laptop

    Deep Learning on Mac Laptop

    The latest high-performance Mac laptops are well-suited for experimentation. However, have you been frustrated by your…

    1 条评论
  • Impact of Linear Activation on Convolution Networks

    Impact of Linear Activation on Convolution Networks

    Have you ever wondered how choosing an activation function can influence the performance of a convolutional neural…

社区洞察

其他会员也浏览了