From 256a523351466e4ae98ea9f26ed5ec7fd79f48fa Mon Sep 17 00:00:00 2001 From: XuanYang-cn Date: Tue, 23 May 2023 16:45:22 +0800 Subject: [PATCH] Multi cherry from 2.2 to master (#1453) - RBAC supports Database validation (#1396) - Support database API (#1401) - Using database by URI (#1420) Signed-off-by: yangxuan Co-authored-by: SimFG Co-authored-by: jaime --- examples/database.py | 144 ++++++++++++++++++++++++++++++++ examples/role_and_privilege.py | 29 ++++--- pymilvus/__init__.py | 4 +- pymilvus/client/check.py | 1 + pymilvus/client/grpc_handler.py | 51 +++++++++-- pymilvus/client/prepare.py | 26 +++++- pymilvus/client/types.py | 6 ++ pymilvus/orm/connections.py | 30 +++++-- pymilvus/orm/db.py | 44 ++++++++++ pymilvus/orm/role.py | 24 ++++-- tests/test_connections.py | 2 + 11 files changed, 320 insertions(+), 41 deletions(-) create mode 100644 examples/database.py create mode 100644 pymilvus/orm/db.py diff --git a/examples/database.py b/examples/database.py new file mode 100644 index 000000000..32015bb22 --- /dev/null +++ b/examples/database.py @@ -0,0 +1,144 @@ +import random + +from pymilvus import ( + connections, + FieldSchema, CollectionSchema, DataType, + Collection, + db, +) +from pymilvus.orm import utility + +_HOST = '127.0.0.1' +_PORT = '19530' +_ROOT = "root" +_ROOT_PASSWORD = "Milvus" +_METRIC_TYPE = 'IP' +_INDEX_TYPE = 'IVF_FLAT' +_NLIST = 1024 +_NPROBE = 16 +_TOPK = 3 + +# Vector parameters +_DIM = 128 +_INDEX_FILE_SIZE = 32 # max file size of stored index + + +def connect_to_milvus(db_name="default"): + print(f"connect to milvus\n") + connections.connect(host=_HOST, + port=_PORT, + user=_ROOT, + password=_ROOT_PASSWORD, + db_name=db_name, + ) + + +def connect_to_milvus_with_uri(db_name="default"): + print(f"connect to milvus\n") + connections.connect( + alias="uri-connection", + uri="http://{}:{}/{}".format(_HOST, _PORT, db_name), + ) + + +def create_collection(collection_name, db_name): + default_fields = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True), + FieldSchema(name="double", dtype=DataType.DOUBLE), + FieldSchema(name="fv", dtype=DataType.FLOAT_VECTOR, dim=128) + ] + default_schema = CollectionSchema(fields=default_fields) + print(f"Create collection:{collection_name} within db:{db_name}") + return Collection(name=collection_name, schema=default_schema) + + +def insert(collection, num, dim): + data = [ + [i for i in range(num)], + [float(i) for i in range(num)], + [[random.random() for _ in range(dim)] for _ in range(num)], + ] + collection.insert(data) + return data[2] + + +def drop_index(collection): + collection.drop_index() + print("\nDrop index sucessfully") + + +def search(collection, vector_field, id_field, search_vectors): + search_param = { + "data": search_vectors, + "anns_field": vector_field, + "param": {"metric_type": _METRIC_TYPE, "params": {"nprobe": _NPROBE}}, + "limit": _TOPK, + "expr": "id >= 0"} + results = collection.search(**search_param) + for i, result in enumerate(results): + print("\nSearch result for {}th vector: ".format(i)) + for j, res in enumerate(result): + print("Top {}: {}".format(j, res)) + + +def collection_read_write(collection, db_name): + col_name = "{}:{}".format(db_name, collection.name) + vectors = insert(collection, 10000, _DIM) + collection.flush() + print("\nInsert {} rows data into collection:{}".format(collection.num_entities, col_name)) + + # create index + index_param = { + "index_type": _INDEX_TYPE, + "params": {"nlist": _NLIST}, + "metric_type": _METRIC_TYPE} + collection.create_index("fv", index_param) + print("\nCreated index:{} for collection:{}".format(collection.index().params, col_name)) + + # load data to memory + print("\nLoad collection:{}".format(col_name)) + collection.load() + # search + print("\nSearch collection:{}".format(col_name)) + search(collection, "fv", "id", vectors[:3]) + + # release memory + collection.release() + # drop collection index + collection.drop_index() + print("\nDrop collection:{}".format(col_name)) + + +if __name__ == '__main__': + # connect to milvus and using database db1 + # there will not check db1 already exists during connect + connect_to_milvus(db_name="default") + + # create collection within default + col1_db1 = create_collection("col1_db1", "default") + + # create db1 + if "db1" not in db.list_database(): + print("\ncreate database: db1") + db.create_database(db_name="db1") + + # use database db1 + db.using_database(db_name="db1") + # create collection within default + col2_db1 = create_collection("col1_db1", "db1") + + # verify read and write + collection_read_write(col2_db1, "db1") + + # list collections within db1 + print("\nlist collections of database db1:") + print(utility.list_collections()) + + print("\ndrop collection: col1_db2 from db1") + col2_db1.drop() + print("\ndrop database: db1") + db.drop_database(db_name="db1") + + # list database + print("\nlist databases:") + print(db.list_database()) diff --git a/examples/role_and_privilege.py b/examples/role_and_privilege.py index 0309ba327..c16f4d695 100644 --- a/examples/role_and_privilege.py +++ b/examples/role_and_privilege.py @@ -6,6 +6,7 @@ _CONNECTION = "demo" _FOO_CONNECTION = "foo_connection" +_DB_NAME = "foo_db" _HOST = '127.0.0.1' _PORT = '19530' _ROOT = "root" @@ -13,13 +14,14 @@ _COLLECTION_NAME = "foocol2" -def connect_to_milvus(connection=_CONNECTION, user=_ROOT, password=_ROOT_PASSWORD): +def connect_to_milvus(connection=_CONNECTION, user=_ROOT, password=_ROOT_PASSWORD, db_name="default"): print(f"connect to milvus\n") connections.connect(alias=connection, host=_HOST, port=_PORT, user=user, password=password, + db_name=db_name, ) @@ -133,9 +135,9 @@ def rbac_user(username, password, role_name, connection=_CONNECTION): is_exception = True assert is_exception role = Role(role_name, using=_CONNECTION) - role.grant("User", "*", "SelectUser") + role.grant("User", "*", "SelectUser", db_name=_DB_NAME) print(select_all_user(connection)) - role.revoke("User", "*", "SelectUser") + role.revoke("User", "*", "SelectUser", db_name=_DB_NAME) def role_example(): @@ -182,26 +184,31 @@ def privilege_example(): print(f"add user") role.add_user(username) print(f"grant privilege") - role.grant("Global", "*", privilege_create) - role.grant("Collection", object_name, privilege_insert) - # role.grant("Collection", object_name, "*") - # role.grant("Collection", "*", privilege_insert) + role.grant("Global", "*", privilege_create, db_name=_DB_NAME) + role.grant("Collection", object_name, privilege_insert, db_name=_DB_NAME) + # role.grant("Collection", object_name, "*", db_name=_DB_NAME) + # role.grant("Collection", "*", privilege_insert, db_name=_DB_NAME) + + print(f"list grants") + print(role.list_grants(db_name=_DB_NAME)) + print(f"list grant") + print(role.list_grant("Collection", object_name, db_name=_DB_NAME)) print(f"list grants") print(role.list_grants()) print(f"list grant") print(role.list_grant("Collection", object_name)) - connect_to_milvus(connection=_FOO_CONNECTION, user=username, password=password) + connect_to_milvus(connection=_FOO_CONNECTION, user=username, password=password, db_name=_DB_NAME) has_collection(_COLLECTION_NAME, connection=_FOO_CONNECTION) rbac_collection(connection=_FOO_CONNECTION) rbac_user(username, password, role_name, connection=_FOO_CONNECTION) print(f"revoke privilege") role.revoke("Global", "*", privilege_create) - role.revoke("Collection", object_name, privilege_insert) - # role.revoke("Collection", object_name, "*") - # role.revoke("Collection", "*", privilege_insert) + role.revoke("Collection", object_name, privilege_insert, db_name=_DB_NAME) + # role.revoke("Collection", object_name, "*", db_name=_DB_NAME) + # role.revoke("Collection", "*", privilege_insert, db_name=_DB_NAME) print(f"remove user") role.remove_user(username) role.drop() diff --git a/pymilvus/__init__.py b/pymilvus/__init__.py index 844a6fbd2..d45af9051 100644 --- a/pymilvus/__init__.py +++ b/pymilvus/__init__.py @@ -66,7 +66,7 @@ list_resource_groups, transfer_node, transfer_replica ) -from .orm import utility +from .orm import utility, db from .orm.search import SearchResult, Hits, Hit from .orm.schema import FieldSchema, CollectionSchema @@ -86,7 +86,7 @@ 'SearchResult', 'Hits', 'Hit', 'Replica', 'Group', 'Shard', 'FieldSchema', 'CollectionSchema', 'SearchFuture', 'MutationFuture', - 'utility', 'DefaultConfig', 'ExceptionsMessage', 'MilvusUnavailableException', 'BulkInsertState', + 'utility', 'db', 'DefaultConfig', 'ExceptionsMessage', 'MilvusUnavailableException', 'BulkInsertState', 'Role', 'create_resource_group', 'drop_resource_group', 'describe_resource_group', 'list_resource_groups', 'transfer_node', 'transfer_replica', diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py index b917572a4..d0d2a02a6 100644 --- a/pymilvus/client/check.py +++ b/pymilvus/client/check.py @@ -310,6 +310,7 @@ def is_legal_operate_privilege_type(operate_privilege_type: Any) -> bool: class ParamChecker(metaclass=Singleton): def __init__(self) -> None: self.check_dict = { + "db_name": is_legal_table_name, "collection_name": is_legal_table_name, "field_name": is_legal_field_name, "dimension": is_legal_dimension, diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 086cb30dd..819727e20 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -80,6 +80,7 @@ def __init__(self, uri=Config.GRPC_URI, host="", port="", channel=None, **kwargs self._request_id = None self._user = kwargs.get("user", None) self._set_authorization(**kwargs) + self._setup_db_interceptor(kwargs.get("db_name", None)) self._setup_grpc_channel() def __get_address(self, uri: str, host: str, port: str) -> str: @@ -127,12 +128,22 @@ def _wait_for_channel_ready(self, timeout=10): def close(self): self._channel.close() + def reset_db_name(self, db_name): + self._setup_db_interceptor(db_name) + self._setup_grpc_channel() + def _setup_authorization_interceptor(self, user, password): if user and password: authorization = base64.b64encode(f"{user}:{password}".encode('utf-8')) key = "authorization" self._authorization_interceptor = interceptor.header_adder_interceptor(key, authorization) + def _setup_db_interceptor(self, db_name): + if db_name: + self._db_interceptor = interceptor.header_adder_interceptor("dbname", db_name) + else: + self._db_interceptor = None + def _setup_grpc_channel(self): """ Create a ddl grpc channel """ if self._channel is None: @@ -174,6 +185,8 @@ def _setup_grpc_channel(self): self._final_channel = self._channel if self._authorization_interceptor: self._final_channel = grpc.intercept_channel(self._final_channel, self._authorization_interceptor) + if self._db_interceptor: + self._final_channel = grpc.intercept_channel(self._final_channel, self._db_interceptor) if self._log_level: log_level_interceptor = interceptor.header_adder_interceptor("log_level", self._log_level) self._final_channel = grpc.intercept_channel(self._final_channel, log_level_interceptor) @@ -809,6 +822,28 @@ def get_loading_progress(self, collection_name, partition_names=None, timeout=No raise MilvusException(response.status.error_code, response.status.reason) return response.progress + @retry_on_rpc_failure() + def create_database(self, db_name, timeout=None): + request = Prepare.create_database_req(db_name) + status = self._stub.CreateDatabase(request, timeout=timeout) + if status.error_code != 0: + raise MilvusException(status.error_code, status.reason) + + @retry_on_rpc_failure() + def drop_database(self, db_name, timeout=None): + request = Prepare.drop_database_req(db_name) + status = self._stub.DropDatabase(request, timeout=timeout) + if status.error_code != 0: + raise MilvusException(status.error_code, status.reason) + + @retry_on_rpc_failure() + def list_database(self, timeout=None): + request = Prepare.list_database_req() + response = self._stub.ListDatabases(request, timeout=timeout) + if response.status.error_code != 0: + raise MilvusException(response.status.error_code, response.status.reason) + return list(response.db_names) + @retry_on_rpc_failure() def get_load_state(self, collection_name, partition_names=None, timeout=None): request = Prepare.get_load_state(collection_name, partition_names) @@ -1227,24 +1262,24 @@ def select_all_user(self, include_role_info, timeout=None, **kwargs): return UserInfo(resp.results) @retry_on_rpc_failure() - def grant_privilege(self, role_name, object, object_name, privilege, timeout=None, **kwargs): - req = Prepare.operate_privilege_request(role_name, object, object_name, privilege, + def grant_privilege(self, role_name, object, object_name, privilege, db_name, timeout=None, **kwargs): + req = Prepare.operate_privilege_request(role_name, object, object_name, privilege, db_name, milvus_types.OperatePrivilegeType.Grant) resp = self._stub.OperatePrivilege(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def revoke_privilege(self, role_name, object, object_name, privilege, timeout=None, **kwargs): - req = Prepare.operate_privilege_request(role_name, object, object_name, privilege, + def revoke_privilege(self, role_name, object, object_name, privilege, db_name, timeout=None, **kwargs): + req = Prepare.operate_privilege_request(role_name, object, object_name, privilege, db_name, milvus_types.OperatePrivilegeType.Revoke) resp = self._stub.OperatePrivilege(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def select_grant_for_one_role(self, role_name, timeout=None, **kwargs): - req = Prepare.select_grant_request(role_name, None, None) + def select_grant_for_one_role(self, role_name, db_name, timeout=None, **kwargs): + req = Prepare.select_grant_request(role_name, None, None, db_name) resp = self._stub.SelectGrant(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: raise MilvusException(resp.status.error_code, resp.status.reason) @@ -1252,8 +1287,8 @@ def select_grant_for_one_role(self, role_name, timeout=None, **kwargs): return GrantInfo(resp.entities) @retry_on_rpc_failure() - def select_grant_for_role_and_object(self, role_name, object, object_name, timeout=None, **kwargs): - req = Prepare.select_grant_request(role_name, object, object_name) + def select_grant_for_role_and_object(self, role_name, object, object_name, db_name, timeout=None, **kwargs): + req = Prepare.select_grant_request(role_name, object, object_name, db_name) resp = self._stub.SelectGrant(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: raise MilvusException(resp.status.error_code, resp.status.reason) diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 8a1ba25f5..966d7a7ef 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -794,7 +794,7 @@ def select_user_request(cls, username, include_role_info): include_role_info=include_role_info) @classmethod - def operate_privilege_request(cls, role_name, object, object_name, privilege, operate_privilege_type): + def operate_privilege_request(cls, role_name, object, object_name, privilege, db_name, operate_privilege_type): check_pass_param(role_name=role_name) check_pass_param(object=object) check_pass_param(object_name=object_name) @@ -804,12 +804,13 @@ def operate_privilege_request(cls, role_name, object, object_name, privilege, op entity=milvus_types.GrantEntity(role=milvus_types.RoleEntity(name=role_name), object=milvus_types.ObjectEntity(name=object), object_name=object_name, + db_name=db_name, grantor=milvus_types.GrantorEntity( privilege=milvus_types.PrivilegeEntity(name=privilege))), type=operate_privilege_type) @classmethod - def select_grant_request(cls, role_name, object, object_name): + def select_grant_request(cls, role_name, object, object_name, db_name): check_pass_param(role_name=role_name) if object: check_pass_param(object=object) @@ -818,7 +819,9 @@ def select_grant_request(cls, role_name, object, object_name): return milvus_types.SelectGrantRequest( entity=milvus_types.GrantEntity(role=milvus_types.RoleEntity(name=role_name), object=milvus_types.ObjectEntity(name=object) if object else None, - object_name=object_name if object_name else None)) + object_name=object_name if object_name else None, + db_name=db_name, + )) @classmethod def get_server_version(cls): @@ -887,3 +890,20 @@ def register_request(cls, user, host, **kwargs): return milvus_types.ConnectRequest( client_info=this, ) + + @classmethod + def create_database_req(cls, db_name): + check_pass_param(db_name=db_name) + req = milvus_types.CreateDatabaseRequest(db_name=db_name) + return req + + @classmethod + def drop_database_req(cls, db_name): + check_pass_param(db_name=db_name) + req = milvus_types.DropDatabaseRequest(db_name=db_name) + return req + + @classmethod + def list_database_req(cls): + req = milvus_types.ListDatabasesRequest() + return req diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index 650491b28..ef74a42af 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -537,12 +537,14 @@ class GrantItem: def __init__(self, entity): self._object = entity.object.name self._object_name = entity.object_name + self._db_name = entity.db_name self._role_name = entity.role.name self._grantor_name = entity.grantor.user.name self._privilege = entity.grantor.privilege.name def __repr__(self) -> str: s = f"GrantItem: , , " \ + f", " \ f", , " \ f"" return s @@ -555,6 +557,10 @@ def object(self): def object_name(self): return self._object_name + @property + def db_name(self): + return self._db_name + @property def role_name(self): return self._role_name diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 5d57dcaef..acc620ffd 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -104,13 +104,13 @@ def __verify_host_port(self, host, port): if not 0 <= int(port) < 65535: raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") - def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult): illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'" try: parsed_uri = parse.urlparse(uri) except (Exception) as e: - raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None + raise ConnectionConfigException( + message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None if len(parsed_uri.netloc) == 0: raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None @@ -174,10 +174,12 @@ def add_connection(self, **kwargs): self._alias[alias] = alias_config - def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> (str, parse.ParseResult): + def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> ( + str, parse.ParseResult): if address != "": if not is_legal_address(address): - raise ConnectionConfigException(message=f"Illegal address: {address}, should be in form 'localhost:19530'") + raise ConnectionConfigException( + message=f"Illegal address: {address}, should be in form 'localhost:19530'") return address, None if uri != "": @@ -214,7 +216,7 @@ def remove_connection(self, alias: str): self.disconnect(alias) self._alias.pop(alias, None) - def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", **kwargs): + def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", db_name="", **kwargs): """ Constructs a milvus connection and register it under given alias. @@ -242,6 +244,8 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", **kwargs * *password* (``str``) -- Optional and required when user is provided. The password corresponding to the user. + * *db_name* (``str``) -- + Optional. default database name of this connection * *client_key_path* (``str``) -- Optional. If use tls two-way authentication, need to write the client.key path. * *client_pem_path* (``str``) -- @@ -262,6 +266,7 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", **kwargs >>> from pymilvus import connections >>> connections.connect("test", host="localhost", port="19530") """ + def connect_milvus(**kwargs): gh = GrpcHandler(**kwargs) @@ -270,7 +275,9 @@ def connect_milvus(**kwargs): gh._wait_for_channel_ready(timeout=timeout) kwargs.pop('password') + kwargs.pop('db_name', None) kwargs.pop('secure', None) + kwargs.pop("db_name", "") self._connected_alias[alias] = gh self._alias[alias] = copy.deepcopy(kwargs) @@ -313,11 +320,17 @@ def with_config(config: Tuple) -> bool: user = parsed_uri.username if parsed_uri.username is not None else user password = parsed_uri.password if parsed_uri.password is not None else password + group = parsed_uri.path.split("/") + db_name = "default" + if len(group) > 1: + db_name = group[1] + # Set secure=True if username and password are provided if len(user) > 0 and len(password) > 0: kwargs["secure"] = True - connect_milvus(**kwargs, user=user, password=password) + + connect_milvus(**kwargs, user=user, password=password, db_name=db_name) return # 2nd Priority, connection configs from env @@ -331,20 +344,19 @@ def with_config(config: Tuple) -> bool: if len(user) > 0 and len(password) > 0: kwargs["secure"] = True - connect_milvus(**kwargs, user=user, password=password) + connect_milvus(**kwargs, user=user, password=password, db_name=db_name) return # 3rd Priority, connect to cached configs with provided user and password if alias in self._alias: connect_alias = dict(self._alias[alias].items()) connect_alias["user"] = user - connect_milvus(**connect_alias, password=password, **kwargs) + connect_milvus(**connect_alias, password=password, db_name=db_name, **kwargs) return # No params, env, and cached configs for the alias raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) - def list_connections(self) -> list: """ List names of all connections. diff --git a/pymilvus/orm/db.py b/pymilvus/orm/db.py new file mode 100644 index 000000000..866fecc60 --- /dev/null +++ b/pymilvus/orm/db.py @@ -0,0 +1,44 @@ +from pymilvus import connections + + +def _get_connection(alias): + return connections._fetch_handler(alias) + + +def using_database(db_name, using="default"): + """ Using a database as a default database name within this connection + + :param db_name: Database name + :type db_name: str + + """ + _get_connection(using).reset_db_name(db_name) + + +def create_database(db_name, using="default", timeout=None): + """ Create a database using provided database name + + :param db_name: Database name + :type db_name: str + + """ + _get_connection(using).create_database(db_name, timeout=timeout) + + +def drop_database(db_name, using="default", timeout=None): + """ Drop a database using provided database name + + :param db_name: Database name + :type db_name: str + + """ + _get_connection(using).drop_database(db_name, timeout=timeout) + + +def list_database(using="default", timeout=None) -> list: + """ List databases + + :return list[str]: + List of database names, return when operation is successful + """ + return _get_connection(using).list_database(timeout=timeout) diff --git a/pymilvus/orm/role.py b/pymilvus/orm/role.py index a969652dd..b8c49d36b 100644 --- a/pymilvus/orm/role.py +++ b/pymilvus/orm/role.py @@ -131,7 +131,7 @@ def is_exist(self): roles = self._get_connection().select_one_role(self._name, False) return len(roles.groups) != 0 - def grant(self, object: str, object_name: str, privilege: str): + def grant(self, object: str, object_name: str, privilege: str, db_name: str = "default"): """ Grant a privilege for the role :param object: object type. :type object: str @@ -139,6 +139,8 @@ def grant(self, object: str, object_name: str, privilege: str): :type object_name: str :param privilege: privilege name. :type privilege: str + :param db_name: db name. + :type db_name: str :example: >>> from pymilvus import connections @@ -147,9 +149,9 @@ def grant(self, object: str, object_name: str, privilege: str): >>> role = Role(role_name) >>> role.grant("Collection", collection_name, "Insert") """ - return self._get_connection().grant_privilege(self._name, object, object_name, privilege) + return self._get_connection().grant_privilege(self._name, object, object_name, privilege, db_name) - def revoke(self, object: str, object_name: str, privilege: str): + def revoke(self, object: str, object_name: str, privilege: str, db_name: str = "default"): """ Revoke a privilege for the role :param object: object type. :type object: str @@ -157,6 +159,8 @@ def revoke(self, object: str, object_name: str, privilege: str): :type object_name: str :param privilege: privilege name. :type privilege: str + :param db_name: db name. + :type db_name: str :example: >>> from pymilvus import connections @@ -165,14 +169,16 @@ def revoke(self, object: str, object_name: str, privilege: str): >>> role = Role(role_name) >>> role.revoke("Collection", collection_name, "Insert") """ - return self._get_connection().revoke_privilege(self._name, object, object_name, privilege) + return self._get_connection().revoke_privilege(self._name, object, object_name, privilege, db_name) - def list_grant(self, object: str, object_name: str): + def list_grant(self, object: str, object_name: str, db_name: str = "default"): """ List a grant info for the role and the specific object :param object: object type. :type object: str :param object_name: identifies a specific object name. :type object_name: str + :param db_name: db name. + :type db_name: str :return a GrantInfo object :rtype GrantInfo @@ -186,10 +192,12 @@ def list_grant(self, object: str, object_name: str): >>> role = Role(role_name) >>> role.list_grant("Collection", collection_name) """ - return self._get_connection().select_grant_for_role_and_object(self._name, object, object_name) + return self._get_connection().select_grant_for_role_and_object(self._name, object, object_name, db_name) - def list_grants(self): + def list_grants(self, db_name: str = "default"): """ List a grant info for the role + :param db_name: db name. + :type db_name: str :return a GrantInfo object :rtype GrantInfo @@ -203,4 +211,4 @@ def list_grants(self): >>> role = Role(role_name) >>> role.list_grants() """ - return self._get_connection().select_grant_for_one_role(self._name) + return self._get_connection().select_grant_for_one_role(self._name, db_name) diff --git a/tests/test_connections.py b/tests/test_connections.py index 95531f861..cb862d51e 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -47,6 +47,8 @@ def no_host_or_port(self, request): {"uri": "tcp://127.0.0.1:19530"}, {"uri": "http://127.0.0.1:19530"}, {"uri": "http://example.com:80"}, + {"uri": "http://example.com:80/database1"}, + {"uri": "https://127.0.0.1:19530/databse2"}, ]) def uri(self, request): return request.param