diff --git a/core/database_arango.py b/core/database_arango.py index 83dbac3b9..90edf033d 100644 --- a/core/database_arango.py +++ b/core/database_arango.py @@ -2,17 +2,13 @@ import datetime import json import logging -import re import sys import time -import unicodedata from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type, TypeVar if TYPE_CHECKING: + from core.schemas import entity, indicator, observable from core.schemas.graph import Relationship, TagRelationship - from core.schemas import observable - from core.schemas import entity - from core.schemas import indicator import requests from arango import ArangoClient @@ -303,41 +299,34 @@ def tag( ) -> TYetiObject: """Connects object to tag graph.""" # Import at runtime to avoid circular dependency. - from core.schemas.tag import DEFAULT_EXPIRATION_DAYS, Tag + from core.schemas import tag - expiration_days = expiration_days or DEFAULT_EXPIRATION_DAYS + expiration_days = expiration_days or tag.DEFAULT_EXPIRATION_DAYS if strict: self.clear_tags() extra_tags = set() - for tag_name in tags: - # Attempt to find replacement tag - if normalized: - nfkd_form = unicodedata.normalize("NFKD", tag_name) - nfkd_form.encode("ASCII", "ignore").decode("UTF-8") - tag_name = "".join( - [c for c in nfkd_form if not unicodedata.combining(c)] - ) - tag_name = tag_name.strip() - tag_name = re.sub(r"\s+", "_", tag_name).lower() - tag_name = re.sub(r"[^a-zA-Z0-9_]", "", tag_name) - replacements, _ = Tag.filter({"in__replaces": [tag_name]}, count=1) - tag: Optional[Tag] = None + for provided_tag_name in tags: + tag_name = tag.normalize_name(provided_tag_name) + if not tag_name: + raise RuntimeError(f"Cannot tag object with empty tag: '{provided_tag_name}' -> '{tag_name}'") + replacements, _ = tag.Tag.filter({"in__replaces": [tag_name]}, count=1) + new_tag: Optional[tag.Tag] = None if replacements: - tag = replacements[0] + new_tag = replacements[0] # Attempt to find actual tag else: - tag = Tag.find(name=tag_name) + new_tag = tag.Tag.find(name=tag_name) # Create tag - if not tag: - tag = Tag(name=tag_name).save() + if not new_tag: + new_tag = tag.Tag(name=tag_name).save() - tag_link = self.link_to_tag(tag.name) - self.tags[tag.name] = tag_link + tag_link = self.link_to_tag(new_tag.name) + self.tags[new_tag.name] = tag_link - extra_tags |= set(tag.produces) + extra_tags |= set(new_tag.produces) extra_tags -= set(tags) if extra_tags: diff --git a/core/schemas/tag.py b/core/schemas/tag.py index 60c1a604a..c6368c111 100644 --- a/core/schemas/tag.py +++ b/core/schemas/tag.py @@ -1,4 +1,6 @@ import datetime +import re +import unicodedata from typing import ClassVar from pydantic import BaseModel, Field @@ -11,6 +13,16 @@ def future(): return datetime.timedelta(days=DEFAULT_EXPIRATION_DAYS) +def normalize_name(tag_name: str) -> str: + nfkd_form = unicodedata.normalize("NFKD", tag_name) + nfkd_form.encode("ASCII", "ignore").decode("UTF-8") + tag_name = "".join( + [c for c in nfkd_form if not unicodedata.combining(c)] + ) + tag_name = tag_name.strip().lower() + tag_name = re.sub(r"\s+", "_", tag_name) + tag_name = re.sub(r"[^a-zA-Z0-9_:-]", "", tag_name) + return tag_name class Tag(BaseModel, database_arango.ArangoYetiConnector): _collection_name: ClassVar[str] = "tags" diff --git a/tests/schemas/tag.py b/tests/schemas/tag.py index 2d2c4d832..37ce04da3 100644 --- a/tests/schemas/tag.py +++ b/tests/schemas/tag.py @@ -171,18 +171,18 @@ def test_duplicate_name(self): def test_normalized_tag(self): """Tests that a tag can be normalized.""" - cases = cases = [ + cases = [ ("H@ackërS T3st", "hackers_t3st"), (" SpaCesStartEnd ", "spacesstartend"), ("!!Sp3cial##", "sp3cial"), - ("Multi Spaces", "multi_spaces"), + ("Multi Spaces After", "multi_spaces_after"), ("Élévation", "elevation"), ("UNDER_score", "under_score"), ("mixCaseMix123", "mixcasemix123"), ("MïxedÁccénts", "mixedaccents"), ("123456", "123456"), ("测试chinese", "chinese"), - ("", ""), + ("type:some-custom-type", "type:some-custom-type") ] for cmp, (tag_non_norm, tag_norm) in enumerate(cases):