Finding the ground state of a dual-species Bose-Einstein condensate#

This example demonstrates the calculation of the ground state of a dual-species Bose-Einstein condensate. We use the Gross-Pitaevskii equation (GPE) to describe the dynamics of the two wavefunctions and perform an imaginary time evolution using the split-step method to find the ground state. For further details and for the specific configuration used in this example, refer to: Pichery et al., AVS Quantum Sci. 5, 044401 (2023); doi: 10.1116/5.0163850.

Since this example uses three dimensions it requires high performance and we recommend running it on a GPU in the examplescuda environment. For this reason it is also written in JAX with jax.lax.scan.

Imports and setup#

Import basic libraries and enable float64 precision types in JAX.

[1]:

from functools import reduce from typing import ( Dict, Iterable, Any, Optional, Tuple, TypedDict ) import os from scipy import constants # type: ignore import numpy as np import jax.numpy as jnp from jax import config import fftarray as fa # Enable double float precision for jax config.update("jax_enable_x64", True)

Physical setup#

Define the potential between the two interacting atomic species and basic physical constants.

[2]:
# --------------------
# Physical constants
# --------------------

hbar: float = constants.hbar
a_0: float = constants.physical_constants['Bohr radius'][0]
kb: float = constants.Boltzmann

# coupling constant (used in GPE)
def coupling_fun(m_red: float, a: float) -> float:
    return 2 * np.pi * hbar**2 * a / m_red

# Rubidium 87
m_rb87: float = 86.909 * constants.atomic_mass # The atom's mass in kg.
a_rb87: float = 98 * a_0 # s-wave scattering length
# coupling constant (used in GPE)
coupling_rb87: float = coupling_fun(0.5*m_rb87, a_rb87)

# Potassium 41
m_k41: float = 40.962 * constants.atomic_mass # The atom's mass in kg.
a_k41: float = 60 * a_0 # s-wave scattering length
# coupling constant (used in GPE)
coupling_k41: float = coupling_fun(0.5*m_k41, a_k41)

# Interspecies interaction
a_rb87_k41: float = 165.3 * a_0
reduced_mass_rb87_k41 = m_rb87 * m_k41 / (m_rb87 + m_k41)
# coupling constant (used in GPE)
coupling_rb87_k41: float = coupling_fun(reduced_mass_rb87_k41, a_rb87_k41)


# Define dual species GPE potentials
def gpe_potential_two_species(
    psi_pos_sq_1: fa.Array,
    psi_pos_sq_2: fa.Array,
    coupling_constant_1: float,
    coupling_constant_12: float,
    trap_potential_1: fa.Array,
    num_atoms_1: float,
    num_atoms_2: float,
) -> fa.Array:
    """
    Calculate the 2-species GPE potential for species number 1.
    This does not include the energies that only depend on the other species.
    This does not include the kinetic energy.
    """
    self_interaction = num_atoms_1 * coupling_constant_1 * psi_pos_sq_1
    interaction_12 = num_atoms_2 * coupling_constant_12 * psi_pos_sq_2
    return self_interaction + interaction_12 + trap_potential_1

Quantum mechanics helpers#

These functionalities can also be found in the matterwave library.

Get ground state of quantum harmonic oscillator.#

This function creates an fa.Array from the given coordinates and trap frequencies.

[3]:
def ground_state_ho(
    mass: float,
    omegas: Iterable[float],
    pos_coords: Iterable[fa.Array],
) -> fa.Array:
    psi = fa.full([], "pos", 1.0, xp=jnp)
    for omega, pos_1d in zip(omegas, pos_coords, strict=True):
        psi = (
            psi * (omega / (np.pi*hbar))**(1/4)
            * fa.exp(-mass * omega * (pos_1d**2)/(2*hbar))
        )
    norm = fa.integrate(fa.abs(psi)**2)
    return psi / fa.sqrt(norm)

Imaginary time propagation#

A basic implementation of imaginary time propagation using the split-step method.

[4]:

