Skip to content

Commit

Permalink
[Enhancement] Support is_role_in_session function (#32984)
Browse files Browse the repository at this point in the history
Signed-off-by: HangyuanLiu <[email protected]>
  • Loading branch information
HangyuanLiu authored Oct 20, 2023
1 parent a7dec81 commit b6954d3
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ public class FunctionSet {
// dict query function
public static final String DICT_MAPPING = "dict_mapping";

//user and role function
public static final String IS_ROLE_IN_SESSION = "is_role_in_session";

public static final String QUARTERS_ADD = "quarters_add";
public static final String QUARTERS_SUB = "quarters_sub";
public static final String WEEKS_ADD = "weeks_add";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ public boolean canExecuteAs(UserIdentity currentUser, Set<Long> roleIds, UserIde
}
}

public boolean allowGrant(UserIdentity currentUser, Set<Long> roleIds, ObjectType type,
public boolean allowGrant(UserIdentity currentUser, Set<Long> roleIds, ObjectType type,
List<PrivilegeType> wants, List<PEntryObject> objects) {
try {
PrivilegeCollectionV2 collection = mergePrivilegeCollection(currentUser, roleIds);
Expand Down Expand Up @@ -974,6 +974,17 @@ public RolePrivilegeCollectionV2 getRolePrivilegeCollection(long roleId) {
}
}

public void getRecursiveRole(Set<String> roleNames, Long roleId) {
RolePrivilegeCollectionV2 rolePrivilegeCollection = getRolePrivilegeCollection(roleId);
if (rolePrivilegeCollection != null) {
roleNames.add(rolePrivilegeCollection.getName());

for (Long parentId : rolePrivilegeCollection.getParentRoleIds()) {
getRecursiveRole(roleNames, parentId);
}
}
}

public RolePrivilegeCollectionV2 getRolePrivilegeCollectionUnlocked(long roleId, boolean exceptionIfNotExists)
throws PrivilegeException {
RolePrivilegeCollectionV2 collection = roleIdToPrivilegeCollection.get(roleId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,17 @@ private void checkFunction(String fnName, FunctionCallExpr node, Type[] argument
}
break;
}

case FunctionSet.IS_ROLE_IN_SESSION: {
if (node.getChildren().size() != 1) {
throw new SemanticException("IS_ROLE_IN_SESSION currently only supports a single parameter");
}

if (!(node.getChild(0) instanceof StringLiteral)) {
throw new SemanticException("IS_ROLE_IN_SESSION currently only supports constant parameters");
}
break;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import com.starrocks.connector.PartitionUtil;
import com.starrocks.connector.hive.Partition;
import com.starrocks.privilege.AccessDeniedException;
import com.starrocks.privilege.AuthorizationMgr;
import com.starrocks.privilege.ObjectType;
import com.starrocks.privilege.PrivilegeType;
import com.starrocks.qe.ConnectContext;
Expand Down Expand Up @@ -95,6 +96,7 @@
import java.time.temporal.IsoFields;
import java.time.temporal.TemporalAdjusters;
import java.time.temporal.TemporalUnit;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -1326,5 +1328,18 @@ public static ConstantOperator urlExtractParameter(ConstantOperator url, Constan
}
return ConstantOperator.createNull(Type.VARCHAR);
}

@ConstantFunction(name = "is_role_in_session", argTypes = {VARCHAR}, returnType = BOOLEAN)
public static ConstantOperator isRoleInSession(ConstantOperator role) {
AuthorizationMgr manager = GlobalStateMgr.getCurrentState().getAuthorizationMgr();
Set<String> roleNames = new HashSet<>();
ConnectContext connectContext = ConnectContext.get();

for (Long roleId : connectContext.getCurrentRoleIds()) {
manager.getRecursiveRole(roleNames, roleId);
}

return ConstantOperator.createBoolean(roleNames.contains(role.getVarchar()));
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
import com.starrocks.sql.ast.CreateRoleStmt;
import com.starrocks.sql.ast.CreateUserStmt;
import com.starrocks.sql.ast.DropUserStmt;
import com.starrocks.sql.ast.GrantRoleStmt;
import com.starrocks.sql.ast.SetDefaultRoleStmt;
import com.starrocks.sql.ast.SetUserPropertyStmt;
import com.starrocks.sql.ast.StatementBase;
import com.starrocks.sql.ast.UserIdentity;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorFunctions;
import com.starrocks.utframe.UtFrameUtils;
import org.junit.AfterClass;
import org.junit.Assert;
Expand Down Expand Up @@ -480,4 +483,92 @@ public void testSortUserIdentity() throws Exception {
Assert.assertEquals(Arrays.asList(
"'sort_user'@'10.1.1.1'", "'sort_user'@'10.1.1.2'", "'sort_user'@['host01']", "'sort_user'@'%'"), l);
}

@Test
public void testIsRoleInSession() throws Exception {
AuthenticationMgr masterManager = ctx.getGlobalStateMgr().getAuthenticationMgr();
AuthorizationMgr authorizationManager = ctx.getGlobalStateMgr().getAuthorizationMgr();

String sql = "create role test_in_role_r1";
CreateRoleStmt createStmt =
(CreateRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.createRole(createStmt);

sql = "create role test_in_role_r2";
createStmt = (CreateRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.createRole(createStmt);

sql = "create role test_in_role_r3";
createStmt = (CreateRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.createRole(createStmt);

sql = "create role test_in_role_r4";
createStmt = (CreateRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.createRole(createStmt);

sql = "grant test_in_role_r3 to role test_in_role_r2";
GrantRoleStmt grantRoleStmt = (GrantRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.grantRole(grantRoleStmt);

sql = "grant test_in_role_r2 to role test_in_role_r1";
grantRoleStmt = (GrantRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.grantRole(grantRoleStmt);

sql = "create user test_in_role_u1 default role test_in_role_r1";
CreateUserStmt stmt = (CreateUserStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
masterManager.createUser(stmt);

ctx.setCurrentUserIdentity(new UserIdentity("test_in_role_u1", "%"));
ctx.setCurrentRoleIds(new UserIdentity("test_in_role_u1", "%"));

Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r1")).getBoolean());
Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r2")).getBoolean());
Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r3")).getBoolean());

Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r4")).getBoolean());

sql = "create user test_in_role_u2 default role test_in_role_r2";
stmt = (CreateUserStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
masterManager.createUser(stmt);

ctx.setCurrentUserIdentity(new UserIdentity("test_in_role_u2", "%"));
ctx.setCurrentRoleIds(new UserIdentity("test_in_role_u2", "%"));

Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r1")).getBoolean());
Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r2")).getBoolean());
Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r3")).getBoolean());

ctx.setCurrentRoleIds(new HashSet<>());

Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r1")).getBoolean());
Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r2")).getBoolean());
Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r3")).getBoolean());


sql = "select is_role_in_session(v1) from (select 1 as v1) t";
try {
stmt = (CreateUserStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
Assert.fail();
} catch (AnalysisException e) {
Assert.assertTrue(e.getMessage().contains("IS_ROLE_IN_SESSION currently only supports constant parameters"));
}

sql = "select is_role_in_session(\"a\", \"b\") from (select 1 as v1) t";
try {
stmt = (CreateUserStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
Assert.fail();
} catch (AnalysisException e) {
Assert.assertTrue(e.getMessage().contains("IS_ROLE_IN_SESSION currently only supports a single parameter"));
}
}
}
3 changes: 3 additions & 0 deletions gensrc/script/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,4 +1014,7 @@
# struct functions
[170500, 'row', 'ANY_STRUCT', ['ANY_ELEMENT', "..."], 'StructFunctions::new_struct'],
[170501, 'named_struct', 'ANY_STRUCT', ['ANY_ELEMENT', "..."], 'StructFunctions::named_struct'],

# user function
[180000, 'is_role_in_session', 'BOOLEAN', ['VARCHAR'], 'nullptr']
]
70 changes: 70 additions & 0 deletions test/sql/test_rbac/R/test_is_role_in_session
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
-- name: test_is_role_in_session
drop role if exists r1;
-- result:
-- !result
create role r1;
-- result:
-- !result
drop role if exists r2;
-- result:
-- !result
create role r2;
-- result:
-- !result
drop role if exists r3;
-- result:
-- !result
create role r3;
-- result:
-- !result
drop user if exists u1;
-- result:
-- !result
create user u1;
-- result:
-- !result
grant impersonate on user root to u1;
-- result:
-- !result
grant r3 to role r2;
-- result:
-- !result
grant r2 to role r1;
-- result:
-- !result
grant r1 to u1;
-- result:
-- !result
execute as u1 with no revert;
-- result:
-- !result
select is_role_in_session("r1");
-- result:
0
-- !result
select is_role_in_session("r2");
-- result:
0
-- !result
select is_role_in_session("r3");
-- result:
0
-- !result
set role all;
-- result:
-- !result
select is_role_in_session("r1");
-- result:
1
-- !result
select is_role_in_session("r2");
-- result:
1
-- !result
select is_role_in_session("r3");
-- result:
1
-- !result
execute as root with no revert;
-- result:
-- !result
25 changes: 25 additions & 0 deletions test/sql/test_rbac/T/test_is_role_in_session
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
-- name: test_is_role_in_session
drop role if exists r1;
create role r1;
drop role if exists r2;
create role r2;
drop role if exists r3;
create role r3;
drop user if exists u1;
create user u1;
grant impersonate on user root to u1;

grant r3 to role r2;
grant r2 to role r1;
grant r1 to u1;

execute as u1 with no revert;
select is_role_in_session("r1");
select is_role_in_session("r2");
select is_role_in_session("r3");
set role all;
select is_role_in_session("r1");
select is_role_in_session("r2");
select is_role_in_session("r3");

execute as root with no revert;

0 comments on commit b6954d3

Please sign in to comment.