Skip to content

Commit

Permalink
LazyProxy and LazyProxyMultiton patterns (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins authored Feb 3, 2022
1 parent 1746456 commit 6ac51e4
Show file tree
Hide file tree
Showing 9 changed files with 488 additions and 38 deletions.
5 changes: 5 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
History
=======

X.Y.Z (YYYY-MM-DD)
------------------
* Add Multiton, LazyProxy and LazyProxyMultiton patterns (:pr:`269`)


0.3.2 (2022-13-01)
------------------
* Support numba >= 0.54 (:pr:`264`)
Expand Down
43 changes: 40 additions & 3 deletions africanus/experimental/rime/fused/specification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
from importlib import import_module
import inspect
from itertools import groupby # noqa
import multiprocessing
from pathlib import Path
import re

Expand All @@ -11,6 +11,7 @@
from africanus.experimental.rime.fused import terms as term_mod
from africanus.experimental.rime.fused.transformers.core import Transformer
from africanus.experimental.rime.fused import transformers as transformer_mod
from africanus.util.patterns import LazyProxy


TERM_STRING_REGEX = re.compile("([A-Z])(pq|p|q)")
Expand Down Expand Up @@ -232,9 +233,17 @@ def __init__(self, specification, terms=None, transformers=None):
except KeyError as e:
raise RimeSpecificationError(f"Can't find a type for {str(e)}")

Pool = multiprocessing.get_context("spawn").Pool
pool = LazyProxy((Pool, RimeSpecification._finalise_pool), 4)

# Create the terms
terms = []
global_kw = {"corrs": corrs, "stokes": stokes, "feed_type": feed_type}
global_kw = {
"corrs": corrs,
"stokes": stokes,
"feed_type": feed_type,
"process_pool": pool
}

for cls, cfg in zip(term_types, term_cfgs):
if cfg == "pq":
Expand Down Expand Up @@ -284,8 +293,36 @@ def __init__(self, specification, terms=None, transformers=None):
raise RimeSpecificationError(
"RIME must at least contain a Brightness term")

transformers = []

for cls in transformer_types.values():
init_sig = inspect.signature(cls.__init__)
cls_kw = {}

for a, p in list(init_sig.parameters.items())[1:]:
if p.kind not in {p.POSITIONAL_ONLY,
p.POSITIONAL_OR_KEYWORD}:
raise RimeSpecification(
f"{cls}.__init__{init_sig} may not contain "
f"*args or **kwargs")

try:
cls_kw[a] = available_kw[a]
except KeyError:
raise RimeSpecificationError(
f"{cls}.__init__{init_sig} wants argument {a} "
f"but it is not available. "
f"Available args: {available_kw}")

transformer = cls(**cls_kw)
transformers.append(transformer)

self.terms = terms
self.transformers = [cls() for cls in transformer_types.values()]
self.transformers = transformers

@staticmethod
def _finalise_pool(pool):
pool.terminate()

@staticmethod
def _feed_type(corrs):
Expand Down
3 changes: 3 additions & 0 deletions africanus/experimental/rime/fused/transformers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def __new__(mcls, name, bases, namespace):


class Transformer(metaclass=TransformerMetaClass):
def __init__(self):
pass

def __repr__(self):
return self.__class__.__name__

Expand Down
6 changes: 5 additions & 1 deletion africanus/experimental/rime/fused/transformers/parangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
class ParallacticTransformer(Transformer):
OUTPUTS = ["feed_parangle", "beam_parangle"]

def __init__(self, process_pool):
self.pool = process_pool

def init_fields(self, typingctx,
utime, ufeed, uantenna,
antenna_position, phase_dir,
Expand All @@ -29,7 +32,8 @@ def init_fields(self, typingctx,
@njit(inline="never")
def parangle_stub(time, antenna, phase_dir):
with objmode(out=parangle_dt):
out = casa_parallactic_angles(time, antenna, phase_dir)
out = self.pool.apply(casa_parallactic_angles,
(time, antenna, phase_dir))

return out

Expand Down
9 changes: 3 additions & 6 deletions africanus/rime/jax/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


try:
import jax.numpy as np
import jax.numpy as jnp
except ImportError as e:
opt_import_error = e
else:
Expand All @@ -14,11 +14,8 @@

@requires_optional('jax', opt_import_error)
def phase_delay(lm, uvw, frequency):
out_dtype = np.result_type(lm, uvw, frequency, np.complex64)

one = lm.dtype.type(1.0)
neg_two_pi_over_c = lm.dtype.type(minus_two_pi_over_c)
complex_one = out_dtype.type(1j)

l = lm[:, 0, None, None] # noqa
m = lm[:, 1, None, None]
Expand All @@ -27,10 +24,10 @@ def phase_delay(lm, uvw, frequency):
v = uvw[None, :, 1, None]
w = uvw[None, :, 2, None]

n = np.sqrt(one - l**2 - m**2) - one
n = jnp.sqrt(one - l**2 - m**2) - one

real_phase = (neg_two_pi_over_c *
(l * u + m * v + n * w) *
frequency[None, None, :])

return np.exp(complex_one*real_phase)
return jnp.exp(jnp.complex64(1j)*real_phase)
15 changes: 7 additions & 8 deletions africanus/rime/jax/tests/test_jax_phase_delay.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
# -*- coding: utf-8 -*-


import numpy as onp
import numpy as np
import pytest

from africanus.rime.phase import phase_delay as np_phase_delay
from africanus.rime.jax.phase import phase_delay


@pytest.mark.parametrize("dtype", [onp.float32, onp.float64])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_jax_phase_delay(dtype):
jax = pytest.importorskip('jax')
np = pytest.importorskip('jax.numpy')

onp.random.seed(0)
np.random.seed(0)

uvw = onp.random.random(size=(100, 3)).astype(dtype)
lm = onp.random.random(size=(10, 2)).astype(dtype)*0.001
frequency = onp.linspace(.856e9, .856e9*2, 64).astype(dtype)
uvw = np.random.random(size=(100, 3)).astype(dtype)
lm = np.random.random(size=(10, 2)).astype(dtype)*0.001
frequency = np.linspace(.856e9, .856e9*2, 64).astype(dtype)

# Compute complex phase
np_complex_phase = np_phase_delay(lm, uvw, frequency)
complex_phase = jax.jit(phase_delay)(lm, uvw, frequency)

onp.testing.assert_array_almost_equal(complex_phase, np_complex_phase)
np.testing.assert_array_almost_equal(complex_phase, np_complex_phase)
expected_ctype = np.result_type(dtype, np.complex64)
assert np_complex_phase.dtype == expected_ctype
assert complex_phase.dtype == expected_ctype
Loading

0 comments on commit 6ac51e4

Please sign in to comment.