def split_step_imaginary_time( psi: fa.Array, V: fa.Array, dt: float, mass: float, ) -> fa.Array: """Perform an imaginary time split-step of second order in VPV configuration.""" # Calculate half step imaginary time potential propagator V_prop = fa.exp((-0.5*dt / hbar) * V) # Calculate full step imaginary time kinetic propagator (k_sq = kx^2 + ky^2 + kz^2) k_sq = reduce(lambda a,b: a+b, [ (2*np.pi * fa.coords_from_dim(dim, "freq", xp=jnp, dtype=jnp.float64))**2 for dim in psi.dims ]) T_prop = fa.exp(-dt * hbar * k_sq / (2*mass)) # Apply half potential propagator psi = V_prop * psi.into_space("pos") # Apply full kinetic propagator psi = T_prop * psi.into_space("freq") # Apply half potential propagator psi = V_prop * psi.into_space("pos") # Normalize after step state_norm = fa.integrate(fa.abs(psi)**2) psi = psi / fa.sqrt(state_norm) return psi

Energy computation#

Compute the kinetic and potential energy of a given wave function.

[5]:
def get_e_kin(
    psi: fa.Array,
    mass: float,
):
    """Calculate the kinetic energy of the wavefunction."""
    # Calculate k^2 = (2πf)^2
    ksq = reduce(lambda a,b: a+b, [
        (2*np.pi * fa.coords_from_dim(dim, "freq", xp=jnp, dtype=jnp.float64))**2
        for dim in psi.dims
    ])

    post_factor = hbar**2 / (2*mass)

    # Calculate |ψ(f)|^2
    wf_abs_sq = fa.abs(psi.into_space("freq"))**2

    # Calculate E_kin = <ψ|(hbar*k)^2/2m|ψ> = ∫|ψ(f)|^2 * k^2 df * hbar^2 / (2m)
    return fa.integrate(wf_abs_sq * ksq).values("freq") * post_factor

def get_e_pot(
        psi: fa.Array,
        V: fa.Array,
    ):
    return fa.integrate(
        fa.abs(psi.into_space("pos"))**2 * V,
        dtype="float64"
    ).values("pos") / (kb * 1e-6)

Definition of the state initialization and the step function#

The simulation loop is in this example written with jax.lax.scan in order to achieve high performance, especially on a GPU.

calc_ground_state_two_species sets up the domain and the physical problem and then runs imaginary_time_step_dual_species in a loop to execute the optimization.

[6]:
from jax.lax import scan
# Register fftarray pytree nodes for JAX
try:
    fa.jax_register_pytree_nodes()
except ValueError:
    print(
        "JAX pytree nodes registration failed. " \
        "Probably due to being already registered."
    )

class DualSpeciesProperties(TypedDict):
    psi_rb87: fa.Array
    psi_k41: fa.Array
    rb_potential: fa.Array
    k_potential: fa.Array
    num_atoms_rb87: float
    num_atoms_k41: float


