Skip to content

Commit

Permalink
[query] Use sourcecode.Enclosing to handle timed blocks implicitly (#…
Browse files Browse the repository at this point in the history
…14683)

### Change Description

This change exists as part of larger refactoring work. Herein, I've
exchanged
hard-coded contextual strings passed to `ExecutionTimer.time` with
implict
contexts, drawing inspiration from scalatest.

These contexts are now supplied after entering functions like `Compile`
and
`Emit` instead of before (see `ExecuteContext.time`). By sprinking calls
to
`time` throughout the codebase after entering functions, we obtain a
nice trace
of the timings with `sourcecode.Enclosing`, minus the previous
verbosity.

See [1] for more information about what pre-built macros are available.
We can
always build our own later. See comments in [pull request id] for
example output.
Note that `ExectionTimer.time` still accepts a string to support uses
like
`Optimise` and `LoweringPass` where those contexts are provided already.
It is also exception-safe now.

This change exposed many similarities between the implementations of
query
execution across all three backends. I've stopped short of full
unification
which is a greater work, I've instead simplified and moved duplicated
result
encoding into the various backend api implementations.

More interesting changes are to `ExecuteContext`, which now supports
- `time`, as discussed above
- `local`, a generalisation for temporarily overriding properties of an 
`ExecuteContext` (inspired by [2]). While I've long wanted this for
testing,
we were doing some questionable things when reporting timings back to
python,
for which locally overriding the `timer` of a `ctx` has been very
useful.
  We also follow this pattern for local regions

[1] https://github.com/com-lihaoyi/sourcecode
[2]
https://hackage.haskell.org/package/mtl-2.3.1/docs/Control-Monad-Reader.html#v:local

### Security Assessment

This change has no security impact as it's confined to refactoring of
existing non-security-related code.
  • Loading branch information
ehigham authored Oct 17, 2024
1 parent 1c04a66 commit a863a4b
Show file tree
Hide file tree
Showing 38 changed files with 4,113 additions and 4,221 deletions.
2 changes: 2 additions & 0 deletions hail/build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ object Deps {
val log4j = ivy"org.apache.logging.log4j:log4j-1.2-api:2.17.2"
val hadoopClient = ivy"org.apache.hadoop:hadoop-client:3.3.4"
val jackson = ivy"com.fasterxml.jackson.core:jackson-core:2.15.2"
val sourcecode = ivy"com.lihaoyi::sourcecode:0.4.2"

object Plugins {
val betterModadicFor = ivy"com.olegpy::better-monadic-for:0.3.1"
Expand Down Expand Up @@ -200,6 +201,7 @@ object main extends RootModule with HailScalaModule { outer =>
Deps.jna,
Deps.json4s.excludeOrg("com.fasterxml.jackson.core"),
Deps.zstd,
Deps.sourcecode
)

override def runIvyDeps: T[Agg[Dep]] = Agg(
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def execute(self, ir: BaseIR, timed: bool = False) -> Any:
return (value, timings) if timed else value

@abc.abstractmethod
def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, str]:
def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, Optional[dict]]:
pass

def _render_ir(self, ir):
Expand Down
18 changes: 15 additions & 3 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socketserver
import sys
from threading import Thread
from typing import Mapping, Set, Tuple
from typing import Mapping, Optional, Set, Tuple

import orjson
import py4j
Expand Down Expand Up @@ -156,6 +156,18 @@ def connect_logger(utils_package_object, host, port):
}


def parse_timings(str: Optional[str]) -> Optional[dict]:
def parse(node):
return {
'name': node[0],
'total_time': node[1],
'self_time': node[2],
'children': [parse(c) for c in node[3]],
}

return None if str is None else parse(orjson.loads(str))


class Py4JBackend(Backend):
@abc.abstractmethod
def __init__(self, jvm: JVMView, jbackend: JavaObject, jhc: JavaObject):
Expand Down Expand Up @@ -211,7 +223,7 @@ def logger(self):
self._logger = Log4jLogger(self._utils_package_object)
return self._logger

