Source code for magnumnp.field_terms.demag

#
# This file is part of the magnum.np distribution
# (https://gitlab.com/magnum.np/magnum.np).
# Copyright (c) 2023 magnum.np team.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

from magnumnp.common import logging, timedmethod, constants, Timer, complex_dtype
from .field_terms import LinearFieldTerm
import numpy as np
import torch
import torch.fft
from torch import asinh, atan, sqrt, log, abs, pi
from time import time
import os

__all__ = ["DemagField"]

def f(x, y, z):
    x, y, z = abs(x), abs(y), abs(z)
    x2, y2, z2 = x**2, y**2, z**2
    r = sqrt(x2 + y2 + z2)
    res = 1.0 / 6.0 * (2*x2 - y2 - z2) * r
    res += (y / 2.0 * (z2 - x2) * asinh(y / sqrt(x2 + z2))).nan_to_num(posinf=0, neginf=0)
    res += (z / 2.0 * (y2 - x2) * asinh(z / sqrt(x2 + y2))).nan_to_num(posinf=0, neginf=0)
    res -= (x * y * z * atan(y*z / (x * r))).nan_to_num(posinf=0, neginf=0)
    return res

def g(x, y, z):
    z = abs(z)
    x2, y2, z2 = x**2, y**2, z**2
    r = sqrt(x2 + y2 + z2)
    res = -x * y * r / 3.0
    res += (x * y * z * asinh(z / sqrt(x2 + y2))).nan_to_num(posinf=0, neginf=0)
    res += (y / 6.0 * (3.0 * z2 - y2) * asinh(x / sqrt(y2 + z2))).nan_to_num(posinf=0, neginf=0)
    res += (x / 6.0 * (3.0 * z2 - x2) * asinh(y / sqrt(x2 + z2))).nan_to_num(posinf=0, neginf=0)
    res -= (z**3 / 6.0 * atan(x * y / (z * r))).nan_to_num(posinf=0, neginf=0)
    res -= (z * y2 / 2.0 * atan(x * z / (y * r))).nan_to_num(posinf=0, neginf=0)
    res -= (z * x2 / 2.0 * atan(y * z / (x * r))).nan_to_num(posinf=0, neginf=0)
    return res

def F1(func, x, y, z, dz, dZ):
    return func(x, y, z      + dZ) \
         - func(x, y, z          ) \
         - func(x, y, z - dz + dZ) \
         + func(x, y, z - dz     )

def F0(func, x, y, z, dy, dY, dz, dZ):
    return F1(func, x, y      + dY, z, dz, dZ) \
         - F1(func, x, y,           z, dz, dZ) \
         - F1(func, x, y - dy + dY, z, dz, dZ) \
         + F1(func, x, y - dy,      z, dz, dZ)

def newell(func, x, y, z, dx, dy, dz, dX, dY, dZ):
    res = F0(func, x,           y, z, dy, dY, dz, dZ) \
        - F0(func, x - dx,      y, z, dy, dY, dz, dZ) \
        - F0(func, x + dX,      y, z, dy, dY, dz, dZ) \
        + F0(func, x - dx + dX, y, z, dy, dY, dz, dZ)
    return -res / (4.*pi*dx*dy*dz)

def dipole_f(x, y, z, dx, dy, dz, dX, dY, dZ):
    z = z + dZ/2. - dz/2. # diff of cell centers for non-equidistant demag
    res = (2.*x**2 - y**2 - z**2) * pow(x**2 + y**2 + z**2, -5./2.)
    res[0,0,0] = 0.
    return res * dx*dy*dz / (4.*pi)

def dipole_g(x, y, z, dx, dy, dz, dX, dY, dZ):
    z = z + dZ/2. - dz/2. # diff of cell centers for non-equidistant demag
    res = 3.*x*y * pow(x**2 + y**2 + z**2, -5./2.)
    res[0,0,0] = 0.
    return res * dx*dy*dz / (4.*pi)

def demag_f(x, y, z, dx, dy, dz, dX, dY, dZ, p):
    res = dipole_f(x, y, z, dx, dy, dz, dX, dY, dZ)
    near = (x**2 + y**2 + z**2) / max(dx**2 + dy**2 + dz**2, dX**2 + dY**2 + dZ**2) < p**2
    res[near] = newell(f, x[near], y[near], z[near], dx, dy, dz, dX, dY, dZ)
    return res

def demag_g(x, y, z, dx, dy, dz, dX, dY, dZ, p):
    res = dipole_g(x, y, z, dx, dy, dz, dX, dY, dZ)
    near = (x**2 + y**2 + z**2) / max(dx**2 + dy**2 + dz**2, dX**2 + dY**2 + dZ**2) < p**2
    res[near] = newell(g, x[near], y[near], z[near], dx, dy, dz, dX, dY, dZ)
    return res