def calc_ground_state_two_species(
    N_iter: int,
    plot_dir: Optional[str] = None,
) -> Tuple[DualSpeciesProperties, Dict[str, Any]]:

    if plot_dir and not os.path.exists(plot_dir):
        os.makedirs(plot_dir)

    dt_list = np.full(N_iter, 5e-7)

    # Number of atoms
    num_atoms_rb87 = 43900
    num_atoms_k41 = 14400

    rb_omega_x = 2*np.pi * 24.8 # rad/s
    rb_omega_y = 2*np.pi * 378.3 # rad/s
    rb_omega_z = 2*np.pi * 384.0 # rad/s

    k_omega_x = rb_omega_x * np.sqrt(m_rb87/m_k41)
    k_omega_y = rb_omega_y * np.sqrt(m_rb87/m_k41)
    k_omega_z = rb_omega_z * np.sqrt(m_rb87/m_k41)

    # --------------------
    # fftarray definitions
    # --------------------

    # Define dimensions

    x_dim = fa.dim_from_constraints(
        "x",
        pos_extent=400e-6,
        n=2**9,
        freq_middle=0.,
        pos_middle=0.,
        dynamically_traced_coords=False,
    )

    y_dim = fa.dim_from_constraints(
        "y",
        pos_extent=50e-6,
        n=2**8,
        freq_middle=0.,
        pos_middle=0.,
        dynamically_traced_coords=False,
    )

    z_dim = fa.dim_from_constraints(
        "z",
        pos_extent=50e-6,
        n=2**7,
        freq_middle=0.,
        pos_middle=0.,
        dynamically_traced_coords=False,
    )

    # Define 1d arrays
    x: fa.Array = fa.coords_from_dim(x_dim, "pos", xp=jnp, dtype=jnp.float64)
    y: fa.Array = fa.coords_from_dim(y_dim, "pos", xp=jnp, dtype=jnp.float64)
    z: fa.Array = fa.coords_from_dim(z_dim, "pos", xp=jnp, dtype=jnp.float64)

    # Define 3d arrays for potential
    rb_potential = 0.5 * m_rb87 * (
        rb_omega_x**2 * x**2
        + rb_omega_y**2 * y**2
        + rb_omega_z**2 * z**2
    )
    k_potential = 0.5 * m_k41 * (
        k_omega_x**2 * x**2
        + k_omega_y**2 * y**2
        + k_omega_z**2 * z**2
    )

    # Define 3d arrays for initial wavefunction
    init_psi = fa.full((x_dim, y_dim, z_dim), "pos", 1, xp=jnp, dtype=jnp.float64)
    state_norm = fa.integrate(fa.abs(init_psi)**2)

    init_psi_rb = init_psi / fa.sqrt(state_norm)
    init_psi_k = init_psi / fa.sqrt(state_norm)

    # When using jax.lax.scan, the input fa.Array must have the same properties as
    # the output one. As the scanned method imaginary_time_step_dual_species
    # returns the fa.Array with space="pos" and factors_applied=False,
    # we transform the input state accordingly.
    init_psi_rb = init_psi_rb.into_space("pos").into_factors_applied(False)
    init_psi_k = init_psi_k.into_space("pos").into_factors_applied(False)

    init_properties: DualSpeciesProperties = {
        "psi_rb87": init_psi_rb,
        "psi_k41": init_psi_k,
        "rb_potential": rb_potential,
        "k_potential": k_potential,
        "num_atoms_rb87": num_atoms_rb87,
        "num_atoms_k41": num_atoms_k41
    }

    res: Tuple[DualSpeciesProperties, Dict[str, Any]] = scan(
        f=imaginary_time_step_dual_species, # type: ignore
        init=init_properties,
        xs=dt_list,
    )

    return res



def imaginary_time_step_dual_species(
    properties: DualSpeciesProperties,
    dt: float,
) -> Tuple[DualSpeciesProperties, Dict[str, float]]:
    """
    Perform a single imaginary time step for the dual species GPE.
    Additionally, calculate all relevant energies.
    The states are returned in position space.
    """

    psi_rb87 = properties["psi_rb87"]
    psi_k41 = properties["psi_k41"]
    rb_potential = properties["rb_potential"]
    k_potential = properties["k_potential"]
    num_atoms_rb87 = properties["num_atoms_rb87"]
    num_atoms_k41 = properties["num_atoms_k41"]

    ## Calculate the potential energy operators (used for split-step and plots)
    psi_rb87 = psi_rb87.into_space("pos")
    psi_k41 = psi_k41.into_space("pos")

    psi_pos_sq_rb87 = fa.abs(psi_rb87)**2
    psi_pos_sq_k41 = fa.abs(psi_k41)**2

    self_interaction_rb87 = num_atoms_rb87 * coupling_rb87 * psi_pos_sq_rb87
    interaction_rb87_k41 = num_atoms_k41 * coupling_rb87_k41 * psi_pos_sq_k41
    V_rb87 = self_interaction_rb87 + interaction_rb87_k41 + rb_potential

    self_interaction_k41 = num_atoms_k41 * coupling_k41 * psi_pos_sq_k41
    interaction_k41_rb87 = num_atoms_rb87 * coupling_rb87_k41 * psi_pos_sq_rb87
    V_k41 = self_interaction_k41 + interaction_k41_rb87 + k_potential

    ## Calculate energies for plotting and convergence check

    # Calculate potential energy (in µK)
    E_pot_rb87 = get_e_pot(psi_rb87, V_rb87)
    E_pot_k41 = get_e_pot(psi_k41, V_k41)

    # Calculate kinetic energy (in µK)
    E_kin_rb87: float = get_e_kin(psi_rb87, mass=m_rb87) / (kb * 1e-6)
    E_kin_k41: float = get_e_kin(psi_k41, mass=m_k41) / (kb * 1e-6)

    # Calculate the total energy
    E_tot = E_kin_rb87 + E_kin_k41 + E_pot_rb87 + E_pot_k41
    ## Imaginary time split step application

    psi_rb87 = split_step_imaginary_time(
        psi=psi_rb87,
        V=V_rb87,
        dt=dt,
        mass=m_rb87,
    )
    psi_k41 = split_step_imaginary_time(
        psi=psi_k41,
        V=V_k41,
        dt=dt,
        mass=m_k41,
    )


    return properties | {"psi_rb87": psi_rb87, "psi_k41": psi_k41}, {
        "E_kin_rb87": E_kin_rb87,
        "E_kin_k41": E_kin_k41,
        "E_pot_rb87": E_pot_rb87,
        "E_pot_k41": E_pot_k41,
        "E_tot": E_tot
    }