def _rpc(self, action, payload) -> Tuple[bytes, str]:
def _rpc(self, action, payload) -> Tuple[bytes, Optional[dict]]:
data = orjson.dumps(payload)
path = action_routes[action]
port = self._backend_server_port
Expand All @@ -221,7 +233,7 @@ def _rpc(self, action, payload) -> Tuple[bytes, str]:
raise fatal_error_from_java_error_triplet(
error_json['short'], error_json['expanded'], error_json['error_id']
)
return resp.content, resp.headers.get('X-Hail-Timings', '')
return resp.content, parse_timings(resp.headers.get('X-Hail-Timings', None))

def persist_expression(self, expr):
t = expr.dtype
Expand Down
6 changes: 3 additions & 3 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ async def _run_on_batch(
progress: Optional[BatchProgressBar] = None,
driver_cores: Optional[Union[int, str]] = None,
driver_memory: Optional[str] = None,
) -> Tuple[bytes, str]:
) -> Tuple[bytes, Optional[dict]]:
timings = Timings()
async with TemporaryDirectory(ensure_exists=False) as iodir:
with timings.step("write input"):
Expand Down Expand Up @@ -414,7 +414,7 @@ async def _run_on_batch(

with timings.step("read output"):
result_bytes = await retry_transient_errors(self._read_output, iodir + '/out', iodir + '/in')
return result_bytes, str(timings.to_dict())
return result_bytes, timings.to_dict()

async def _read_output(self, output_uri: str, input_uri: str) -> bytes:
try:
Expand Down Expand Up @@ -462,7 +462,7 @@ def _cancel_on_ctrl_c(self, coro: Awaitable[T]) -> T:
self._batch_was_submitted = False
raise

def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, str]:
def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, Optional[dict]]:
return self._cancel_on_ctrl_c(self._async_rpc(action, payload))

async def _async_rpc(self, action: ActionTag, payload: ActionPayload):
Expand Down
3 changes: 2 additions & 1 deletion hail/python/hail/expr/expressions/expression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def eval_timed(expression):
uid = Env.get_uid()
ir = expression._indices.source.select_globals(**{uid: expression}).index_globals()[uid]._ir

return Env.backend().execute(MakeTuple([ir]), timed=True)[0]
(value, timings) = Env.backend().execute(MakeTuple([ir]), timed=True)
return value[0], timings


@typecheck(expression=expr_any)
Expand Down
8 changes: 5 additions & 3 deletions hail/python/hailtop/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,9 +1122,11 @@ def step(self, name: str):
d: Dict[str, int] = {}
self.timings[name] = d
d['start_time'] = time_msecs()
yield
d['finish_time'] = time_msecs()
d['duration'] = d['finish_time'] - d['start_time']
try:
yield
finally:
d['finish_time'] = time_msecs()
d['duration'] = d['finish_time'] - d['start_time']

def to_dict(self):
return self.timings
Expand Down
68 changes: 32 additions & 36 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package is.hail.backend
import is.hail.asm4s._
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IRParser, IRParserEnvironment, LoweringAnalyses,
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
}
import is.hail.expr.ir.functions.IRFunctionRegistry
Expand All @@ -30,6 +30,7 @@ import java.nio.charset.StandardCharsets
import com.fasterxml.jackson.core.StreamReadConstraints
import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import sourcecode.Enclosing

object Backend {

Expand All @@ -46,6 +47,25 @@ object Backend {
irID += 1
irID
}

def encodeToOutputStream(
ctx: ExecuteContext,
t: PTuple,
off: Long,
bufferSpecString: String,
os: OutputStream,
): Unit = {
val bs = BufferSpec.parseOrDefault(bufferSpecString)
assert(t.size == 1)
val elementType = t.fields(0).typ
val codec = TypedCodecSpec(
EType.fromPythonTypeEncoding(elementType.virtualType),
elementType.virtualType,
bs,
)
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}
}

abstract class BroadcastValue[T] { def value: T }
Expand Down Expand Up @@ -169,41 +189,41 @@ abstract class Backend {
def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage

def withExecuteContext[T](methodName: String)(f: ExecuteContext => T): T
def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T

private[this] def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)

final def valueType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext("valueType") { ctx =>
withExecuteContext { ctx =>
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
}
}

final def tableType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext("tableType") { ctx =>
withExecuteContext { ctx =>
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
}
}

final def matrixTableType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext("matrixTableType") { ctx =>
withExecuteContext { ctx =>
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
}
}

