From 024f264202c2fe893d629643896dd4793bff15eb Mon Sep 17 00:00:00 2001 From: Maximilian Langenfeld <15726643+ezdac@users.noreply.github.com> Date: Tue, 7 Jan 2025 11:44:44 +0100 Subject: [PATCH] Add more verbosity to contract-address getter defaults Before the `GetAddresses` function silently returned the mainnet addresses as a default value, without the caller being made aware of this. Now, the `GetAddresses` function will return `nil` whenever it does not explicitly maps the passed in chain-id to a set of addresses. The newly added `GetAddressesOrDefault` function allows for retrieving the addresses with a default fallback for easier use in inline calls. This replicates the behavior of the old `GetAddresses` functionality with more verbosity. --- contracts/addresses/addresses.go | 20 ++++++++++++++++---- contracts/fee_currencies.go | 10 ++++++++-- core/celo_state_transition.go | 2 +- core/vm/celo_contracts.go | 2 +- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/contracts/addresses/addresses.go b/contracts/addresses/addresses.go index 4cfa57a71a..1b893bd487 100644 --- a/contracts/addresses/addresses.go +++ b/contracts/addresses/addresses.go @@ -33,19 +33,31 @@ var ( } ) -// GetAddresses returns the addresses for the given chainID. +// GetAddresses returns the addresses for the given chainID or +// nil if not found. func GetAddresses(chainID *big.Int) *CeloAddresses { // ChainID can be uninitialized in some tests if chainID == nil { - return MainnetAddresses + return nil } - switch chainID.Uint64() { case params.CeloAlfajoresChainID: return AlfajoresAddresses case params.CeloBaklavaChainID: return BaklavaAddresses - default: + case params.CeloMainnetChainID: return MainnetAddresses + default: + return nil + } +} + +// GetAddressesOrDefault returns the addresses for the given chainID or +// the Mainnet addresses if none are found. +func GetAddressesOrDefault(chainID *big.Int, defaultValue *CeloAddresses) *CeloAddresses { + addresses := GetAddresses(chainID) + if addresses == nil { + return defaultValue } + return addresses } diff --git a/contracts/fee_currencies.go b/contracts/fee_currencies.go index 592d68807c..d4c99ed0ef 100644 --- a/contracts/fee_currencies.go +++ b/contracts/fee_currencies.go @@ -189,7 +189,10 @@ func GetRegisteredCurrencies(caller *abigen.FeeCurrencyDirectoryCaller) ([]commo // GetExchangeRates returns the exchange rates for the provided gas currencies func GetExchangeRates(caller *CeloBackend) (common.ExchangeRates, error) { - directory, err := abigen.NewFeeCurrencyDirectoryCaller(addresses.GetAddresses(caller.ChainConfig.ChainID).FeeCurrencyDirectory, caller) + directory, err := abigen.NewFeeCurrencyDirectoryCaller( + addresses.GetAddressesOrDefault(caller.ChainConfig.ChainID, addresses.MainnetAddresses).FeeCurrencyDirectory, + caller, + ) if err != nil { return common.ExchangeRates{}, fmt.Errorf("failed to access FeeCurrencyDirectory: %w", err) } @@ -203,7 +206,10 @@ func GetExchangeRates(caller *CeloBackend) (common.ExchangeRates, error) { // GetFeeCurrencyContext returns the fee currency block context for all registered gas currencies from CELO func GetFeeCurrencyContext(caller *CeloBackend) (common.FeeCurrencyContext, error) { var feeContext common.FeeCurrencyContext - directory, err := abigen.NewFeeCurrencyDirectoryCaller(addresses.GetAddresses(caller.ChainConfig.ChainID).FeeCurrencyDirectory, caller) + directory, err := abigen.NewFeeCurrencyDirectoryCaller( + addresses.GetAddressesOrDefault(caller.ChainConfig.ChainID, addresses.MainnetAddresses).FeeCurrencyDirectory, + caller, + ) if err != nil { return feeContext, fmt.Errorf("failed to access FeeCurrencyDirectory: %w", err) } diff --git a/core/celo_state_transition.go b/core/celo_state_transition.go index b4ce8c0daf..b8d7df632c 100644 --- a/core/celo_state_transition.go +++ b/core/celo_state_transition.go @@ -94,7 +94,7 @@ func (st *StateTransition) distributeTxFees() error { tipTxFee := new(big.Int).Sub(totalTxFee, baseTxFee) feeCurrency := st.msg.FeeCurrency - feeHandlerAddress := addresses.GetAddresses(st.evm.ChainConfig().ChainID).FeeHandler + feeHandlerAddress := addresses.GetAddressesOrDefault(st.evm.ChainConfig().ChainID, addresses.MainnetAddresses).FeeHandler log.Trace("distributeTxFees", "from", from, "refund", refund, "feeCurrency", feeCurrency, "coinbaseFeeRecipient", st.evm.Context.Coinbase, "coinbaseFee", tipTxFee, diff --git a/core/vm/celo_contracts.go b/core/vm/celo_contracts.go index 186127673b..d16a7571c6 100644 --- a/core/vm/celo_contracts.go +++ b/core/vm/celo_contracts.go @@ -47,7 +47,7 @@ func celoPrecompileAddress(index byte) common.Address { } func (ctx *celoPrecompileContext) IsCallerCeloToken() (bool, error) { - tokenAddress := addresses.GetAddresses(ctx.evm.ChainConfig().ChainID).CeloToken + tokenAddress := addresses.GetAddressesOrDefault(ctx.evm.ChainConfig().ChainID, addresses.MainnetAddresses).CeloToken return tokenAddress == ctx.caller, nil }