From 761ae55682c34d619134ffec3a1dc79aaf39c464 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Thu, 9 Jan 2025 10:25:26 +0100 Subject: [PATCH] Create end point for events --- src/ert/run_models/everest_run_model.py | 9 ++- src/everest/detached/jobs/everserver.py | 88 +++++++++++++++++++++++-- 2 files changed, 88 insertions(+), 9 deletions(-) diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index 68c8e0bfc90..3c18629a66a 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -28,7 +28,7 @@ from _ert.events import EESnapshot, EESnapshotUpdate, Event from ert.config import ErtConfig, ExtParamConfig -from ert.ensemble_evaluator import EnsembleSnapshot, EvaluatorServerConfig +from ert.ensemble_evaluator import EndEvent, EnsembleSnapshot, EvaluatorServerConfig from ert.runpaths import Runpaths from ert.storage import open_storage from everest.config import EverestConfig @@ -103,10 +103,11 @@ def __init__( everest_config: EverestConfig, simulation_callback: SimulationCallback | None, optimization_callback: OptimizerCallback | None, + status_queue: queue.SimpleQueue[StatusEvents] | None = None, ): Path(everest_config.log_dir).mkdir(parents=True, exist_ok=True) Path(everest_config.optimization_output_dir).mkdir(parents=True, exist_ok=True) - + status_queue = queue.SimpleQueue() if status_queue is None else status_queue assert everest_config.environment is not None logging.getLogger(EVEREST).info( "Using random seed: %d. To deterministically reproduce this experiment, " @@ -136,7 +137,6 @@ def __init__( self._status: SimulationStatus | None = None storage = open_storage(config.ens_path, mode="w") - status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue() super().__init__( config, storage, @@ -152,12 +152,14 @@ def create( ever_config: EverestConfig, simulation_callback: SimulationCallback | None = None, optimization_callback: OptimizerCallback | None = None, + status_queue: queue.SimpleQueue[StatusEvents] | None = None, ) -> EverestRunModel: return cls( config=everest_to_ert_config(ever_config), everest_config=ever_config, simulation_callback=simulation_callback, optimization_callback=optimization_callback, + status_queue=status_queue, ) @classmethod @@ -222,6 +224,7 @@ def run_experiment( self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS case _: self._exit_code = EverestExitCode.COMPLETED + self.send_event(EndEvent(failed=bool(self.exit_code))) def _create_optimizer(self) -> BasicOptimizer: RESULT_COLUMNS = { diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 33f5a96c8b8..a37a7a15400 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -1,12 +1,16 @@ import argparse +import asyncio import datetime import json import logging +import multiprocessing as mp import os +import queue import socket import ssl import threading import traceback +import uuid from base64 import b64encode from functools import partial from pathlib import Path @@ -19,7 +23,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID from dns import resolver, reversename -from fastapi import Depends, FastAPI, HTTPException, Request, status +from fastapi import Depends, FastAPI, HTTPException, Request, WebSocket, status from fastapi.encoders import jsonable_encoder from fastapi.responses import ( JSONResponse, @@ -32,7 +36,8 @@ ) from ert.config import QueueSystem -from ert.ensemble_evaluator import EvaluatorServerConfig +from ert.ensemble_evaluator import EndEvent, EvaluatorServerConfig +from ert.run_models import StatusEvents from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel from everest import export_to_csv, export_with_progress from everest.config import EverestConfig, ServerConfig @@ -51,6 +56,23 @@ from everest.util import makedirs_if_needed, version_info +class EndTaskEvent: + pass + + +class Subscriber: + def __init__(self) -> None: + self.index = 0 + self._event = asyncio.Event() + + def notify(self): + self._event.set() + + async def wait_for_event(self): + await self._event.wait() + self._event.clear() + + def _get_machine_name() -> str: """Returns a name that can be used to identify this machine in a network @@ -153,6 +175,29 @@ def get_opt_progress( progress = get_opt_status(server_config["optimization_output_dir"]) return JSONResponse(jsonable_encoder(progress)) + @app.websocket("/events") + async def websocket_endpoint(websocket: WebSocket): + subscriber_id = str(uuid.uuid4()) + await websocket.accept() + while True: + event = await get_event(subscriber_id=subscriber_id) + if isinstance(event, EndTaskEvent): + break + await websocket.send_json(event) + await asyncio.sleep(0.1) + + async def get_event(subscriber_id: str) -> StatusEvents: + if subscriber_id not in shared_data["subscribers"]: + shared_data["subscribers"][subscriber_id] = Subscriber() + subscriber = shared_data["subscribers"][subscriber_id] + + while subscriber.index >= len(shared_data["events"]): + await subscriber.wait_for_event() + + event = shared_data["events"][subscriber.index] + shared_data["subscribers"][subscriber_id].index += 1 + return event + uvicorn.run( app, host="0.0.0.0", @@ -235,6 +280,10 @@ def make_handler_config( def main(): + asyncio.run(everserver_main()) + + +async def everserver_main(): arg_parser = argparse.ArgumentParser() arg_parser.add_argument("--config-file", type=str) arg_parser.add_argument("--debug", action="store_true") @@ -272,6 +321,8 @@ def main(): shared_data = { SIM_PROGRESS_ENDPOINT: {}, STOP_ENDPOINT: False, + "events": [], + "subscribers": [], } server_config = { @@ -296,14 +347,14 @@ def main(): message=traceback.format_exc(), ) return - + status_queue: mp.Queue[StatusEvents] = mp.Queue() try: update_everserver_status(status_path, ServerStatus.running) - run_model = EverestRunModel.create( config, simulation_callback=partial(_sim_monitor, shared_data=shared_data), optimization_callback=partial(_opt_monitor, shared_data=shared_data), + status_queue=status_queue, ) if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL: evaluator_server_config = EvaluatorServerConfig() @@ -311,8 +362,33 @@ def main(): evaluator_server_config = EvaluatorServerConfig( custom_port_range=range(49152, 51819), use_ipc_protocol=False ) - - run_model.run_experiment(evaluator_server_config) + loop = asyncio.get_running_loop() + simulation_future = loop.run_in_executor( + None, + lambda: run_model.run_experiment(evaluator_server_config), + ) + events = [] + while True: + try: + item: StatusEvents = status_queue.get(block=False) + except queue.Empty: + await asyncio.sleep(0.01) + continue + + event = jsonable_encoder(item) + shared_data["events"].append(event) + for sub in shared_data["subscribers"]: + sub.notify() + await asyncio.sleep(0.1) + + if isinstance(item, EndEvent): + events.append(EndTaskEvent()) + for sub in shared_data["subscribers"]: + sub.notify() + break + + await simulation_future + run_model = None status, message = _get_optimization_status(run_model.exit_code, shared_data) if status != ServerStatus.completed: