diff --git a/core/meterer/meterer.go b/core/meterer/meterer.go index 0babbe1f06..f7c13c401d 100644 --- a/core/meterer/meterer.go +++ b/core/meterer/meterer.go @@ -60,7 +60,7 @@ func (m *Meterer) Start(ctx context.Context) { for { select { case <-ticker.C: - if err := m.ChainPaymentState.RefreshOnchainPaymentState(ctx, nil); err != nil { + if err := m.ChainPaymentState.RefreshOnchainPaymentState(ctx); err != nil { m.logger.Error("Failed to refresh on-chain state", "error", err) } case <-ctx.Done(): diff --git a/core/meterer/meterer_test.go b/core/meterer/meterer_test.go index 2b4bc4fa34..0f9c16569c 100644 --- a/core/meterer/meterer_test.go +++ b/core/meterer/meterer_test.go @@ -154,7 +154,7 @@ func setup(_ *testing.M) { } paymentChainState.On("RefreshOnchainPaymentState", testifymock.Anything).Return(nil).Maybe() - if err := paymentChainState.RefreshOnchainPaymentState(context.Background(), nil); err != nil { + if err := paymentChainState.RefreshOnchainPaymentState(context.Background()); err != nil { panic("failed to make initial query to the on-chain state") } diff --git a/core/meterer/onchain_state.go b/core/meterer/onchain_state.go index 48b15c43aa..951d60b974 100644 --- a/core/meterer/onchain_state.go +++ b/core/meterer/onchain_state.go @@ -14,7 +14,7 @@ import ( // OnchainPaymentState is an interface for getting information about the current chain state for payments. type OnchainPayment interface { - RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error + RefreshOnchainPaymentState(ctx context.Context) error GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) @@ -49,49 +49,45 @@ type PaymentVaultParams struct { } func NewOnchainPaymentState(ctx context.Context, tx *eth.Reader) (*OnchainPaymentState, error) { - paymentVaultParams, err := GetPaymentVaultParams(ctx, tx) - if err != nil { - return nil, err - } - state := OnchainPaymentState{ tx: tx, ReservedPayments: make(map[gethcommon.Address]*core.ReservedPayment), OnDemandPayments: make(map[gethcommon.Address]*core.OnDemandPayment), PaymentVaultParams: atomic.Pointer[PaymentVaultParams]{}, } - state.PaymentVaultParams.Store(paymentVaultParams) - - return &state, nil -} -func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultParams, error) { - blockNumber, err := tx.GetCurrentBlockNumber(ctx) + paymentVaultParams, err := state.GetPaymentVaultParams(ctx) if err != nil { return nil, err } - quorumNumbers, err := tx.GetRequiredQuorumNumbers(ctx, blockNumber) + state.PaymentVaultParams.Store(paymentVaultParams) + + return &state, nil +} + +func (pcs *OnchainPaymentState) GetPaymentVaultParams(ctx context.Context) (*PaymentVaultParams, error) { + quorumNumbers, err := pcs.GetOnDemandQuorumNumbers(ctx) if err != nil { return nil, err } - globalSymbolsPerSecond, err := tx.GetGlobalSymbolsPerSecond(ctx) + globalSymbolsPerSecond, err := pcs.tx.GetGlobalSymbolsPerSecond(ctx) if err != nil { return nil, err } - minNumSymbols, err := tx.GetMinNumSymbols(ctx) + minNumSymbols, err := pcs.tx.GetMinNumSymbols(ctx) if err != nil { return nil, err } - pricePerSymbol, err := tx.GetPricePerSymbol(ctx) + pricePerSymbol, err := pcs.tx.GetPricePerSymbol(ctx) if err != nil { return nil, err } - reservationWindow, err := tx.GetReservationWindow(ctx) + reservationWindow, err := pcs.tx.GetReservationWindow(ctx) if err != nil { return nil, err } @@ -106,8 +102,8 @@ func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultPa } // RefreshOnchainPaymentState returns the current onchain payment state -func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error { - paymentVaultParams, err := GetPaymentVaultParams(ctx, tx) +func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context) error { + paymentVaultParams, err := pcs.GetPaymentVaultParams(ctx) if err != nil { return err } @@ -120,7 +116,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, accountIDs = append(accountIDs, accountID) } - reservedPayments, err := tx.GetReservedPayments(ctx, accountIDs) + reservedPayments, err := pcs.tx.GetReservedPayments(ctx, accountIDs) if err != nil { return err } @@ -133,7 +129,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, accountIDs = append(accountIDs, accountID) } - onDemandPayments, err := tx.GetOnDemandPayments(ctx, accountIDs) + onDemandPayments, err := pcs.tx.GetOnDemandPayments(ctx, accountIDs) if err != nil { return err } @@ -146,10 +142,11 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, // GetReservedPaymentByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) { pcs.ReservationsLock.RLock() - defer pcs.ReservationsLock.RUnlock() if reservation, ok := (pcs.ReservedPayments)[accountID]; ok { + pcs.ReservationsLock.RUnlock() return reservation, nil } + pcs.ReservationsLock.RUnlock() // pulls the chain state res, err := pcs.tx.GetReservedPaymentByAccount(ctx, accountID) @@ -166,10 +163,12 @@ func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, // GetOnDemandPaymentByAccount returns a pointer to the on-demand payment for the given account ID; no writes will be made to the payment func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) { pcs.OnDemandLocks.RLock() - defer pcs.OnDemandLocks.RUnlock() if payment, ok := (pcs.OnDemandPayments)[accountID]; ok { + pcs.OnDemandLocks.RUnlock() return payment, nil } + pcs.OnDemandLocks.RUnlock() + // pulls the chain state res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, accountID) if err != nil { diff --git a/core/meterer/onchain_state_test.go b/core/meterer/onchain_state_test.go index 468296be87..1d362cbc0b 100644 --- a/core/meterer/onchain_state_test.go +++ b/core/meterer/onchain_state_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/Layr-Labs/eigenda/core" - "github.com/Layr-Labs/eigenda/core/eth" "github.com/Layr-Labs/eigenda/core/mock" gethcommon "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/assert" @@ -30,7 +29,7 @@ func TestRefreshOnchainPaymentState(t *testing.T) { ctx := context.Background() mockState.On("RefreshOnchainPaymentState", testifymock.Anything, testifymock.Anything).Return(nil) - err := mockState.RefreshOnchainPaymentState(ctx, ð.Reader{}) + err := mockState.RefreshOnchainPaymentState(ctx) assert.NoError(t, err) } diff --git a/core/mock/payment_state.go b/core/mock/payment_state.go index 32a2a0b6cf..9c8746a64e 100644 --- a/core/mock/payment_state.go +++ b/core/mock/payment_state.go @@ -4,7 +4,6 @@ import ( "context" "github.com/Layr-Labs/eigenda/core" - "github.com/Layr-Labs/eigenda/core/eth" "github.com/Layr-Labs/eigenda/core/meterer" gethcommon "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/mock" @@ -25,7 +24,7 @@ func (m *MockOnchainPaymentState) GetCurrentBlockNumber(ctx context.Context) (ui return value, args.Error(1) } -func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error { +func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context) error { args := m.Called() return args.Error(0) } diff --git a/disperser/apiserver/server_test.go b/disperser/apiserver/server_test.go index 1bf6629d34..a696c46863 100644 --- a/disperser/apiserver/server_test.go +++ b/disperser/apiserver/server_test.go @@ -748,7 +748,7 @@ func newTestServer(transactor core.Writer, testName string) *apiserver.Dispersal mockState := &mock.MockOnchainPaymentState{} mockState.On("RefreshOnchainPaymentState", tmock.Anything).Return(nil).Maybe() - if err := mockState.RefreshOnchainPaymentState(context.Background(), nil); err != nil { + if err := mockState.RefreshOnchainPaymentState(context.Background()); err != nil { panic("failed to make initial query to the on-chain state") } @@ -798,7 +798,7 @@ func newTestServer(transactor core.Writer, testName string) *apiserver.Dispersal panic("failed to create offchain store") } mt := meterer.NewMeterer(meterer.Config{}, mockState, store, logger) - err = mt.ChainPaymentState.RefreshOnchainPaymentState(context.Background(), nil) + err = mt.ChainPaymentState.RefreshOnchainPaymentState(context.Background()) if err != nil { panic("failed to make initial query to the on-chain state") } diff --git a/disperser/apiserver/server_v2_test.go b/disperser/apiserver/server_v2_test.go index 7f95a8f3bd..5b5a1084fd 100644 --- a/disperser/apiserver/server_v2_test.go +++ b/disperser/apiserver/server_v2_test.go @@ -451,7 +451,7 @@ func newTestServerV2(t *testing.T) *testComponents { mockState.On("GetOnDemandPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.OnDemandPayment{CumulativePayment: big.NewInt(3864)}, nil) mockState.On("GetOnDemandQuorumNumbers", tmock.Anything).Return([]uint8{0, 1}, nil) - if err := mockState.RefreshOnchainPaymentState(context.Background(), nil); err != nil { + if err := mockState.RefreshOnchainPaymentState(context.Background()); err != nil { panic("failed to make initial query to the on-chain state") } table_names := []string{"reservations_server_" + t.Name(), "ondemand_server_" + t.Name(), "global_server_" + t.Name()} diff --git a/disperser/cmd/apiserver/main.go b/disperser/cmd/apiserver/main.go index e28e3aeb79..1b1d98590b 100644 --- a/disperser/cmd/apiserver/main.go +++ b/disperser/cmd/apiserver/main.go @@ -107,7 +107,7 @@ func RunDisperserServer(ctx *cli.Context) error { if err != nil { return fmt.Errorf("failed to create onchain payment state: %w", err) } - if err := paymentChainState.RefreshOnchainPaymentState(context.Background(), nil); err != nil { + if err := paymentChainState.RefreshOnchainPaymentState(context.Background()); err != nil { return fmt.Errorf("failed to make initial query to the on-chain state: %w", err) } diff --git a/test/integration_test.go b/test/integration_test.go index 56ea282534..e3c390e56d 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "errors" "fmt" - "github.com/stretchr/testify/require" "log" "math" "math/big" @@ -18,6 +17,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/Layr-Labs/eigenda/common/pubip" "github.com/Layr-Labs/eigenda/encoding/kzg" "github.com/Layr-Labs/eigenda/encoding/kzg/prover" @@ -289,7 +290,7 @@ func mustMakeDisperser(t *testing.T, cst core.IndexedChainState, store disperser } mockState.On("RefreshOnchainPaymentState", mock.Anything).Return(nil).Maybe() - if err := mockState.RefreshOnchainPaymentState(context.Background(), nil); err != nil { + if err := mockState.RefreshOnchainPaymentState(context.Background()); err != nil { panic("failed to make initial query to the on-chain state") }