final def blockMatrixType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext("blockMatrixType") { ctx =>
withExecuteContext { ctx =>
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
}
}

def loadReferencesFromDataset(path: String): Array[Byte] = {
withExecuteContext("loadReferencesFromDataset") { ctx =>
withExecuteContext { ctx =>
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
rgs.foreach(addReference)

Expand All @@ -221,14 +241,14 @@ abstract class Backend {
mtContigs: Array[String],
parInput: Array[String],
): Array[Byte] =
withExecuteContext("fromFASTAFile") { ctx =>
withExecuteContext { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput)
rg.toJSONString.getBytes(StandardCharsets.UTF_8)
}

def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes {
withExecuteContext("parseVCFMetadata") { ctx =>
withExecuteContext { ctx =>
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
implicit val formats = defaultJSONFormats
Extraction.decompose(metadata)
Expand All @@ -237,7 +257,7 @@ abstract class Backend {

def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String)
: Array[Byte] =
withExecuteContext("importFam") { ctx =>
withExecuteContext { ctx =>
LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes(
StandardCharsets.UTF_8
)
Expand All @@ -251,7 +271,7 @@ abstract class Backend {
returnType: String,
bodyStr: String,
): Unit = {
withExecuteContext("pyRegisterIR") { ctx =>
withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
name,
Expand All @@ -264,31 +284,7 @@ abstract class Backend {
}
}

def execute(
ir: String,
timed: Boolean,
)(
consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit
): Unit = ()

def encodeToOutputStream(
ctx: ExecuteContext,
t: PTuple,
off: Long,
bufferSpecString: String,
os: OutputStream,
): Unit = {
val bs = BufferSpec.parseOrDefault(bufferSpecString)
assert(t.size == 1)
val elementType = t.fields(0).typ
val codec = TypedCodecSpec(
EType.fromPythonTypeEncoding(elementType.virtualType),
elementType.virtualType,
bs,
)
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}
def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

trait BackendWithCodeCache {
Expand Down
29 changes: 23 additions & 6 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package is.hail.backend

import is.hail.expr.ir.{IRParser, IRParserEnvironment}
import is.hail.utils._

import scala.util.control.NonFatal

import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets
import java.util.concurrent._

import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer}
import org.json4s._
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.compact

case class IRTypePayload(ir: String)
case class LoadReferencesFromDatasetPayload(path: String)
Expand Down Expand Up @@ -84,15 +88,28 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
try {
val body = using(exchange.getRequestBody)(JsonMethods.parse(_))
if (exchange.getRequestURI.getPath == "/execute") {
val config = body.extract[ExecutePayload]
backend.execute(config.ir, config.timed) { (ctx, res, timings) =>
exchange.getResponseHeaders().add("X-Hail-Timings", timings)
val ExecutePayload(irStr, streamCodec, timed) = body.extract[ExecutePayload]
backend.withExecuteContext { ctx =>
val (res, timings) = ExecutionTimer.time { timer =>
ctx.local(timer = timer) { ctx =>
val irData = IRParser.parse_value_ir(
irStr,
IRParserEnvironment(ctx, irMap = backend.persistedIR.toMap),
)
backend.execute(ctx, irData)
}
}

if (timed) {
exchange.getResponseHeaders.add("X-Hail-Timings", compact(timings.toJSON))
}

res match {
case Left(_) => exchange.sendResponseHeaders(200, -1L)
case Right((t, off)) =>
exchange.sendResponseHeaders(200, 0L) // 0 => an arbitrarily long response body
using(exchange.getResponseBody()) { os =>
backend.encodeToOutputStream(ctx, t, off, config.stream_codec, os)
using(exchange.getResponseBody) { os =>
Backend.encodeToOutputStream(ctx, t, off, streamCodec, os)
}
}
}
Expand Down Expand Up @@ -126,7 +143,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
exchange.sendResponseHeaders(200, response.length)
using(exchange.getResponseBody())(_.write(response))
} catch {
case t: Throwable =>
case NonFatal(t) =>
val (shortMessage, expandedMessage, errorId) = handleForPython(t)
val errorJson = JObject(
"short" -> JString(shortMessage),
Expand Down
Loading

0 comments on commit a863a4b

Please sign in to comment.