[docs] class DemagField(LinearFieldTerm): r""" Demagnetization Field: The dipole-dipole interaction gives rise to a long-range interaction. The integral formulation of the corresponding Maxwell equations can be represented as convolution of the magnetization :math:`\vec{M} = M_s \; \vec{m}` with a proper demagnetization kernel :math:`\vec{N}` .. math:: \vec{h}^\text{dem}_{\vec{i}} = \sum\limits_{\vec{j}} \vec{N}_{\vec{i} - \vec{j}} \, \vec{M}_{\vec{j}}, The convolution can be evaluated efficiently using an FFT method. :param p: number of next neighbors for near field via Newell's equation (default = 20) :type p: int, optional """ def __init__(self, p = 20, cache_dir = None): self._p = p self._cache_dir = cache_dir def _shape(self, state): # TODO: try padding to 2N-1 for small N like mumax does s = [1,1,1] for i in range(3): if state.mesh.n[i] == 1: continue if state.mesh.pbc[i] == 0: s[i] = 2*state.mesh.n[i] else: s[i] = state.mesh.n[i] # no need to pad if nonzero pbc return s def _init_N_component(self, state, perm, func): dx = np.array(state.mesh.dx) dx /= dx.min() # rescale dx to avoid NaNs when using single precision shape = self._shape(state) ij = [torch.fft.fftfreq(n,1/n) for n in shape] # local indices ij = torch.meshgrid(*ij,indexing='ij') x, y, z = [ij[ind]*dx[ind] for ind in perm] Lx = [state.mesh.n[ind]*dx[ind] for ind in perm] dx = [dx[ind] for ind in perm] offsets = [torch.arange(-state.mesh.pbc[ind], state.mesh.pbc[ind]+1) for ind in perm] # offset of pseudo PBC images offsets = torch.stack(torch.meshgrid(*offsets, indexing="ij"), dim=-1).flatten(end_dim=-2) Nc = torch.zeros(shape) for offset in offsets: Nc += func(x + offset[0]*Lx[0], y + offset[1]*Lx[1], z + offset[2]*Lx[2], *dx, *dx, self._p) dim = [i for i in range(3) if state.mesh.n[i] > 1] if len(dim) > 0: Nc = torch.fft.rfftn(Nc, dim = dim) return Nc.real.clone() def _init_N(self, state): name = "/N_%s.pt" % str(state.mesh).replace(" ","") if self._cache_dir != None and os.path.isfile(self._cache_dir + name): [Nxx,Nxy,Nxz,Nyy,Nyz,Nzz] = torch.load(self._cache_dir + name, map_location=state.device) logging.info("[DEMAG]: Use cached demag kernel from '%s'" % (self._cache_dir + name)) else: dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float64) # always use double precision time_kernel = time() Nxx = self._init_N_component(state, [0,1,2], demag_f).to(dtype=dtype) Nxy = self._init_N_component(state, [0,1,2], demag_g).to(dtype=dtype) Nxz = self._init_N_component(state, [0,2,1], demag_g).to(dtype=dtype) Nyy = self._init_N_component(state, [1,2,0], demag_f).to(dtype=dtype) Nyz = self._init_N_component(state, [1,2,0], demag_g).to(dtype=dtype) Nzz = self._init_N_component(state, [2,0,1], demag_f).to(dtype=dtype) logging.info(f"[DEMAG]: Time calculation of demag kernel = {time() - time_kernel} s") torch.set_default_dtype(dtype) # restore dtype # cache demag tensor if self._cache_dir != None: if not os.path.isdir(self._cache_dir): os.makedirs(self._cache_dir) torch.save([Nxx,Nxy,Nxz,Nyy,Nyz,Nzz], self._cache_dir + name) logging.info("[DEMAG]: Save demag kernel to '%s'" % (self._cache_dir + name)) return [[Nxx, Nxy, Nxz], [Nxy, Nyy, Nyz], [Nxz, Nyz, Nzz]] @timedmethod def h(self, state): if not hasattr(self, "_N"): self._N = self._init_N(state) dim = [i for i in range(3) if state.mesh.n[i] > 1] shape = self._shape(state) s = [shape[i] for i in dim] if len(dim) == 0: # single spin TODO: remove this when torch issue #96518 has been solved N = torch.stack([torch.stack(self._N[0], dim=-1), torch.stack(self._N[1], dim=-1), torch.stack(self._N[2], dim=-1)], dim=-1) return (N * state.m).sum(dim=-1) hx = torch.zeros(self._N[0][0].shape, dtype=complex_dtype[self._N[0][0].dtype], device=state.device) hy = torch.zeros(self._N[0][0].shape, dtype=complex_dtype[self._N[0][0].dtype], device=state.device) hz = torch.zeros(self._N[0][0].shape, dtype=complex_dtype[self._N[0][0].dtype], device=state.device) for ax in range(3): m_pad_fft1D = torch.fft.rfftn(state.material["Ms"] * state.m[:,:,:,(ax,)], dim = dim, s = s).squeeze(-1) hx += self._N[0][ax] * m_pad_fft1D hy += self._N[1][ax] * m_pad_fft1D hz += self._N[2][ax] * m_pad_fft1D hx = torch.fft.irfftn(hx, dim = dim) hy = torch.fft.irfftn(hy, dim = dim) hz = torch.fft.irfftn(hz, dim = dim) return torch.stack([hx[:state.mesh.n[0],:state.mesh.n[1],:state.mesh.n[2]], hy[:state.mesh.n[0],:state.mesh.n[1],:state.mesh.n[2]], hz[:state.mesh.n[0],:state.mesh.n[1],:state.mesh.n[2]]], dim=3)