Skip to content

Commit

Permalink
Support is_role_in_session function
Browse files Browse the repository at this point in the history
Signed-off-by: HangyuanLiu <[email protected]>
  • Loading branch information
HangyuanLiu committed Oct 17, 2023
1 parent 74f3311 commit f847a0d
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ public class FunctionSet {
// dict query function
public static final String DICT_MAPPING = "dict_mapping";

//user and role function
public static final String IS_ROLE_IN_SESSION = "is_role_in_session";

public static final String QUARTERS_ADD = "quarters_add";
public static final String QUARTERS_SUB = "quarters_sub";
public static final String WEEKS_ADD = "weeks_add";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,17 @@ private void checkFunction(String fnName, FunctionCallExpr node, Type[] argument
}
break;
}

case FunctionSet.IS_ROLE_IN_SESSION: {
if (node.getChildren().size() != 1) {
throw new SemanticException("IS_ROLE_IN_SESSION currently only supports a single parameter");
}

if (!(node.getChild(0) instanceof StringLiteral)) {
throw new SemanticException("IS_ROLE_IN_SESSION currently only supports constant parameters");
}
break;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,15 @@
import com.starrocks.connector.PartitionUtil;
import com.starrocks.connector.hive.Partition;
import com.starrocks.privilege.AccessDeniedException;
import com.starrocks.privilege.AuthorizationMgr;
import com.starrocks.privilege.ObjectType;
import com.starrocks.privilege.PrivilegeException;
import com.starrocks.privilege.PrivilegeType;
import com.starrocks.privilege.RolePrivilegeCollectionV2;
import com.starrocks.qe.ConnectContext;
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.sql.analyzer.Authorizer;
import com.starrocks.sql.analyzer.SemanticException;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.collections4.SetUtils;
Expand Down Expand Up @@ -95,6 +99,7 @@
import java.time.temporal.IsoFields;
import java.time.temporal.TemporalAdjusters;
import java.time.temporal.TemporalUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -1326,5 +1331,34 @@ public static ConstantOperator urlExtractParameter(ConstantOperator url, Constan
}
return ConstantOperator.createNull(Type.VARCHAR);
}

private static void getRecursiveRole(AuthorizationMgr manager, List<String> roleName, Long roleId) {
try {
RolePrivilegeCollectionV2 rolePrivilegeCollection =
manager.getRolePrivilegeCollectionUnlocked(roleId, false);
if (rolePrivilegeCollection != null) {
roleName.add(rolePrivilegeCollection.getName());

for (Long parentId : rolePrivilegeCollection.getParentRoleIds()) {
getRecursiveRole(manager, roleName, parentId);
}
}
} catch (PrivilegeException e) {
throw new SemanticException(e.getMessage());
}
}

@ConstantFunction(name = "is_role_in_session", argTypes = {VARCHAR}, returnType = BOOLEAN)
public static ConstantOperator isRoleInSession(ConstantOperator role) {
AuthorizationMgr manager = GlobalStateMgr.getCurrentState().getAuthorizationMgr();
List<String> roleName = new ArrayList<>();
ConnectContext connectContext = ConnectContext.get();

for (Long roleId : connectContext.getCurrentRoleIds()) {
getRecursiveRole(manager, roleName, roleId);
}

return ConstantOperator.createBoolean(roleName.contains(role.getVarchar()));
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
import com.starrocks.sql.ast.CreateRoleStmt;
import com.starrocks.sql.ast.CreateUserStmt;
import com.starrocks.sql.ast.DropUserStmt;
import com.starrocks.sql.ast.GrantRoleStmt;
import com.starrocks.sql.ast.SetDefaultRoleStmt;
import com.starrocks.sql.ast.SetUserPropertyStmt;
import com.starrocks.sql.ast.StatementBase;
import com.starrocks.sql.ast.UserIdentity;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorFunctions;
import com.starrocks.utframe.UtFrameUtils;
import org.junit.AfterClass;
import org.junit.Assert;
Expand Down Expand Up @@ -480,4 +483,92 @@ public void testSortUserIdentity() throws Exception {
Assert.assertEquals(Arrays.asList(
"'sort_user'@'10.1.1.1'", "'sort_user'@'10.1.1.2'", "'sort_user'@['host01']", "'sort_user'@'%'"), l);
}

@Test
public void testIsRoleInSession() throws Exception {
AuthenticationMgr masterManager = ctx.getGlobalStateMgr().getAuthenticationMgr();
AuthorizationMgr authorizationManager = ctx.getGlobalStateMgr().getAuthorizationMgr();

String sql = "create role test_in_role_r1";
CreateRoleStmt createStmt =
(CreateRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.createRole(createStmt);

sql = "create role test_in_role_r2";
createStmt = (CreateRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.createRole(createStmt);

sql = "create role test_in_role_r3";
createStmt = (CreateRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.createRole(createStmt);

sql = "create role test_in_role_r4";
createStmt = (CreateRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.createRole(createStmt);

sql = "grant test_in_role_r3 to role test_in_role_r2";
GrantRoleStmt grantRoleStmt = (GrantRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.grantRole(grantRoleStmt);

sql = "grant test_in_role_r2 to role test_in_role_r1";
grantRoleStmt = (GrantRoleStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
authorizationManager.grantRole(grantRoleStmt);

sql = "create user test_in_role_u1 default role test_in_role_r1";
CreateUserStmt stmt = (CreateUserStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
masterManager.createUser(stmt);

ctx.setCurrentUserIdentity(new UserIdentity("test_in_role_u1", "%"));
ctx.setCurrentRoleIds(new UserIdentity("test_in_role_u1", "%"));

Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r1")).getBoolean());
Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r2")).getBoolean());
Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r3")).getBoolean());

Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r4")).getBoolean());

sql = "create user test_in_role_u2 default role test_in_role_r2";
stmt = (CreateUserStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
masterManager.createUser(stmt);

ctx.setCurrentUserIdentity(new UserIdentity("test_in_role_u2", "%"));
ctx.setCurrentRoleIds(new UserIdentity("test_in_role_u2", "%"));

Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r1")).getBoolean());
Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r2")).getBoolean());
Assert.assertTrue(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r3")).getBoolean());

ctx.setCurrentRoleIds(new HashSet<>());

Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r1")).getBoolean());
Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r2")).getBoolean());
Assert.assertFalse(ScalarOperatorFunctions.isRoleInSession(
ConstantOperator.createVarchar("test_in_role_r3")).getBoolean());


sql = "select is_role_in_session(v1) from (select 1 as v1) t";
try {
stmt = (CreateUserStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
Assert.fail();
} catch (AnalysisException e) {
Assert.assertTrue(e.getMessage().contains("IS_ROLE_IN_SESSION currently only supports constant parameters"));
}

sql = "select is_role_in_session(\"a\", \"b\") from (select 1 as v1) t";
try {
stmt = (CreateUserStmt) UtFrameUtils.parseStmtWithNewParser(sql, ctx);
Assert.fail();
} catch (AnalysisException e) {
Assert.assertTrue(e.getMessage().contains("IS_ROLE_IN_SESSION currently only supports a single parameter"));
}
}
}

0 comments on commit f847a0d

Please sign in to comment.