diff --git a/linkml_runtime/utils/schemaview.py b/linkml_runtime/utils/schemaview.py index 78f94240..49343632 100644 --- a/linkml_runtime/utils/schemaview.py +++ b/linkml_runtime/utils/schemaview.py @@ -1435,7 +1435,7 @@ def slot_range_as_union(self, slot: SlotDefinition) -> List[ElementName]: if x.range: range_union_of.append(x.range) return range_union_of - + def get_classes_by_slot(self, slot: SlotDefinition, include_induced: bool = False) -> List[ClassDefinitionName]: """Get all classes that use a given slot, either as a direct or induced slot. @@ -1443,21 +1443,21 @@ def get_classes_by_slot(self, slot: SlotDefinition, include_induced: bool = Fals :param include_induced: supplement all direct slots with induced slots, defaults to False :return: list of slots, either direct, or both direct and induced """ - slots_list = [] # list of all direct or induced slots + direct_classes_list = [] # list of classes associated with slot directly + induced_classes_list = [] # list of classes associated with slot indirectly for c_name, c in self.all_classes().items(): - # check if slot is direct specification on class if slot.name in c.slots: - slots_list.append(c_name) - - # include induced classes also if requested - if include_induced: - for c_name, c in self.all_classes().items(): + direct_classes_list.append(c_name) + elif include_induced: for ind_slot in self.class_induced_slots(c_name): if ind_slot.name == slot.name: - slots_list.append(c_name) + induced_classes_list.append(c_name) - return list(dict.fromkeys(slots_list)) + if include_induced: + return list(set(direct_classes_list + induced_classes_list)) + else: + return list(set(direct_classes_list)) @lru_cache() def get_slots_by_enum(self, enum_name: ENUM_NAME = None) -> List[SlotDefinition]: diff --git a/tests/test_utils/test_schemaview.py b/tests/test_utils/test_schemaview.py index aca871ff..aa9f5c66 100644 --- a/tests/test_utils/test_schemaview.py +++ b/tests/test_utils/test_schemaview.py @@ -735,7 +735,7 @@ def test_get_classes_by_slot(self): actual_result = sv.get_classes_by_slot(slot, include_induced=True) expected_result = ["Person", "Adult"] - self.assertListEqual(actual_result, expected_result) + self.assertListEqual(sorted(actual_result), sorted(expected_result)) def test_materialize_patterns(self): sv = SchemaView(SCHEMA_WITH_STRUCTURED_PATTERNS)