Skip to content

Commit

Permalink
add plot (#1436)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz authored Jan 26, 2025
1 parent 3b29601 commit e4a2664
Show file tree
Hide file tree
Showing 2 changed files with 1,475 additions and 1,365 deletions.
2,817 changes: 1,456 additions & 1,361 deletions docs/source/notebooks/mmm/mmm_example.ipynb

Large diffs are not rendered by default.

23 changes: 19 additions & 4 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,23 +1116,38 @@ def plot_grouped_contribution_breakdown_over_time(
ax.legend(title="groups", loc="center left", bbox_to_anchor=(1, 0.5))
return fig

def _get_channel_contributions_share_samples(self) -> DataArray:
def get_channel_contributions_share_samples(self, prior: bool = False) -> DataArray:
"""Get the share of channel contributions in the original scale of the target variable.
Parameters
----------
prior : bool, optional
Whether to use the prior or posterior, by default False (posterior)
Returns
-------
DataArray
The share of channel contributions in the original scale of the target variable.
"""
channel_contribution_original_scale_samples: DataArray = (
self.compute_channel_contribution_original_scale()
self.compute_channel_contribution_original_scale(prior=prior)
)
numerator: DataArray = channel_contribution_original_scale_samples.sum(["date"])
denominator: DataArray = numerator.sum("channel")
return numerator / denominator

def plot_channel_contribution_share_hdi(
self, hdi_prob: float = 0.94, **plot_kwargs: Any
self, hdi_prob: float = 0.94, prior: bool = False, **plot_kwargs: Any
) -> plt.Figure:
"""Plot the share of channel contributions in a forest plot.
Parameters
----------
hdi_prob : float, optional
HDI value to be displayed, by default 0.94
prior : bool, optional
Whether to use the prior or posterior, by default False (posterior)
**plot_kwargs
Additional keyword arguments to pass to `az.plot_forest`.
Expand All @@ -1142,7 +1157,7 @@ def plot_channel_contribution_share_hdi(
"""
channel_contributions_share: DataArray = (
self._get_channel_contributions_share_samples()
self.get_channel_contributions_share_samples(prior=prior)
)

ax, *_ = az.plot_forest(
Expand Down

0 comments on commit e4a2664

Please sign in to comment.