diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index cf783e0d..6c36ce98 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -68,9 +68,10 @@ def __init__( handler: Any = None, description: str | None = None, ) -> None: - assert issubclass( - datatype.dtype, ATTRIBUTE_TYPES - ), f"Attr type must be one of {ATTRIBUTE_TYPES}, received type {datatype.dtype}" + assert issubclass(datatype.dtype, ATTRIBUTE_TYPES), ( + f"Attr type must be one of {ATTRIBUTE_TYPES}, " + "received type {datatype.dtype}" + ) self._datatype: DataType[T] = datatype self._access_mode: AttrMode = access_mode self._group = group diff --git a/src/fastcs/backend.py b/src/fastcs/backend.py index 35ff2de0..fa5fc1ed 100644 --- a/src/fastcs/backend.py +++ b/src/fastcs/backend.py @@ -83,9 +83,9 @@ def _link_attribute_sender_class(single_mapping: SingleMapping) -> None: for attr_name, attribute in single_mapping.attributes.items(): match attribute: case AttrW(sender=Sender()): - assert ( - not attribute.has_process_callback() - ), f"Cannot assign both put method and Sender object to {attr_name}" + assert not attribute.has_process_callback(), ( + f"Cannot assign both put method and Sender object to {attr_name}" + ) callback = _create_sender_callback(attribute, single_mapping.controller) attribute.set_process_callback(callback) diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index 3e98ba4f..467612bd 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -3,13 +3,13 @@ import enum from abc import abstractmethod from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import cached_property from typing import Generic, TypeVar import numpy as np -T = TypeVar("T", int, float, bool, str, enum.IntEnum, np.ndarray) +T = TypeVar("T", int, float, bool, str, enum.Enum, np.ndarray) ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore @@ -21,10 +21,6 @@ class DataType(Generic[T]): """Generic datatype mapping to a python type, with additional metadata.""" - # We move this to each datatype so that we can have positional - # args in subclasses. - allowed_values: list[T] | None = field(init=False, default=None) - @property @abstractmethod def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars @@ -34,15 +30,7 @@ def validate(self, value: T) -> T: """Validate a value against fields in the datatype.""" if not isinstance(value, self.dtype): raise ValueError(f"Value {value} is not of type {self.dtype}") - if ( - hasattr(self, "allowed_values") - and self.allowed_values is not None - and value not in self.allowed_values - ): - raise ValueError( - f"Value {value} is not in the allowed values for this " - f"datatype {self.allowed_values}." - ) + return value @property @@ -79,8 +67,6 @@ def initial_value(self) -> T_Numerical: class Int(_Numerical[int]): """`DataType` mapping to builtin ``int``.""" - allowed_values: list[int] | None = None - @property def dtype(self) -> type[int]: return int @@ -91,7 +77,6 @@ class Float(_Numerical[float]): """`DataType` mapping to builtin ``float``.""" prec: int = 2 - allowed_values: list[float] | None = None @property def dtype(self) -> type[float]: @@ -102,10 +87,6 @@ def dtype(self) -> type[float]: class Bool(DataType[bool]): """`DataType` mapping to builtin ``bool``.""" - znam: str = "OFF" - onam: str = "ON" - allowed_values: list[bool] | None = None - @property def dtype(self) -> type[bool]: return bool @@ -119,8 +100,6 @@ def initial_value(self) -> bool: class String(DataType[str]): """`DataType` mapping to builtin ``str``.""" - allowed_values: list[str] | None = None - @property def dtype(self) -> type[str]: return str @@ -130,33 +109,30 @@ def initial_value(self) -> str: return "" -T_Enum = TypeVar("T_Enum", bound=enum.IntEnum) +T_Enum = TypeVar("T_Enum", bound=enum.Enum) @dataclass(frozen=True) -class Enum(DataType[enum.IntEnum]): - enum_cls: type[enum.IntEnum] - - @cached_property - def is_string_enum(self) -> bool: - return all(isinstance(member.value, str) for member in self.members) +class Enum(Generic[T_Enum], DataType[T_Enum]): + enum_cls: type[T_Enum] def __post_init__(self): - if not issubclass(self.enum_cls, enum.IntEnum): - raise ValueError("Enum class has to take an IntEnum.") - if {member.value for member in self.members} != set(range(len(self.members))): - raise ValueError("Enum values must be contiguous.") + if not issubclass(self.enum_cls, enum.Enum): + raise ValueError("Enum class has to take an Enum.") + + def index_of(self, value: T_Enum) -> int: + return self.members.index(value) @cached_property - def members(self) -> list[enum.IntEnum]: + def members(self) -> list[T_Enum]: return list(self.enum_cls) @property - def dtype(self) -> type[enum.IntEnum]: + def dtype(self) -> type[T_Enum]: return self.enum_cls @property - def initial_value(self) -> enum.IntEnum: + def initial_value(self) -> T_Enum: return self.members[0] diff --git a/src/fastcs/transport/epics/ioc.py b/src/fastcs/transport/epics/ioc.py index ae293ac2..28784e83 100644 --- a/src/fastcs/transport/epics/ioc.py +++ b/src/fastcs/transport/epics/ioc.py @@ -8,15 +8,13 @@ from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import BaseController, Controller -from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm -from fastcs.exceptions import FastCSException +from fastcs.datatypes import DataType, T from fastcs.transport.epics.util import ( - MBB_MAX_CHOICES, - MBB_STATE_FIELDS, - get_cast_method_from_epics_type, - get_cast_method_to_epics_type, - get_record_metadata_from_attribute, - get_record_metadata_from_datatype, + builder_callable_from_attribute, + get_callable_from_epics_type, + get_callable_to_epics_type, + record_metadata_from_attribute, + record_metadata_from_datatype, ) from .options import EpicsIOCOptions @@ -160,76 +158,34 @@ def _create_and_link_attribute_pvs(pv_prefix: str, controller: Controller) -> No def _create_and_link_read_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T] ) -> None: - cast_method = get_cast_method_to_epics_type(attribute.datatype) + cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) async def async_record_set(value: T): - record.set(cast_method(value)) + record.set(cast_to_epics_type(value)) - record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute) + record = _make_record(f"{pv_prefix}:{pv_name}", attribute) _add_attr_pvi_info(record, pv_prefix, attr_name, "r") attribute.set_update_callback(async_record_set) -def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper: - match attribute.datatype: - case Bool(): - record = builder.boolIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Int(): - record = builder.longIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Float(): - record = builder.aIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case String(): - record = builder.longStringIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Enum(): - if len(attribute.datatype.members) > MBB_MAX_CHOICES: - raise RuntimeError( - f"Received an `Enum` datatype on attribute {attribute} " - f"with more elements than the epics limit `{MBB_MAX_CHOICES}` " - f"for `mbbIn`. Use an `Int or `String with `allowed_values`." - ) - state_keys = dict( - zip( - MBB_STATE_FIELDS, - [member.name for member in attribute.datatype.members], - strict=False, - ) - ) - record = builder.mbbIn( - pv, - **state_keys, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case WaveForm(): - record = builder.WaveformIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case _: - raise FastCSException( - f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" - ) +def _make_record( + pv: str, + attribute: AttrR | AttrW | AttrRW, + on_update: Callable | None = None, +) -> RecordWrapper: + builder_callable = builder_callable_from_attribute(attribute, on_update is None) + datatype_record_metadata = record_metadata_from_datatype(attribute.datatype) + attribute_record_metadata = record_metadata_from_attribute(attribute) + + update = {"always_update": True, "on_update": on_update} if on_update else {} + + record = builder_callable( + pv, **update, **datatype_record_metadata, **attribute_record_metadata + ) def datatype_updater(datatype: DataType): - for name, value in get_record_metadata_from_datatype(datatype).items(): + for name, value in record_metadata_from_datatype(datatype).items(): record.set_field(name, value) attribute.add_update_datatype_callback(datatype_updater) @@ -239,102 +195,22 @@ def datatype_updater(datatype: DataType): def _create_and_link_write_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T] ) -> None: - cast_method = get_cast_method_from_epics_type(attribute.datatype) + cast_from_epics_type = get_callable_from_epics_type(attribute.datatype) + cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) async def on_update(value): - await attribute.process_without_display_update(cast_method(value)) + await attribute.process_without_display_update(cast_from_epics_type(value)) async def async_write_display(value: T): - record.set(cast_method(value), process=False) + record.set(cast_to_epics_type(value), process=False) - record = _get_output_record( - f"{pv_prefix}:{pv_name}", attribute, on_update=on_update - ) + record = _make_record(f"{pv_prefix}:{pv_name}", attribute, on_update=on_update) _add_attr_pvi_info(record, pv_prefix, attr_name, "w") attribute.set_write_display_callback(async_write_display) -def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any: - match attribute.datatype: - case Bool(): - record = builder.boolOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Int(): - record = builder.longOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Float(): - record = builder.aOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case String(): - record = builder.longStringOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Enum(): - if len(attribute.datatype.members) > MBB_MAX_CHOICES: - raise RuntimeError( - f"Received an `Enum` datatype on attribute {attribute} " - f"with more elements than the epics limit `{MBB_MAX_CHOICES}` " - f"for `mbbOut`. Use an `Int or `String with `allowed_values`." - ) - - state_keys = dict( - zip( - MBB_STATE_FIELDS, - [member.name for member in attribute.datatype.members], - strict=False, - ) - ) - record = builder.mbbOut( - pv, - **state_keys, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case WaveForm(): - record = builder.WaveformOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - - case _: - raise FastCSException( - f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" - ) - - def datatype_updater(datatype: DataType): - for name, value in get_record_metadata_from_datatype(datatype).items(): - record.set_field(name, value) - - attribute.add_update_datatype_callback(datatype_updater) - return record - - def _create_and_link_command_pvs(pv_prefix: str, controller: Controller) -> None: for single_mapping in controller.get_controller_mappings(): path = single_mapping.controller.path diff --git a/src/fastcs/transport/epics/util.py b/src/fastcs/transport/epics/util.py index de79b551..8b5d3038 100644 --- a/src/fastcs/transport/epics/util.py +++ b/src/fastcs/transport/epics/util.py @@ -1,8 +1,11 @@ from collections.abc import Callable from dataclasses import asdict -from fastcs.attributes import Attribute +from softioc import builder + +from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm +from fastcs.exceptions import FastCSException _MBB_FIELD_PREFIXES = ( "ZR", @@ -42,13 +45,13 @@ } -def get_record_metadata_from_attribute( +def record_metadata_from_attribute( attribute: Attribute[T], ) -> dict[str, str | None]: return {"DESC": attribute.description} -def get_record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: +def record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: arguments = { DATATYPE_FIELD_TO_RECORD_FIELD[field]: value for field, value in asdict(datatype).items() @@ -63,36 +66,71 @@ def get_record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: "supports to 1D arrays" ) arguments["length"] = datatype.shape[0] + case Enum(): + if len(datatype.members) <= MBB_MAX_CHOICES: + state_keys = dict( + zip( + MBB_STATE_FIELDS, + [member.name for member in datatype.members], + strict=False, + ) + ) + arguments.update(state_keys) return arguments -def get_cast_method_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: +def get_callable_from_epics_type(datatype: DataType[T]) -> Callable[[object], T]: match datatype: case Enum(): - def cast_to_epics_type(value) -> str | int: - return datatype.validate(value).value + def cast_from_epics_type(value: object) -> T: + return datatype.validate(datatype.members[value]) + case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): - def cast_to_epics_type(value) -> object: + def cast_from_epics_type(value) -> T: return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_to_epics_type + return cast_from_epics_type -def get_cast_method_from_epics_type(datatype: DataType[T]) -> Callable[[object], T]: +def get_callable_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: match datatype: - case Enum(enum_cls): - - def cast_from_epics_type(value: object) -> T: - return datatype.validate(enum_cls(value)) + case Enum(): + def cast_to_epics_type(value) -> object: + return datatype.index_of(datatype.validate(value)) case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): - def cast_from_epics_type(value) -> T: + def cast_to_epics_type(value) -> object: return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_from_epics_type + return cast_to_epics_type + + +def builder_callable_from_attribute( + attribute: AttrR | AttrW | AttrRW, make_in_record: bool +): + match attribute.datatype: + case Bool(): + return builder.boolIn if make_in_record else builder.boolOut + case Int(): + return builder.longIn if make_in_record else builder.longOut + case Float(): + return builder.aIn if make_in_record else builder.aOut + case String(): + return builder.longStringIn if make_in_record else builder.longStringOut + case Enum(): + if len(attribute.datatype.members) > MBB_MAX_CHOICES: + return builder.longIn if make_in_record else builder.longOut + else: + return builder.mbbIn if make_in_record else builder.mbbOut + case WaveForm(): + return builder.WaveformIn if make_in_record else builder.WaveformOut + case _: + raise FastCSException( + f"EPICS unsupported datatype on {attribute}: {attribute.datatype}" + ) diff --git a/src/fastcs/transport/tango/util.py b/src/fastcs/transport/tango/util.py index 8e46f63c..3e4bb472 100644 --- a/src/fastcs/transport/tango/util.py +++ b/src/fastcs/transport/tango/util.py @@ -66,7 +66,7 @@ def get_cast_method_to_tango_type(datatype: DataType[T]) -> Callable[[T], object case Enum(): def cast_to_tango_type(value) -> int: - return datatype.validate(value).value + return datatype.index_of(datatype.validate(value)) case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): def cast_to_tango_type(value) -> object: @@ -78,10 +78,10 @@ def cast_to_tango_type(value) -> object: def get_cast_method_from_tango_type(datatype: DataType[T]) -> Callable[[object], T]: match datatype: - case Enum(enum_cls): + case Enum(): def cast_from_tango_type(value: object) -> T: - return datatype.validate(enum_cls(value)) + return datatype.validate(datatype.members[value]) case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): diff --git a/tests/conftest.py b/tests/conftest.py index 01e45838..3c04c7ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,11 +28,6 @@ class BackendTestController(TestController): read_bool: AttrR = AttrR(Bool()) write_bool: AttrW = AttrW(Bool(), handler=TestSender()) read_string: AttrRW = AttrRW(String()) - big_enum: AttrR = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) @pytest.fixture diff --git a/tests/test_launch.py b/tests/test_launch.py index 61b516fc..075f6a82 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -102,7 +102,7 @@ def test_over_defined_schema(): def test_version(): impl_version = "0.0.1" - expected = f"SingleArg: {impl_version}\n" f"FastCS: {__version__}\n" + expected = f"SingleArg: {impl_version}\nFastCS: {__version__}\n" app = _launch(SingleArg, version=impl_version) result = runner.invoke(app, ["version"]) assert result.exit_code == 0 diff --git a/tests/transport/epics/test_gui.py b/tests/transport/epics/test_gui.py index 71db7cb8..45faba94 100644 --- a/tests/transport/epics/test_gui.py +++ b/tests/transport/epics/test_gui.py @@ -50,7 +50,6 @@ def test_get_components(controller): ) ], ), - SignalR(name="BigEnum", read_pv="DEVICE:BigEnum", read_widget=TextRead()), SignalR(name="ReadBool", read_pv="DEVICE:ReadBool", read_widget=LED()), SignalR( name="ReadInt", diff --git a/tests/transport/epics/test_ioc.py b/tests/transport/epics/test_ioc.py index 78fd9aae..eee2b226 100644 --- a/tests/transport/epics/test_ioc.py +++ b/tests/transport/epics/test_ioc.py @@ -24,13 +24,12 @@ _add_sub_controller_pvi_info, _create_and_link_read_pv, _create_and_link_write_pv, - _get_input_record, - _get_output_record, + _make_record, ) from fastcs.transport.epics.util import ( MBB_STATE_FIELDS, - get_record_metadata_from_attribute, - get_record_metadata_from_datatype, + record_metadata_from_attribute, + record_metadata_from_datatype, ) DEVICE = "DEVICE" @@ -51,16 +50,16 @@ def record_input_from_enum(enum_cls: type[enum.IntEnum]) -> dict[str, str]: @pytest.mark.asyncio async def test_create_and_link_read_pv(mocker: MockerFixture): - get_input_record = mocker.patch("fastcs.transport.epics.ioc._get_input_record") + make_record = mocker.patch("fastcs.transport.epics.ioc._make_record") add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - record = get_input_record.return_value + record = make_record.return_value attribute = AttrR(Int()) attribute.set_update_callback = mocker.MagicMock() _create_and_link_read_pv("PREFIX", "PV", "attr", attribute) - get_input_record.assert_called_once_with("PREFIX:PV", attribute) + make_record.assert_called_once_with("PREFIX:PV", attribute) add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r") # Extract the callback generated and set in the function and call it @@ -81,11 +80,6 @@ class ColourEnum(enum.IntEnum): "attribute,record_type,kwargs", ( (AttrR(String()), "longStringIn", {}), - ( - AttrR(String(allowed_values=[member.name for member in list(ColourEnum)])), - "longStringIn", - {}, - ), ( AttrR(Enum(ColourEnum)), "mbbIn", @@ -99,36 +93,36 @@ class ColourEnum(enum.IntEnum): (AttrR(WaveForm(np.int32, (10,))), "WaveformIn", {}), ), ) -def test_get_input_record( +def test_make_input_record( attribute: AttrR, record_type: str, kwargs: dict[str, Any], mocker: MockerFixture, ): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") pv = "PV" - _get_input_record(pv, attribute) + _make_record(pv, attribute) + kwargs.update(record_metadata_from_datatype(attribute.datatype)) + kwargs.update(record_metadata_from_attribute(attribute)) getattr(builder, record_type).assert_called_once_with( pv, - **get_record_metadata_from_attribute(attribute), - **get_record_metadata_from_datatype(attribute.datatype), **kwargs, ) -def test_get_input_record_raises(mocker: MockerFixture): +def test_make_record_raises(mocker: MockerFixture): # Pass a mock as attribute to provoke the fallback case matching on datatype with pytest.raises(FastCSException): - _get_input_record("PV", mocker.MagicMock()) + _make_record("PV", mocker.MagicMock()) @pytest.mark.asyncio async def test_create_and_link_write_pv(mocker: MockerFixture): - get_output_record = mocker.patch("fastcs.transport.epics.ioc._get_output_record") + make_record = mocker.patch("fastcs.transport.epics.ioc._make_record") add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - record = get_output_record.return_value + record = make_record.return_value attribute = AttrW(Int()) attribute.process_without_display_update = mocker.AsyncMock() @@ -136,9 +130,7 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): _create_and_link_write_pv("PREFIX", "PV", "attr", attribute) - get_output_record.assert_called_once_with( - "PREFIX:PV", attribute, on_update=mocker.ANY - ) + make_record.assert_called_once_with("PREFIX:PV", attribute, on_update=mocker.ANY) add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w") # Extract the write update callback generated and set in the function and call it @@ -149,7 +141,7 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): record.set.assert_called_once_with(1, process=False) # Extract the on update callback generated and set in the function and call it - on_update_callback = get_output_record.call_args[1]["on_update"] + on_update_callback = make_record.call_args[1]["on_update"] await on_update_callback(1) attribute.process_without_display_update.assert_called_once_with(1) @@ -159,31 +151,30 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): "attribute,record_type,kwargs", ( ( - AttrR(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), + AttrW(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), "mbbOut", {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), - (AttrR(String(allowed_values=SEVENTEEN_VALUES)), "longStringOut", {}), ), ) -def test_get_output_record( +def test_make_output_record( attribute: AttrW, record_type: str, kwargs: dict[str, Any], mocker: MockerFixture, ): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") update = mocker.MagicMock() pv = "PV" - _get_output_record(pv, attribute, on_update=update) + _make_record(pv, attribute, on_update=update) + + kwargs.update(record_metadata_from_datatype(attribute.datatype)) + kwargs.update(record_metadata_from_attribute(attribute)) + kwargs.update({"always_update": True, "on_update": update}) getattr(builder, record_type).assert_called_once_with( pv, - always_update=True, - on_update=update, - **get_record_metadata_from_attribute(attribute), - **get_record_metadata_from_datatype(attribute.datatype), **kwargs, ) @@ -191,7 +182,7 @@ def test_get_output_record( def test_get_output_record_raises(mocker: MockerFixture): # Pass a mock as attribute to provoke the fallback case matching on datatype with pytest.raises(FastCSException): - _get_output_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) + _make_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) class EpicsAssertableController(AssertableController): @@ -203,11 +194,6 @@ class EpicsAssertableController(AssertableController): read_string = AttrRW(String()) enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) - big_enum = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) @pytest.fixture() @@ -216,7 +202,8 @@ def controller(class_mocker: MockerFixture): def test_ioc(mocker: MockerFixture, controller: Controller): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + ioc_builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") add_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_pvi_info") add_sub_controller_pvi_info = mocker.patch( "fastcs.transport.epics.ioc._add_sub_controller_pvi_info" @@ -227,20 +214,18 @@ def test_ioc(mocker: MockerFixture, controller: Controller): # Check records are created builder.boolIn.assert_called_once_with( f"{DEVICE}:ReadBool", - **get_record_metadata_from_attribute(controller.attributes["read_bool"]), - **get_record_metadata_from_datatype( - controller.attributes["read_bool"].datatype - ), + **record_metadata_from_attribute(controller.attributes["read_bool"]), + **record_metadata_from_datatype(controller.attributes["read_bool"].datatype), ) builder.longIn.assert_any_call( f"{DEVICE}:ReadInt", - **get_record_metadata_from_attribute(controller.attributes["read_int"]), - **get_record_metadata_from_datatype(controller.attributes["read_int"].datatype), + **record_metadata_from_attribute(controller.attributes["read_int"]), + **record_metadata_from_datatype(controller.attributes["read_int"].datatype), ) builder.aIn.assert_called_once_with( f"{DEVICE}:ReadWriteFloat_RBV", - **get_record_metadata_from_attribute(controller.attributes["read_write_float"]), - **get_record_metadata_from_datatype( + **record_metadata_from_attribute(controller.attributes["read_write_float"]), + **record_metadata_from_datatype( controller.attributes["read_write_float"].datatype ), ) @@ -248,20 +233,15 @@ def test_ioc(mocker: MockerFixture, controller: Controller): f"{DEVICE}:ReadWriteFloat", always_update=True, on_update=mocker.ANY, - **get_record_metadata_from_attribute(controller.attributes["read_write_float"]), - **get_record_metadata_from_datatype( + **record_metadata_from_attribute(controller.attributes["read_write_float"]), + **record_metadata_from_datatype( controller.attributes["read_write_float"].datatype ), ) - builder.longIn.assert_any_call( - f"{DEVICE}:BigEnum", - **get_record_metadata_from_attribute(controller.attributes["big_enum"]), - **get_record_metadata_from_datatype(controller.attributes["big_enum"].datatype), - ) builder.longIn.assert_any_call( f"{DEVICE}:ReadWriteInt_RBV", - **get_record_metadata_from_attribute(controller.attributes["read_write_int"]), - **get_record_metadata_from_datatype( + **record_metadata_from_attribute(controller.attributes["read_write_int"]), + **record_metadata_from_datatype( controller.attributes["read_write_int"].datatype ), ) @@ -269,39 +249,31 @@ def test_ioc(mocker: MockerFixture, controller: Controller): f"{DEVICE}:ReadWriteInt", always_update=True, on_update=mocker.ANY, - **get_record_metadata_from_attribute(controller.attributes["read_write_int"]), - **get_record_metadata_from_datatype( + **record_metadata_from_attribute(controller.attributes["read_write_int"]), + **record_metadata_from_datatype( controller.attributes["read_write_int"].datatype ), ) builder.mbbIn.assert_called_once_with( f"{DEVICE}:Enum_RBV", - ZRST="RED", - ONST="GREEN", - TWST="BLUE", - **get_record_metadata_from_attribute(controller.attributes["enum"]), - **get_record_metadata_from_datatype(controller.attributes["enum"].datatype), + **record_metadata_from_attribute(controller.attributes["enum"]), + **record_metadata_from_datatype(controller.attributes["enum"].datatype), ) builder.mbbOut.assert_called_once_with( f"{DEVICE}:Enum", - ZRST="RED", - ONST="GREEN", - TWST="BLUE", - **get_record_metadata_from_attribute(controller.attributes["enum"]), - **get_record_metadata_from_datatype(controller.attributes["enum"].datatype), always_update=True, on_update=mocker.ANY, + **record_metadata_from_attribute(controller.attributes["enum"]), + **record_metadata_from_datatype(controller.attributes["enum"].datatype), ) builder.boolOut.assert_called_once_with( f"{DEVICE}:WriteBool", always_update=True, on_update=mocker.ANY, - **get_record_metadata_from_attribute(controller.attributes["write_bool"]), - **get_record_metadata_from_datatype( - controller.attributes["write_bool"].datatype - ), + **record_metadata_from_attribute(controller.attributes["write_bool"]), + **record_metadata_from_datatype(controller.attributes["write_bool"].datatype), ) - builder.Action.assert_any_call(f"{DEVICE}:Go", on_update=mocker.ANY) + ioc_builder.Action.assert_any_call(f"{DEVICE}:Go", on_update=mocker.ANY) # Check info tags are added add_pvi_info.assert_called_once_with(f"{DEVICE}:PVI") @@ -423,7 +395,8 @@ class ControllerLongNames(Controller): def test_long_pv_names_discarded(mocker: MockerFixture): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + ioc_builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") long_name_controller = ControllerLongNames() long_attr_name = "attr_r_with_reallyreallyreallyreallyreallyreallyreally_long_name" long_rw_name = "attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV" @@ -438,17 +411,17 @@ def test_long_pv_names_discarded(mocker: MockerFixture): f"{DEVICE}:{short_pv_name}", always_update=True, on_update=mocker.ANY, - **get_record_metadata_from_datatype( + **record_metadata_from_datatype( long_name_controller.attr_rw_short_name.datatype ), - **get_record_metadata_from_attribute(long_name_controller.attr_rw_short_name), + **record_metadata_from_attribute(long_name_controller.attr_rw_short_name), ) builder.longIn.assert_called_once_with( f"{DEVICE}:{short_pv_name}_RBV", - **get_record_metadata_from_datatype( + **record_metadata_from_datatype( long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV.datatype ), - **get_record_metadata_from_attribute( + **record_metadata_from_attribute( long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV ), ) @@ -477,12 +450,12 @@ def test_long_pv_names_discarded(mocker: MockerFixture): assert long_name_controller.command_short_name.fastcs_method.enabled long_command_name = ( - "command_with_" "reallyreallyreallyreallyreallyreallyreally_long_name" + "command_with_reallyreallyreallyreallyreallyreallyreally_long_name" ) assert not getattr(long_name_controller, long_command_name).fastcs_method.enabled short_command_pv_name = "command_short_name".title().replace("_", "") - builder.Action.assert_called_once_with( + ioc_builder.Action.assert_called_once_with( f"{DEVICE}:{short_command_pv_name}", on_update=mocker.ANY, ) @@ -497,17 +470,17 @@ def test_long_pv_names_discarded(mocker: MockerFixture): def test_update_datatype(mocker: MockerFixture): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") pv_name = f"{DEVICE}:Attr" attr_r = AttrR(Int()) - record_r = _get_input_record(pv_name, attr_r) + record_r = _make_record(pv_name, attr_r) builder.longIn.assert_called_once_with( pv_name, - **get_record_metadata_from_attribute(attr_r), - **get_record_metadata_from_datatype(attr_r.datatype), + **record_metadata_from_attribute(attr_r), + **record_metadata_from_datatype(attr_r.datatype), ) record_r.set_field.assert_not_called() attr_r.update_datatype(Int(units="m", min=-3)) @@ -521,12 +494,12 @@ def test_update_datatype(mocker: MockerFixture): attr_r.update_datatype(String()) # type: ignore attr_w = AttrW(Int()) - record_w = _get_output_record(pv_name, attr_w, on_update=mocker.ANY) + record_w = _make_record(pv_name, attr_w, on_update=mocker.ANY) builder.longIn.assert_called_once_with( pv_name, - **get_record_metadata_from_attribute(attr_w), - **get_record_metadata_from_datatype(attr_w.datatype), + **record_metadata_from_attribute(attr_w), + **record_metadata_from_datatype(attr_w.datatype), ) record_w.set_field.assert_not_called() attr_w.update_datatype(Int(units="m", min=-3)) diff --git a/tests/transport/graphQL/test_graphQL.py b/tests/transport/graphQL/test_graphQL.py index 0f61dd7c..8ba57eeb 100644 --- a/tests/transport/graphQL/test_graphQL.py +++ b/tests/transport/graphQL/test_graphQL.py @@ -24,11 +24,6 @@ class RestAssertableController(AssertableController): read_bool = AttrR(Bool()) write_bool = AttrW(Bool(), handler=TestSender()) read_string = AttrRW(String()) - big_enum = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) @pytest.fixture(scope="class") @@ -136,15 +131,6 @@ def test_write_bool(self, client, assertable_controller): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, value) - def test_big_enum(self, client, assertable_controller): - expect = 0 - path = ["bigEnum"] - query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["big_enum"]): - response = client.post("/graphql", json={"query": query}) - assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) - def test_go(self, client, assertable_controller): path = ["go"] mutation = f"mutation {{ {nest_query(path)} }}" diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index e237b1eb..e69f76a3 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -26,11 +26,6 @@ class RestAssertableController(AssertableController): enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) two_d_waveform = AttrRW(WaveForm(np.int32, (10, 10))) - big_enum = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) @pytest.fixture(scope="class") @@ -102,13 +97,6 @@ def test_enum(self, assertable_controller, client): assert isinstance(enum_attr.get(), enum_cls) assert enum_attr.get() == enum_cls(2) - def test_big_enum(self, assertable_controller, client): - expect = 0 - with assertable_controller.assert_read_here(["big_enum"]): - response = client.get("/big-enum") - assert response.status_code == 200 - assert response.json()["value"] == expect - def test_1d_waveform(self, assertable_controller, client): attribute = assertable_controller.attributes["one_d_waveform"] expect = np.zeros((10,), dtype=np.int32) diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index b4beed27..e0d3aa78 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -27,11 +27,6 @@ class TangoAssertableController(AssertableController): enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) two_d_waveform = AttrRW(WaveForm(np.int32, (10, 10))) - big_enum = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) @pytest.fixture(scope="class") @@ -49,7 +44,6 @@ def tango_context(self, assertable_controller): def test_list_attributes(self, tango_context): assert list(tango_context.get_attribute_list()) == [ - "BigEnum", "Enum", "OneDWaveform", "ReadBool", @@ -132,12 +126,6 @@ def test_enum(self, assertable_controller, tango_context): assert isinstance(enum_attr.get(), enum_cls) assert enum_attr.get() == enum_cls(1) - def test_big_enum(self, assertable_controller, tango_context): - expect = 0 - with assertable_controller.assert_read_here(["big_enum"]): - result = tango_context.read_attribute("BigEnum").value - assert result == expect - def test_1d_waveform(self, assertable_controller, tango_context): expect = np.zeros((10,), dtype=np.int32) with assertable_controller.assert_read_here(["one_d_waveform"]):