Skip to content

Commit

Permalink
Merge pull request #424 from hairyhum/gen_consumer_callback_stop
Browse files Browse the repository at this point in the history
Allow GenConsumer callbacks to return :stop replies
  • Loading branch information
joshuawscott authored Jan 7, 2021
2 parents 301c483 + 6523b1f commit 1d8c0e5
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 49 deletions.
119 changes: 77 additions & 42 deletions lib/kafka_ex/gen_consumer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ defmodule KafkaEx.GenConsumer do
"""
@callback init(topic :: binary, partition :: non_neg_integer) ::
{:ok, state :: term}
| {:stop, reason :: term}

@doc """
Invoked when the server is started. `start_link/5` will block until it
Expand All @@ -255,7 +256,7 @@ defmodule KafkaEx.GenConsumer do
topic :: binary,
partition :: non_neg_integer,
extra_args :: map()
) :: {:ok, state :: term}
) :: {:ok, state :: term} | {:stop, reason :: term}

@doc """
Invoked for each message set consumed from a Kafka topic partition.
Expand Down Expand Up @@ -287,6 +288,8 @@ defmodule KafkaEx.GenConsumer do
"""
@callback handle_call(call :: term, from :: GenServer.from(), state :: term) ::
{:reply, reply_value :: term, new_state :: term}
| {:stop, reason :: term, reply_value :: term, new_state :: term}
| {:stop, reason :: term, new_state :: term}

@doc """
Invoked by `KafkaEx.GenConsumer.cast/2`.
Expand All @@ -296,6 +299,7 @@ defmodule KafkaEx.GenConsumer do
"""
@callback handle_cast(cast :: term, state :: term) ::
{:noreply, new_state :: term}
| {:stop, reason :: term, new_state :: term}

@doc """
Invoked by sending messages to the consumer.
Expand All @@ -305,6 +309,7 @@ defmodule KafkaEx.GenConsumer do
"""
@callback handle_info(info :: term, state :: term) ::
{:noreply, new_state :: term}
| {:stop, reason :: term, new_state :: term}

defmacro __using__(_opts) do
quote do
Expand Down Expand Up @@ -541,44 +546,49 @@ defmodule KafkaEx.GenConsumer do
api_versions = Keyword.get(opts, :api_versions, %{})
api_versions = Map.merge(default_api_versions, api_versions)

{:ok, consumer_state} =
consumer_module.init(topic, partition, extra_consumer_args)
case consumer_module.init(topic, partition, extra_consumer_args) do
{:ok, consumer_state} ->
worker_opts = Keyword.take(opts, [:uris, :use_ssl, :ssl_options])

worker_opts = Keyword.take(opts, [:uris, :use_ssl, :ssl_options])
{:ok, worker_name} =
KafkaEx.create_worker(
:no_name,
[consumer_group: group_name] ++ worker_opts
)

{:ok, worker_name} =
KafkaEx.create_worker(
:no_name,
[consumer_group: group_name] ++ worker_opts
)
default_fetch_options = [
auto_commit: false,
worker_name: worker_name
]

default_fetch_options = [
auto_commit: false,
worker_name: worker_name
]
given_fetch_options = Keyword.get(opts, :fetch_options, [])

given_fetch_options = Keyword.get(opts, :fetch_options, [])
fetch_options = Keyword.merge(default_fetch_options, given_fetch_options)

state = %State{
consumer_module: consumer_module,
consumer_state: consumer_state,
commit_interval: commit_interval,
commit_threshold: commit_threshold,
auto_offset_reset: auto_offset_reset,
worker_name: worker_name,
group: group_name,
topic: topic,
partition: partition,
generation_id: generation_id,
member_id: member_id,
fetch_options: fetch_options,
api_versions: api_versions
}
fetch_options =
Keyword.merge(default_fetch_options, given_fetch_options)

state = %State{
consumer_module: consumer_module,
consumer_state: consumer_state,
commit_interval: commit_interval,
commit_threshold: commit_threshold,
auto_offset_reset: auto_offset_reset,
worker_name: worker_name,
group: group_name,
topic: topic,
partition: partition,
generation_id: generation_id,
member_id: member_id,
fetch_options: fetch_options,
api_versions: api_versions
}

Process.flag(:trap_exit, true)
Process.flag(:trap_exit, true)

