Acoustic Hologram Optimisation Using Automatic Differentiation in JAX
On December 2020, Tatsuki Fushimi, Kenta Yamamoto & Yoichi Ochiai have submitted a preprint, later accepted by Scientific Reports, on using automatic differentiation for the optimization of acoustic holograms produced by phased arrays.
In the following, we will implement the main algorithm discussed in the paper using jax: this work provides a good case study for demonstrating the use of jax for scientific applications not related to machine learning.
Problem setup
Assume to have a transducer located in \(x_t\), which is transmitting a monochromatic (single frequency) signal with wavenumber
\[k = \frac{2\pi f_0}{c_0},\]where \(f_0\) is the transmit frequency and \(c_0\) is the speed of sound of the homogeneous medium. Then one can use a simplified version of the Rayleigh integral (see [Sapozhnikov et al., 2015] for a more general discussion) to calculate the pressure field at a location \(x_c\)
\[p_{c,t} = \frac{P_{ref}}{\|x_c - x_t\|}D(\theta)e^{j(k\|x_t-x_c\|+ \phi_t)}\]where \(P_{ref}\) is the pressure amplitude at the transducer, \(\phi_t\) is the phase of the transmit wave and
\[D(\theta) = \frac{2J_1(k r \sin(\theta))}{k r \sin(\theta)}\]is the directivity factor which depends on the angle \(\theta\) between the transducer normal and the vector \(x_t-x_c\). Here, the function \(J_1\) is the Bessel function of the first kind of order 1.
Note that the pressure is expressed as a complex number, as it is customary for time harmonic fields, in order to implicitly define the phase relationships between the field at various locations.
The last step is to use the superposition property deriving from the linearity of the wave equation to sum the contribution of \(M\) transducers in the phased array to the field
\[p(x_c) = \sum_{t=1}^{M} p_{c,t}\]Optimization
All is left to use automatic differentiation is to define a loss function. The authors have chosen to optimize the amplitude \(\|p(x_c)\|\) of the field by matching it against some known positive field \(A(x_c)\). Using a squared error distance, this reduces the loss function to
\[\mathcal L(p) = \frac{1}{|\Omega|}\int_\Omega (A(x) - |p(x)|)^2 dx \propto \sum_{x_c \in X} (A(x) - |p(x)|)^2\]for an appropriate dense set of positions \(X\), which we will take as equispaced points (i.e. pixels) to directly compare the field with a digital image.
Implementation
First of all, let’s import the required libraries:
Afterwards, we define some parameters that we will use throughout the following sections.
Forward functions
We can start implementing some functions! While doing that, we can focus on a single transducer and a single target point, as we will parallelize (actually, vectorize) everything later on using jax.vmap
.
The first function evaluates the angle \(\theta\) between a transducer normal and a target location, simply by using the arc-cosine of their normalized dot product:
At this point, we hit the first problem: we need the Bessel function of the first kind of order 1 to evaluate the directivity factor, but it looks like JAX doesn’t have it! However, according to Wikipedia it holds the relationship
\[-J_1(x) = \frac{\partial}{\partial x} J_0(x)\]where \(J_0(x)\) is the Bessel function of the first kind of order 0, which is implemented by jax in jax.numpy.i0
. So we can implement the Bessel function of the first kind using autodiff with jax.grad
(I’m 99% sure this is correct, but I’m not sure):
At this point, we can write the directivity function \(D(\theta)\) as
Having all the main ingredients setup, we can finally write the function that evaluates the beam-pattern of a single transducer (at a single location)
Vectorization
Adding the contribution of all the transducers can be easily done by vectorizing the function above with respect to the input positions, using jax.vmap
, and summing them all:
Similarly, we can get the field at all the positions by vectorizing the function above with respect to the target location x_c
Let’s look at the hologram for the initial, flat phase distribution
Loss function
Let’s start with a very simple image that we are trying to match
We are going to optimize a slightly different loss function than the Diff-PAT paper, namely the cross correlation, defined as
\[\mathcal L = \sum_i A(x_i) |p(x_i)|\]which is implemented as
Note that the loss function depends on the vector of phases for each transducer. To optimize it, we get the gradient using autodiff
Optimization
We are now all setup to optimize the loss function. All we need is an updated function that takes the current vector of phases and updates it using the gradient. We will use the Adam optimizer, as in the Diff-PAT paper.
As it is customary in JAX, we can use the jax.jit
to just-in-time compile this function for faster execution.
All is left to do now is to wrap the update
function in a loop that runs for a number of iterations. Note that we explicitly define a random seed for the random number generator, since this aids reproducibility and is anyhow necessary in JAX.
Results
After the optimization is over, which should be relatively fast especially if you are running jax
on a GPU, we can visualize the results:
It is fairly close to the target hologram, but not quite. One could experiment with different loss functions, or with different initial phases. Note however that we are currently only controlling the phase of the transducers. If one could also control the amplitude, than the wave propagator is a linear operator of the complex input parameters \(P_{t}e^{j\theta_t}\), making the MSE optimization problem convex and therefore uniquely solvable (up to a global phase shift).
Conclusions
In this tutorial, we have reproduced the Diff-PAT algorithm, and we have shown how JAX can be used to easily and efficiently prototype algorithms that are relevant for numerical physics methods, by exploiting its ability to conveniently transform functions in several ways.
A jupyter notebook implementing this tutorial can be found at the following GitHub repo.
The findings from Fushimi et al. could also be extended in a number of ways. For example, the hologram produced by a planar wavefront could be efficiently propagated in the Fourier domain: this is implemented in the angular_spectrum
function of the jwave
package.
References
- Fushimi, T., Yamamoto, K. & Ochiai, Y. Acoustic hologram optimisation using automatic differentiation. Sci Rep 11, 12678 (2021). https://doi.org/10.1038/s41598-021-91880-2