Skip to content

Commit

Permalink
Feature/andberge/misc (#4)
Browse files Browse the repository at this point in the history
* handle potential error in handshake

* misc whitespace, print to log, check None cleanup

* bundle exceptions server responses (py>=3.11 feature)

* use std timeout

* cleanup deps and allow py 3.12

* rm poetry

* simplify

* append short id avoiding name collision
  • Loading branch information
andberge authored Jan 10, 2025
1 parent c3d0d09 commit 609ab60
Show file tree
Hide file tree
Showing 9 changed files with 1,115 additions and 2,616 deletions.
12 changes: 4 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:

strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand All @@ -34,25 +34,21 @@ jobs:
- name: Checkout LFS objects
run: git lfs pull


- name: Install poetry
run: |
curl -sSL https://install.python-poetry.org | python3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'

- name: Install testing deps
run: |
pip install poetry
poetry config virtualenvs.create false --local
poetry install --with dev --no-interaction
docker compose -f tests/compose.yml up --detach
- name: Test
run: |
poetry run python -m pytest -v -x -s -color=yes --cov-report=term-missing --cov-report=xml --cov=pyetp tests/
python -m pytest -v -x -s --color=yes --cov-report=term-missing --cov-report=xml --cov=pyetp tests/
# - name: Comment coverage
# if: ${{ github.event_name == 'pull_request'}}
Expand Down
3,457 changes: 955 additions & 2,502 deletions poetry.lock

Large diffs are not rendered by default.

95 changes: 70 additions & 25 deletions pyetp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import typing as T
import uuid
from collections import defaultdict
from contextlib import asynccontextmanager
from types import TracebackType

import numpy as np
import websockets
import xtgeo
from async_timeout import timeout
from etpproto.connection import (CommunicationProtocol, ConnectionType,
ETPConnection)
from etpproto.messages import Message, MessageFlags
Expand All @@ -23,6 +22,28 @@
from pyetp.config import SETTINGS
from pyetp.types import *
from pyetp.uri import DataObjectURI, DataspaceURI
from pyetp.utils import short_id

try:
# for py >3.11, we can raise grouped exceptions
from builtins import ExceptionGroup # type: ignore
except ImportError:
def ExceptionGroup(msg, errors):
return errors[0]

try:
from asyncio import timeout
except ImportError:
import async_timeout

@asynccontextmanager
async def timeout(delay: T.Optional[float]) -> T.Any:
try:
async with async_timeout.timeout(delay):
yield None
except asyncio.CancelledError as e:
raise asyncio.TimeoutError(f'Timeout ({delay}s)') from e


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -38,11 +59,14 @@ def __init__(self, message: str, code: int):
super().__init__(f"{message} ({code=:})")

@classmethod
def from_proto(cls, msg: ProtocolException):
assert msg.error is not None or msg.errors is not None, "passed no error info"
error = msg.error or list(msg.errors.values())[0]
def from_proto(cls, error: ErrorInfo):
assert error is not None, "passed no error info"
return cls(error.message, error.code)

@classmethod
def from_protos(cls, errors: T.Iterable[ErrorInfo]):
return list(map(cls.from_proto, errors))


def get_all_etp_protocol_classes():
"""Update protocol - all exception protocols are now per message"""
Expand Down Expand Up @@ -104,25 +128,37 @@ async def _send(self, body: ETPModel):
async def _recv(self, correlation_id: int) -> ETPModel:
assert correlation_id in self._recv_events, "trying to recv response on non-existing message"

try:
async with timeout(self.timeout):
await self._recv_events[correlation_id].wait()
except asyncio.CancelledError as e:
raise TimeoutError(f'Timeout before reciving {correlation_id=}') from e
async with timeout(self.timeout):
await self._recv_events[correlation_id].wait()

# cleanup
bodies = self._clear_msg_on_buffer(correlation_id)

for body in bodies:
if isinstance(body, ProtocolException):
logger.debug(body)
raise ETPError.from_proto(body)
# error handling
errors = self._parse_error_info(bodies)

if len(errors) == 1:
raise ETPError.from_proto(errors.pop())
elif len(errors) > 1:
raise ExceptionGroup("Server responded with ETPErrors:", ETPError.from_protos(errors))

if len(bodies) > 1:
logger.warning(f"Recived {len(bodies)} messages, but only expected one")

# ok
return bodies[0]

@staticmethod
def _parse_error_info(bodies: list[ETPModel]) -> list[ErrorInfo]:
# returns all error infos from bodies
errors = []
for body in bodies:
if isinstance(body, ProtocolException):
if body.error is not None:
errors.append(body.error)
errors.extend(body.errors.values())
return errors

async def close(self, reason=''):
if self.ws.closed:
self.__recvtask.cancel("stopped")
Expand Down Expand Up @@ -328,7 +364,7 @@ async def put_data_objects(self, *objs: DataObject):
PutDataObjectsResponse

response = await self.send(
PutDataObjects(dataObjects={p.resource.uri: p for p in objs})
PutDataObjects(dataObjects={f"{p.resource.name}_{short_id()}": p for p in objs})
)
# logger.info(f"objects {response=:}")
assert isinstance(response, PutDataObjectsResponse), "Expected PutDataObjectsResponse"
Expand Down Expand Up @@ -452,7 +488,7 @@ async def get_surface_value_x_y(self, epc_uri: T.Union[DataObjectURI, str], gri_
arr = await self.get_subarray(uid, [min_x_ind, min_y_ind], [count_x, count_y])
new_x_ori = xori+(min_x_ind*xinc)
new_y_ori = yori+(min_y_ind*yinc)
regridded = xtgeo.RegularSurface(
regridded = RegularSurface(
ncol=arr.shape[0],
nrow=arr.shape[1],
xori=new_x_ori,
Expand All @@ -465,7 +501,7 @@ async def get_surface_value_x_y(self, epc_uri: T.Union[DataObjectURI, str], gri_
return regridded.get_value_from_xy((x, y))

async def get_xtgeo_surface(self, epc_uri: T.Union[DataObjectURI, str], gri_uri: T.Union[DataObjectURI, str], crs_uri: T.Union[DataObjectURI, str, None] = None):
if isinstance(crs_uri, type(None)):
if crs_uri is None:
logger.debug("NO crs")
gri, = await self.get_resqml_objects(gri_uri)
crs_uuid = gri.grid2d_patch.geometry.local_crs.uuid
Expand Down Expand Up @@ -742,8 +778,10 @@ async def put_epc_mesh(
else:
time_indices = [-1]
cprop0s, props, propertykind0 = utils_xml.convert_epc_mesh_property_to_resqml_mesh(epc_filename, hexa, propname, uns, epc)
if isinstance(cprop0s, type(None)):

if cprop0s is None:
continue

cprop_uris = []
for cprop0, prop, time_index in zip(cprop0s, props, time_indices):
assert isinstance(cprop0, ro.ContinuousProperty) or isinstance(cprop0, ro.DiscreteProperty), "prop must be a Property"
Expand Down Expand Up @@ -1064,15 +1102,15 @@ def __await__(self):
# async with connect(...) as ...:

async def __aenter__(self):

headers = {}
if isinstance(self.authorization, str):
token = self.authorization
headers["Authorization"] = self.authorization
elif isinstance(self.authorization, SecretStr):
token = self.authorization.get_secret_value()
headers = {}
if isinstance(self.authorization, type(None)) is False:
headers["Authorization"] = token
if isinstance(self.data_partition, str):
headers["Authorization"] = self.authorization.get_secret_value()
if self.data_partition is not None:
headers["data-partition-id"] = self.data_partition

ws = await websockets.connect(
self.server_url,
subprotocols=[ETPClient.SUB_PROTOCOL], # type: ignore
Expand All @@ -1081,8 +1119,15 @@ async def __aenter__(self):
ping_timeout=self.timeout,
open_timeout=None,
)

self.client = ETPClient(ws, default_dataspace_uri=self.default_dataspace_uri, timeout=self.timeout)
await self.client.request_session()

try:
await self.client.request_session()
except Exception as e:
# aexit not called if raised in aenter - so manual cleanup here needed
await self.client.close("Failed to request session")
raise e

return self.client

Expand Down
10 changes: 5 additions & 5 deletions pyetp/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional

from pydantic import AnyUrl, BaseSettings, Field, RedisDsn, SecretStr

from pyetp.uri import DataspaceURI
from typing import Optional


class WebSocketUrl(AnyUrl):
allowed_schemes = {'wss', 'ws'}
Expand All @@ -18,7 +21,6 @@ class Config:
application_name: str = Field(default='etpClient')
application_version: str = Field(default='0.0.1')


dataspace: str = Field(default='demo/pss-data-gateway')
etp_url: WebSocketUrl = Field(default='wss://host.com')
etp_timeout: float = Field(default=60., description="Timeout in seconds")
Expand All @@ -29,6 +31,4 @@ def duri(self):
return DataspaceURI.from_name(self.dataspace)




SETTINGS = Settings()
SETTINGS = Settings()
1 change: 1 addition & 0 deletions pyetp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from etptypes.energistics.etp.v12.datatypes.data_array_types.data_array_metadata import \
DataArrayMetadata
from etptypes.energistics.etp.v12.datatypes.data_value import DataValue
from etptypes.energistics.etp.v12.datatypes.error_info import ErrorInfo
from etptypes.energistics.etp.v12.datatypes.message_header import MessageHeader
from etptypes.energistics.etp.v12.datatypes.object.data_object import \
DataObject
Expand Down
6 changes: 6 additions & 0 deletions pyetp/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import random
import string


def short_id(length=8):
return ''.join([random.choice(string.ascii_letters + string.digits + '-_') for _ in range(length)])
69 changes: 36 additions & 33 deletions pyetp/utils_arrays.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@

import typing as T

import numpy as np
from scipy import interpolate
import xtgeo
from xtgeo import RegularSurface

from .types import (AnyArray, AnyArrayType, ArrayOfBoolean, ArrayOfDouble,
ArrayOfFloat, ArrayOfInt, ArrayOfLong, DataArray,
DataArrayMetadata)
from pyetp.types import (AnyArray, AnyArrayType, ArrayOfBoolean, ArrayOfDouble,
ArrayOfFloat, ArrayOfInt, ArrayOfLong, DataArray,
DataArrayMetadata)

SUPPORED_ARRAY_TYPES = T.Union[ArrayOfFloat , ArrayOfBoolean, ArrayOfInt, ArrayOfLong , ArrayOfDouble]
SUPPORED_ARRAY_TYPES = T.Union[ArrayOfFloat, ArrayOfBoolean, ArrayOfInt, ArrayOfLong, ArrayOfDouble]

_ARRAY_MAP_TYPES: dict[AnyArrayType, np.dtype[T.Any]] = {
AnyArrayType.ARRAY_OF_FLOAT: np.dtype(np.float32),
Expand Down Expand Up @@ -44,7 +45,7 @@ def get_cls(dtype: np.dtype):
return _ARRAY_MAP[get_transport(dtype)]


def get_dtype(item: T.Union[AnyArray ,AnyArrayType]):
def get_dtype(item: T.Union[AnyArray, AnyArrayType]):
atype = item if isinstance(item, AnyArrayType) else get_transport_from_name(item.item.__class__.__name__)

if atype not in _ARRAY_MAP_TYPES:
Expand Down Expand Up @@ -73,34 +74,36 @@ def to_data_array(data: np.ndarray):
data=AnyArray(item=cls(values=data.flatten().tolist()))
)


def mid_point_rectangle(arr: np.ndarray):
all_x=arr[:,0]
all_y= arr[:,1]
all_x = arr[:, 0]
all_y = arr[:, 1]
min_x = np.min(all_x)
min_y = np.min(all_y)
mid_x = ((np.max(all_x)-min_x)/2)+min_x
mid_y = ((np.max(all_y)-min_y)/2)+min_y
return np.array([mid_x, mid_y])


def grid_xtgeo(data: np.ndarray):
max_x = np.nanmax(data[:,0])
max_y = np.nanmax(data[:,1])
min_x = np.nanmin(data[:,0])
min_y = np.nanmin(data[:,1])
u_x = np.sort(np.unique(data[:,0]))
u_y = np.sort(np.unique(data[:,1]))
xinc = u_x[1]- u_x[0]
yinc = u_y[1]- u_y[0]
max_x = np.nanmax(data[:, 0])
max_y = np.nanmax(data[:, 1])
min_x = np.nanmin(data[:, 0])
min_y = np.nanmin(data[:, 1])
u_x = np.sort(np.unique(data[:, 0]))
u_y = np.sort(np.unique(data[:, 1]))
xinc = u_x[1] - u_x[0]
yinc = u_y[1] - u_y[0]
grid_x, grid_y = np.mgrid[
min_x: max_x + xinc: xinc,
min_y: max_y + yinc: yinc,
]

interp = interpolate.LinearNDInterpolator(data[:,:-1], data[:,-1], fill_value=np.nan, rescale=False)
z = interp(np.array([grid_x.flatten(), grid_y.flatten() ]).T )
zz = np.reshape(z,grid_x.shape)
interp = interpolate.LinearNDInterpolator(data[:, :-1], data[:, -1], fill_value=np.nan, rescale=False)
z = interp(np.array([grid_x.flatten(), grid_y.flatten()]).T)
zz = np.reshape(z, grid_x.shape)

surf = xtgeo.RegularSurface(
return RegularSurface(
ncol=grid_x.shape[0],
nrow=grid_x.shape[1],
xori=min_x,
Expand All @@ -110,22 +113,22 @@ def grid_xtgeo(data: np.ndarray):
rotation=0.0,
values=zz,
)
return surf

def get_cells_positions(points: np.ndarray, n_cells:int, n_cell_per_pos:int,layers_per_sediment_unit:int,n_node_per_pos:int,node_index: int):
results = np.zeros((int(n_cells/n_cell_per_pos),3), dtype=np.float64)
grid_x_pos = np.unique(points[:,0])
grid_y_pos = np.unique(points[:,1])

def get_cells_positions(points: np.ndarray, n_cells: int, n_cell_per_pos: int, layers_per_sediment_unit: int, n_node_per_pos: int, node_index: int):
results = np.zeros((int(n_cells/n_cell_per_pos), 3), dtype=np.float64)
grid_x_pos = np.unique(points[:, 0])
grid_y_pos = np.unique(points[:, 1])
counter = 0
# find cell index and location

for y_ind in range(0,len(grid_y_pos)-1):
for x_ind in range(0,len(grid_x_pos)-1):
top_depth= []
for y_ind in range(0, len(grid_y_pos)-1):
for x_ind in range(0, len(grid_x_pos)-1):
top_depth = []
for corner_x in range(layers_per_sediment_unit):
for corner_y in range(layers_per_sediment_unit):
node_indx = (( (y_ind+corner_y)*len(grid_x_pos) + (x_ind+corner_x) ) * n_node_per_pos)+ node_index
top_depth.append( points[node_indx])
results[counter,0:2] = mid_point_rectangle(np.array(top_depth))
counter+=1
return results
node_indx = (((y_ind+corner_y)*len(grid_x_pos) + (x_ind+corner_x)) * n_node_per_pos) + node_index
top_depth.append(points[node_indx])
results[counter, 0:2] = mid_point_rectangle(np.array(top_depth))
counter += 1
return results
Loading

0 comments on commit 609ab60

Please sign in to comment.