{:ok, state, 0}
{:ok, state, 0}

{:stop, reason} ->
{:stop, reason}
end
end

def handle_call(:partition, _from, state) do
Expand All @@ -597,14 +607,23 @@ defmodule KafkaEx.GenConsumer do
# which we turn into a timeout = 0 clause so that we continue to consume.
# any other GenServer flow control could have unintended consequences,
# so we leave that for later consideration
{:reply, reply, new_consumer_state} =
consumer_reply =
consumer_module.handle_call(
message,
from,
consumer_state
)

{:reply, reply, %{state | consumer_state: new_consumer_state}, 0}
case consumer_reply do
{:reply, reply, new_consumer_state} ->
{:reply, reply, %{state | consumer_state: new_consumer_state}, 0}

{:stop, reason, new_consumer_state} ->
{:stop, reason, %{state | consumer_state: new_consumer_state}}

{:stop, reason, reply, new_consumer_state} ->
{:stop, reason, reply, %{state | consumer_state: new_consumer_state}}
end
end

def handle_cast(
Expand All @@ -618,13 +637,19 @@ defmodule KafkaEx.GenConsumer do
# which we turn into a timeout = 0 clause so that we continue to consume.
# any other GenServer flow control could have unintended consequences,
# so we leave that for later consideration
{:noreply, new_consumer_state} =
consumer_reply =
consumer_module.handle_cast(
message,
consumer_state
)

{:noreply, %{state | consumer_state: new_consumer_state}, 0}
case consumer_reply do
{:noreply, new_consumer_state} ->
{:noreply, %{state | consumer_state: new_consumer_state}, 0}

{:stop, reason, new_consumer_state} ->
{:stop, reason, %{state | consumer_state: new_consumer_state}}
end
end

def handle_info(
Expand Down Expand Up @@ -660,13 +685,19 @@ defmodule KafkaEx.GenConsumer do
# which we turn into a timeout = 0 clause so that we continue to consume.
# any other GenServer flow control could have unintended consequences,
# so we leave that for later consideration
{:noreply, new_consumer_state} =
consumer_reply =
consumer_module.handle_info(
message,
consumer_state
)

{:noreply, %{state | consumer_state: new_consumer_state}, 0}
case consumer_reply do
{:noreply, new_consumer_state} ->
{:noreply, %{state | consumer_state: new_consumer_state}, 0}

{:stop, reason, new_consumer_state} ->
{:stop, reason, %{state | consumer_state: new_consumer_state}}
end
end

def terminate(_reason, %State{} = state) do
Expand All @@ -689,7 +720,8 @@ defmodule KafkaEx.GenConsumer do
KafkaEx.fetch(
topic,
partition,
Keyword.merge(fetch_options,
Keyword.merge(
fetch_options,
offset: offset,
api_version: Map.fetch!(state.api_versions, :fetch)
)
Expand Down Expand Up @@ -850,9 +882,12 @@ defmodule KafkaEx.GenConsumer do
# one of these needs to match, depending on which client
case partition_response do
# old client
^partition -> :ok
^partition ->
:ok

# new client
%{error_code: :no_error, partition: ^partition} -> :ok
%{error_code: :no_error, partition: ^partition} ->
:ok
end

Logger.debug(fn ->
Expand Down
93 changes: 86 additions & 7 deletions test/integration/consumer_group_implementation_test.exs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defmodule KafkaEx.ConsumerGroupImplementationTest do
use ExUnit.Case
use ExUnit.Case, async: false

alias KafkaEx.ConsumerGroup
alias KafkaEx.GenConsumer
Expand Down Expand Up @@ -75,14 +75,30 @@ defmodule KafkaEx.ConsumerGroupImplementationTest do
{:reply, Map.get(state, key), state}
end

def handle_call({:stop, msg}, _from, state) do
{:stop, :test_stop, msg, state}
end

def handle_call(:stop, _from, state) do
{:stop, :test_stop, state}
end

def handle_cast({:set, key, value}, state) do
{:noreply, Map.put_new(state, key, value)}
end

def handle_cast(:stop, state) do
{:stop, :test_stop, state}
end

def handle_info({:set, key, value}, state) do
{:noreply, Map.put_new(state, key, value)}
end

def handle_info(:stop, state) do
{:stop, :test_stop, state}
end

def handle_message_set(message_set, state) do
Logger.debug(fn ->
"Consumer #{inspect(self())} handled message set #{inspect(message_set)}"
Expand Down Expand Up @@ -130,14 +146,14 @@ defmodule KafkaEx.ConsumerGroupImplementationTest do
|> length
end

setup do
setup context do
ports_before = num_open_ports()
{:ok, test_partitioner_pid} = TestPartitioner.start_link()

{:ok, consumer_group_pid1} =
ConsumerGroup.start_link(
TestConsumer,
@consumer_group_name,
consumer_group_name(context),
[@topic_name],
heartbeat_interval: 100,
partition_assignment_callback: &TestPartitioner.assign_partitions/2,
Expand All @@ -147,7 +163,7 @@ defmodule KafkaEx.ConsumerGroupImplementationTest do
{:ok, consumer_group_pid2} =
ConsumerGroup.start_link(
TestConsumer,
@consumer_group_name,
consumer_group_name(context),
[@topic_name],
heartbeat_interval: 100,
partition_assignment_callback: &TestPartitioner.assign_partitions/2,
Expand Down Expand Up @@ -183,7 +199,9 @@ defmodule KafkaEx.ConsumerGroupImplementationTest do
generation_id2 = ConsumerGroup.generation_id(context[:consumer_group_pid2])
assert generation_id1 == generation_id2

assert @consumer_group_name ==
consumer_group_name = consumer_group_name(context)

assert consumer_group_name ==
ConsumerGroup.group_name(context[:consumer_group_pid1])

member1 = ConsumerGroup.member_id(context[:consumer_group_pid1])
Expand Down Expand Up @@ -289,7 +307,11 @@ defmodule KafkaEx.ConsumerGroupImplementationTest do
for px <- partition_range do
wait_for(fn ->
ending_offset =
latest_consumer_offset_number(@topic_name, px, @consumer_group_name)
latest_consumer_offset_number(
@topic_name,
px,
consumer_group_name(context)
)

last_offset = Map.get(last_offsets, px)
ending_offset == last_offset + 1
Expand Down Expand Up @@ -318,7 +340,7 @@ defmodule KafkaEx.ConsumerGroupImplementationTest do
{:ok, consumer_group_pid3} =
ConsumerGroup.start_link(
TestConsumer,
@consumer_group_name,
consumer_group_name(context),
[@topic_name],
heartbeat_interval: 100,
partition_assignment_callback: &TestPartitioner.assign_partitions/2
Expand Down Expand Up @@ -374,4 +396,61 @@ defmodule KafkaEx.ConsumerGroupImplementationTest do
assert :value == TestConsumer.get(consumer_pid, :test_info)
end
end

test "handle call stop returns from callbacks", context do
consumer_group_pid =
ConsumerGroup.consumer_supervisor_pid(context[:consumer_group_pid1])

[c1, c2] = GenConsumer.Supervisor.child_pids(consumer_group_pid)
assert :foo = GenConsumer.call(c1, {:stop, :foo})

try do
GenConsumer.call(c2, :stop)
catch
_, err ->
assert {:test_stop, _} = err
end

assert nil == Process.info(c1)
assert nil == Process.info(c2)
end

test "handle cast stop returns from callbacks", context do
consumer_group_pid =
ConsumerGroup.consumer_supervisor_pid(context[:consumer_group_pid1])

[c1, _c2] = GenConsumer.Supervisor.child_pids(consumer_group_pid)
GenConsumer.cast(c1, :stop)

try do
:sys.get_state(c1)
catch
_, err ->
assert {:test_stop, _} = err
end

assert nil == Process.info(c1)
end

test "handle info stop returns from callbacks", context do
consumer_group_pid =
ConsumerGroup.consumer_supervisor_pid(context[:consumer_group_pid1])

[c1, _c2] = GenConsumer.Supervisor.child_pids(consumer_group_pid)
send(c1, :stop)

try do
:sys.get_state(c1)
catch
_, err ->
assert {:test_stop, _} = err
end

assert nil == Process.info(c1)
end

def consumer_group_name(context) do
test_name = context[:test] |> to_string() |> String.replace(" ", "_")
@consumer_group_name <> test_name
end
end

0 comments on commit 1d8c0e5

Please sign in to comment.