Skip to content

Commit

Permalink
Make auth_method a more flexible function object
Browse files Browse the repository at this point in the history
  • Loading branch information
NelsonVides committed Jan 20, 2025
1 parent 278e462 commit 805bd52
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
13 changes: 8 additions & 5 deletions src/escalus_session.erl
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,19 @@ authenticate(Client = #client{props = Props}) ->
%% but as a default we use plain, as it incurrs lower load and better logs (no hashing)
%% for common setups. If a different mechanism is required then it should be configured on the
%% user specification.
{M, F} = proplists:get_value(auth, Props, {escalus_auth, auth_plain}),
PropsAfterAuth = case apply(M, F, [Client, Props]) of
ok -> Props;
{ok, P} when is_list(P) -> P
end,
PropsAfterAuth = apply_auth_method(Client, Props),
escalus_connection:reset_parser(Client),
Client1 = escalus_session:start_stream(Client#client{props = PropsAfterAuth}),
escalus_session:stream_features(Client1, []),
Client1.

apply_auth_method(Client, Props) ->
Fun = proplists:get_value(auth, Props, fun escalus_auth:auth_plain/2),
case apply(Fun, [Client, Props]) of
ok -> Props;
{ok, P} when is_list(P) -> P
end.

-spec bind(client()) -> client().
bind(Client = #client{props = Props0}) ->
Resource = proplists:get_value(resource, Props0, ?DEFAULT_RESOURCE),
Expand Down
38 changes: 21 additions & 17 deletions src/escalus_users.erl
Original file line number Diff line number Diff line change
Expand Up @@ -154,46 +154,50 @@ get_server(Config, User) ->
get_wspath(Config, User) ->
get_user_option(wspath, User, escalus_wspath, Config, undefined).

-spec get_auth_method(escalus:config(), user()) -> {module(), atom()}.
-spec get_auth_method(escalus:config(), user()) ->
fun((escalus_connection:client(), escalus_users:user_spec()) -> ok | {ok, escalus_users:user_spec()}).
get_auth_method(Config, User) ->
AuthMethod = get_user_option(auth_method, User,
escalus_auth_method, Config,
<<"PLAIN">>),
get_auth_method(AuthMethod).

-spec get_auth_method(binary() | {module(), atom()}) -> {module(), atom()}.
-spec get_auth_method(binary() | {module(), atom()}) ->
fun((escalus_connection:client(), escalus_users:user_spec()) -> ok | {ok, escalus_users:user_spec()}).
get_auth_method(<<"PLAIN">>) ->
{escalus_auth, auth_plain};
fun escalus_auth:auth_plain/2;
get_auth_method(<<"DIGEST-MD5">>) ->
{escalus_auth, auth_digest_md5};
fun escalus_auth:auth_digest_md5/2;
get_auth_method(<<"SASL-ANON">>) ->
{escalus_auth, auth_sasl_anon};
fun escalus_auth:auth_sasl_anon/2;
%% SCRAM Regular
get_auth_method(<<"SCRAM-SHA-1">>) ->
{escalus_auth, auth_sasl_scram_sha1};
fun escalus_auth:auth_sasl_scram_sha1/2;
get_auth_method(<<"SCRAM-SHA-224">>) ->
{escalus_auth, auth_sasl_scram_sha224};
fun escalus_auth:auth_sasl_scram_sha224/2;
get_auth_method(<<"SCRAM-SHA-256">>) ->
{escalus_auth, auth_sasl_scram_sha256};
fun escalus_auth:auth_sasl_scram_sha256/2;
get_auth_method(<<"SCRAM-SHA-384">>) ->
{escalus_auth, auth_sasl_scram_sha384};
fun escalus_auth:auth_sasl_scram_sha384/2;
get_auth_method(<<"SCRAM-SHA-512">>) ->
{escalus_auth, auth_sasl_scram_sha512};
fun escalus_auth:auth_sasl_scram_sha512/2;
%% SCRAM PLUS
get_auth_method(<<"SCRAM-SHA-1-PLUS">>) ->
{escalus_auth, auth_sasl_scram_sha1_plus};
fun escalus_auth:auth_sasl_scram_sha1_plus/2;
get_auth_method(<<"SCRAM-SHA-224-PLUS">>) ->
{escalus_auth, auth_sasl_scram_sha224_plus};
fun escalus_auth:auth_sasl_scram_sha224_plus/2;
get_auth_method(<<"SCRAM-SHA-256-PLUS">>) ->
{escalus_auth, auth_sasl_scram_sha256_plus};
fun escalus_auth:auth_sasl_scram_sha256_plus/2;
get_auth_method(<<"SCRAM-SHA-384-PLUS">>) ->
{escalus_auth, auth_sasl_scram_sha384_plus};
fun escalus_auth:auth_sasl_scram_sha384_plus/2;
get_auth_method(<<"SCRAM-SHA-512-PLUS">>) ->
{escalus_auth, auth_sasl_scram_sha512_plus};
fun escalus_auth:auth_sasl_scram_sha512_plus/2;
get_auth_method(<<"X-OAUTH">>) ->
{escalus_auth, auth_sasl_oauth};
fun escalus_auth:auth_sasl_oauth/2;
get_auth_method({Mod, Fun}) when is_atom(Mod), is_atom(Fun) ->
{Mod, Fun}.
fun Mod:Fun/2;
get_auth_method(Fun) when is_function(Fun, 2) ->
Fun.

-spec get_usp(escalus:config(), user()) -> [binary() | xmpp_domain()].
get_usp(Config, User) ->
Expand Down

0 comments on commit 805bd52

Please sign in to comment.