Run the simulation#

This cell takes about 1 minute for the 5000 steps on an NVIDIA A100 GPU and may take a long time when only running on a CPU.

[ ]:
# Reducing the number of steps in test mode
# allows running this notebook as part of the test suite.
if os.environ.get("TEST_MODE", "False") == "True":
    N_iter=2
else:
    # This branch is taken when executed in a normal notebook
    N_iter=5000

final_properties, energies = calc_ground_state_two_species(N_iter=N_iter)

Plotting the results#

[8]:
import matplotlib.pyplot as plt
from bokeh.plotting import show
from bokeh.io import output_notebook
from helpers import plt_integrated_1d_densities, plt_integrated_2d_density

output_notebook(hide_banner=True)

Final state#

We normalize the final states to their respective number of atoms:

[9]:
rb_ground_state = (
    final_properties["psi_rb87"].into_space("pos")
    * np.sqrt(final_properties["num_atoms_rb87"])
)
k_ground_state = (
    final_properties["psi_k41"].into_space("pos")
    * np.sqrt(final_properties["num_atoms_k41"])
)

Integrating in the \(z\) dimension yields a two dimensional plot per species and space. Use the zoom tool in the plots to zoom in on the details of the simulation.

[10]:
show(
    plt_integrated_2d_density(
        rb_ground_state,
        red_dim_name="z",
        data_name="Rb87",
        title_prefix="N",
    )
)
[11]:
show(
    plt_integrated_2d_density(
        k_ground_state,
        red_dim_name="z",
        data_name="K41",
        title_prefix="N",
    )
)

Integrating in \(y\) and \(z\) yields allows us to plot the probability densities of both species:

[12]:
show(
    plt_integrated_1d_densities(
        arrs={"Rb87": rb_ground_state, "K41": k_ground_state},
        red_dim_names=["y", "z"],
        # Zoom in a bit to see the spatial structures better.
        x_range_pos=(-4e-5, +4e-5),
        x_range_freq=(-1e5, +1e5),
        y_label_prefix="N",
    )
)

Energies as a function of iteration steps#

[13]:
COLORS = ["#CC6677", "#88CCEE", "#DDCC77", "#332288", "#117733"]

fig, ax1 = plt.subplots()

ax1.plot(energies["E_kin_rb87"], label="E_kin_rb87", color=COLORS[0])
ax1.plot(energies["E_kin_k41"], label="E_kin_k41", color=COLORS[1])
ax1.plot(energies["E_pot_rb87"], label="E_pot_rb87", color=COLORS[2])
ax1.plot(energies["E_pot_k41"], label="E_pot_k41", color=COLORS[3])
ax1.plot(energies["E_tot"], label="E_tot", color=COLORS[4])

ax1.set_title(f"Final energy of {energies['E_tot'][-1]:.2f} µK")
ax1.set_xlabel("Iteration step m")
ax1.set_ylabel("Energy (µK)")
ax1.set_yscale("log")
ax1.set_ylim(bottom=1e-3)
ax1.legend(loc="upper left")

ax2 = ax1.twinx()
relative_change = np.abs(np.diff(energies["E_tot"]) / energies["E_tot"][:-1])
ax2.plot(
    np.arange(1, len(energies["E_tot"])),
    relative_change,
    color="gray",
    linestyle="--",
    label=r"Rel. E_tot change $|(E_{m+1}-E_m) \:/ \: E_{m}|$",
)
ax2.set_ylabel("Relative change")
ax2.set_yscale("log")
ax2.legend(loc="upper right")

plt.tight_layout()
../_images/examples_two_species_groundstate_25_0.png