Source code for pynbody.sph.kernels
"""SPH kernel details"""
from __future__ import annotations
import numpy as np
import scipy.integrate as integrate
[docs]
class KernelBase:
"""Base class for SPH kernels"""
_sample_cache = {}
[docs]
def __init__(self):
self.h_power = 3
# Return the power of the smoothing length which appears in
# the denominator of the expression for the general kernel.
# Will be 3 for 3D kernels, 2 for 2D kernels.
self.max_d = 2
# The maximum value of the displacement over the smoothing for
# which the kernel is non-zero
def _get_samples_from_cache(self):
if hash(self) in KernelBase._sample_cache:
return KernelBase._sample_cache[hash(self)]
else:
return None
def get_samples(self, dtype=np.float32):
import time
s = time.time()
samples = self._get_samples_from_cache()
if samples is None:
sample_pts = np.arange(0, 4.01, 0.02)
samples = np.array([self.get_value(x ** 0.5) for x in sample_pts], dtype=dtype)
KernelBase._sample_cache[hash(self)] = samples
return samples
[docs]
def get_value(self, d, h=1) -> float:
"""Get the value of the kernel for a given smoothing length."""
raise NotImplementedError("Subclasses must implement this method")
[docs]
def projection(self) -> KernelBase:
"""Return a 2D projection of this kernel"""
return Kernel2D(self)
def __hash__(self):
return hash(self.__class__)
[docs]
@classmethod
def get_c_kernel_id(cls) -> int:
"""Return the C kernel id for this kernel
This is used to select the appropriate C code for the kernel, and must match
the kernel id defined in the Kernel::create function in kernels.hpp"""
raise NotImplementedError("Subclasses must implement this method")
[docs]
class CubicSplineKernel(KernelBase):
"""A cubic spline kernel. This is the default kernel used by pynbody."""
[docs]
def get_value(self, d, h=1):
if d < 1:
f = 1. - (3. / 2) * d ** 2 + (3. / 4.) * d ** 3
elif d < 2:
f = 0.25 * (2. - d) ** 3
else:
f = 0
return f / (np.pi * h ** 3)
[docs]
@classmethod
def get_c_kernel_id(cls):
return 0
[docs]
class WendlandC2Kernel(KernelBase):
"""A Wendland C2 (quintic) kernel. This is the default kernel used by EAGLE."""
[docs]
def get_value(self, d, h=1):
if d < 2:
f = (1. - (d / 2.))**4 * (2. * d + 1)
else:
f = 0
return (21. * f) / (16. * np.pi * h ** 3)
[docs]
@classmethod
def get_c_kernel_id(cls):
return 1
[docs]
class Kernel2D(KernelBase):
"""A 2D spline kernel, generated by numerically projecting an underlying 3D kernel"""
[docs]
def __init__(self, k_orig=CubicSplineKernel()):
"""Create a 2D kernel by projecting a 3D kernel. The 3D kernel is passed as an argument."""
self.h_power = 2
self.max_d = k_orig.max_d
self.k_orig = k_orig
[docs]
def projection(self):
raise ValueError("Cannot project a 2D kernel")
[docs]
def get_value(self, d, h=1):
return 2 * integrate.quad(lambda z: self.k_orig.get_value(np.sqrt(z ** 2 + d ** 2), h), 0, 2*h)[0]
[docs]
def get_c_kernel_id(self):
raise NotImplementedError("2D kernels are not supported in C")
def __hash__(self):
return hash((self.__class__, self.k_orig))
[docs]
def create_kernel(spec) -> KernelBase:
"""Create a kernel object from a string specification, a type, an existing kernel object, or a None
This function is used to create a kernel object from a variety of input types. It is used by the
framework to allow the user flexibility in specifying the kernel type.
If the input is a string, it is assumed to be the name of a kernel class, and an object of that class
is created. You can use the name of the class with or without the 'Kernel' suffix, and the case is
ignored. For example, 'WendlandC2Kernel', 'wendlandc2', and 'WendlandC2' all return a WendlandC2Kernel
instance.
If the input is a subclass of KernelBase, it is assumed to be a kernel object, and is returned as is.
If the input is None, a default kernel is created and returned.
Returns
-------
KernelBase
A kernel object
"""
if spec is None:
from ..configuration import config
return create_kernel(config['sph'].get('kernel', 'CubicSplineKernel'))
elif isinstance(spec, type):
return spec()
elif isinstance(spec, KernelBase):
return spec
elif isinstance(spec, str):
for subclass in KernelBase.__subclasses__():
subclass_name = subclass.__name__
if (subclass_name.lower() == spec.lower() or
subclass_name.endswith('Kernel') and subclass_name[:-6].lower() == spec.lower()):
return subclass()
else:
raise ValueError("Unknown kernel '%s'" % spec)
else:
raise ValueError("Unknown kernel specification %r" % spec)