Skip to content

Commit

Permalink
support fee on transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
1kresh committed Jul 26, 2024
1 parent 6d4a827 commit a4113dd
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 7 deletions.
17 changes: 11 additions & 6 deletions src/contracts/vault/Vault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ import {SafeERC20, IERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeE
import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
import {Time} from "@openzeppelin/contracts/utils/types/Time.sol";
import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";
import {ReentrancyGuardUpgradeable} from "@openzeppelin/contracts-upgradeable/utils/ReentrancyGuardUpgradeable.sol";

contract Vault is VaultStorage, MigratableEntity, AccessControlUpgradeable, IVault {
contract Vault is VaultStorage, MigratableEntity, AccessControlUpgradeable, ReentrancyGuardUpgradeable, IVault {
using Checkpoints for Checkpoints.Trace256;
using Math for uint256;
using SafeCast for uint256;
Expand Down Expand Up @@ -77,7 +78,7 @@ contract Vault is VaultStorage, MigratableEntity, AccessControlUpgradeable, IVau
/**
* @inheritdoc IVault
*/
function deposit(address onBehalfOf, uint256 amount) external returns (uint256 shares) {
function deposit(address onBehalfOf, uint256 amount) external nonReentrant returns (uint256 shares) {
if (onBehalfOf == address(0)) {
revert InvalidOnBehalfOf();
}
Expand All @@ -86,12 +87,14 @@ contract Vault is VaultStorage, MigratableEntity, AccessControlUpgradeable, IVau
revert NotWhitelistedDepositor();
}

uint256 balanceBefore = IERC20(collateral).balanceOf(address(this));
IERC20(collateral).safeTransferFrom(msg.sender, address(this), amount);
amount = IERC20(collateral).balanceOf(address(this)) - balanceBefore;

if (amount == 0) {
revert InsufficientDeposit();
}

IERC20(collateral).safeTransferFrom(msg.sender, address(this), amount);

uint256 activeStake_ = activeStake();
uint256 activeShares_ = activeShares();

Expand Down Expand Up @@ -148,7 +151,7 @@ contract Vault is VaultStorage, MigratableEntity, AccessControlUpgradeable, IVau
/**
* @inheritdoc IVault
*/
function claim(uint256 epoch) external returns (uint256 amount) {
function claim(uint256 epoch) external nonReentrant returns (uint256 amount) {
amount = _claim(epoch);

IERC20(collateral).safeTransfer(msg.sender, amount);
Expand All @@ -157,7 +160,7 @@ contract Vault is VaultStorage, MigratableEntity, AccessControlUpgradeable, IVau
/**
* @inheritdoc IVault
*/
function claimBatch(uint256[] calldata epochs) external returns (uint256 amount) {
function claimBatch(uint256[] calldata epochs) external nonReentrant returns (uint256 amount) {
uint256 length = epochs.length;
if (length == 0) {
revert InvalidLengthEpochs();
Expand Down Expand Up @@ -296,6 +299,8 @@ contract Vault is VaultStorage, MigratableEntity, AccessControlUpgradeable, IVau
revert NotSlasher();
}

__ReentrancyGuard_init();

collateral = params.collateral;

delegator = params.delegator;
Expand Down
204 changes: 203 additions & 1 deletion test/vault/Vault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {VetoSlasher} from "src/contracts/slasher/VetoSlasher.sol";
import {IVault} from "src/interfaces/vault/IVault.sol";
import {SimpleCollateral} from "test/mocks/SimpleCollateral.sol";
import {Token} from "test/mocks/Token.sol";
import {FeeOnTransferToken} from "test/mocks/FeeOnTransferToken.sol";
import {VaultConfigurator} from "src/contracts/VaultConfigurator.sol";
import {IVaultConfigurator} from "src/interfaces/IVaultConfigurator.sol";
import {INetworkRestakeDelegator} from "src/interfaces/delegator/INetworkRestakeDelegator.sol";
Expand Down Expand Up @@ -52,6 +53,7 @@ contract VaultTest is Test {
OptInService operatorNetworkOptInService;

SimpleCollateral collateral;
FeeOnTransferToken feeOnTransferCollateral;
VaultConfigurator vaultConfigurator;

Vault vault;
Expand Down Expand Up @@ -132,6 +134,7 @@ contract VaultTest is Test {

Token token = new Token("Token");
collateral = new SimpleCollateral(address(token));
feeOnTransferCollateral = new FeeOnTransferToken("FeeOnTransferToken");

collateral.mint(token.totalSupply());

Expand Down Expand Up @@ -565,6 +568,205 @@ contract VaultTest is Test {
assertGt(gasSpent, gasLeft - gasleft());
}

function test_DepositTwiceFeeOnTransferCollateral(uint256 amount1, uint256 amount2) public {
amount1 = bound(amount1, 2, 100 * 10 ** 18);
amount2 = bound(amount2, 2, 100 * 10 ** 18);

uint256 blockTimestamp = block.timestamp * block.timestamp / block.timestamp * block.timestamp / block.timestamp;
blockTimestamp = blockTimestamp + 1_720_700_948;
vm.warp(blockTimestamp);

uint48 epochDuration = 1;
{
address[] memory networkLimitSetRoleHolders = new address[](1);
networkLimitSetRoleHolders[0] = alice;
address[] memory operatorNetworkSharesSetRoleHolders = new address[](1);
operatorNetworkSharesSetRoleHolders[0] = alice;
(address vault_,,) = vaultConfigurator.create(
IVaultConfigurator.InitParams({
version: vaultFactory.lastVersion(),
owner: alice,
vaultParams: IVault.InitParams({
collateral: address(feeOnTransferCollateral),
delegator: address(0),
slasher: address(0),
burner: address(0xdEaD),
epochDuration: epochDuration,
depositWhitelist: false,
defaultAdminRoleHolder: alice,
depositorWhitelistRoleHolder: alice
}),
delegatorIndex: 0,
delegatorParams: abi.encode(
INetworkRestakeDelegator.InitParams({
baseParams: IBaseDelegator.BaseParams({
defaultAdminRoleHolder: alice,
hook: address(0),
hookSetRoleHolder: alice
}),
networkLimitSetRoleHolders: networkLimitSetRoleHolders,
operatorNetworkSharesSetRoleHolders: operatorNetworkSharesSetRoleHolders
})
),
withSlasher: false,
slasherIndex: 0,
slasherParams: ""
})
);

vault = Vault(vault_);
}

uint256 tokensBefore = feeOnTransferCollateral.balanceOf(address(vault));
uint256 shares1 = (amount1 - 1) * 10 ** 0;
feeOnTransferCollateral.transfer(alice, amount1 + 1);
vm.startPrank(alice);
feeOnTransferCollateral.approve(address(vault), amount1);
assertEq(vault.deposit(alice, amount1), shares1);
vm.stopPrank();
assertEq(feeOnTransferCollateral.balanceOf(address(vault)) - tokensBefore, amount1 - 1);

assertEq(vault.totalStake(), amount1 - 1);
assertEq(vault.activeSharesAt(uint48(blockTimestamp - 1), ""), 0);
assertEq(vault.activeSharesAt(uint48(blockTimestamp), ""), shares1);
assertEq(vault.activeShares(), shares1);
assertEq(vault.activeStakeAt(uint48(blockTimestamp - 1), ""), 0);
assertEq(vault.activeStakeAt(uint48(blockTimestamp), ""), amount1 - 1);
assertEq(vault.activeStake(), amount1 - 1);
assertEq(vault.activeSharesOfAt(alice, uint48(blockTimestamp - 1), ""), 0);
assertEq(vault.activeSharesOfAt(alice, uint48(blockTimestamp), ""), shares1);
assertEq(vault.activeSharesOf(alice), shares1);
assertEq(vault.activeBalanceOfAt(alice, uint48(blockTimestamp - 1), ""), 0);
assertEq(vault.activeBalanceOfAt(alice, uint48(blockTimestamp), ""), amount1 - 1);
assertEq(vault.activeBalanceOf(alice), amount1 - 1);
assertEq(vault.balanceOf(alice), amount1 - 1);

blockTimestamp = blockTimestamp + 1;
vm.warp(blockTimestamp);

uint256 shares2 = (amount2 - 1) * (shares1 + 10 ** 0) / (amount1 - 1 + 1);
feeOnTransferCollateral.transfer(alice, amount2 + 1);
vm.startPrank(alice);
feeOnTransferCollateral.approve(address(vault), amount2);
assertEq(vault.deposit(alice, amount2), shares2);
vm.stopPrank();

assertEq(vault.totalStake(), amount1 - 1 + amount2 - 1);
assertEq(vault.activeSharesAt(uint48(blockTimestamp - 1), ""), shares1);
assertEq(vault.activeSharesAt(uint48(blockTimestamp), ""), shares1 + shares2);
assertEq(vault.activeShares(), shares1 + shares2);
uint256 gasLeft = gasleft();
assertEq(vault.activeSharesAt(uint48(blockTimestamp - 1), abi.encode(1)), shares1);
uint256 gasSpent = gasLeft - gasleft();
gasLeft = gasleft();
assertEq(vault.activeSharesAt(uint48(blockTimestamp - 1), abi.encode(0)), shares1);
assertGt(gasSpent, gasLeft - gasleft());
gasLeft = gasleft();
assertEq(vault.activeSharesAt(uint48(blockTimestamp), abi.encode(0)), shares1 + shares2);
gasSpent = gasLeft - gasleft();
gasLeft = gasleft();
assertEq(vault.activeSharesAt(uint48(blockTimestamp), abi.encode(1)), shares1 + shares2);
assertGt(gasSpent, gasLeft - gasleft());
assertEq(vault.activeStakeAt(uint48(blockTimestamp - 1), ""), amount1 - 1);
assertEq(vault.activeStakeAt(uint48(blockTimestamp), ""), amount1 - 1 + amount2 - 1);
assertEq(vault.activeStake(), amount1 - 1 + amount2 - 1);
gasLeft = gasleft();
assertEq(vault.activeStakeAt(uint48(blockTimestamp - 1), abi.encode(1)), amount1 - 1);
gasSpent = gasLeft - gasleft();
gasLeft = gasleft();
assertEq(vault.activeStakeAt(uint48(blockTimestamp - 1), abi.encode(0)), amount1 - 1);
assertGt(gasSpent, gasLeft - gasleft());
gasLeft = gasleft();
assertEq(vault.activeStakeAt(uint48(blockTimestamp), abi.encode(0)), amount1 - 1 + amount2 - 1);
gasSpent = gasLeft - gasleft();
gasLeft = gasleft();
assertEq(vault.activeStakeAt(uint48(blockTimestamp), abi.encode(1)), amount1 - 1 + amount2 - 1);
assertGt(gasSpent, gasLeft - gasleft());
assertEq(vault.activeStakeAt(uint48(blockTimestamp - 1), ""), shares1);
assertEq(vault.activeStakeAt(uint48(blockTimestamp), ""), shares1 + shares2);
assertEq(vault.activeSharesOf(alice), shares1 + shares2);
gasLeft = gasleft();
assertEq(vault.activeSharesOfAt(alice, uint48(blockTimestamp - 1), abi.encode(1)), shares1);
gasSpent = gasLeft - gasleft();
gasLeft = gasleft();
assertEq(vault.activeSharesOfAt(alice, uint48(blockTimestamp - 1), abi.encode(0)), shares1);
assertGt(gasSpent, gasLeft - gasleft());
gasLeft = gasleft();
assertEq(vault.activeSharesOfAt(alice, uint48(blockTimestamp), abi.encode(0)), shares1 + shares2);
gasSpent = gasLeft - gasleft();
gasLeft = gasleft();
assertEq(vault.activeSharesOfAt(alice, uint48(blockTimestamp), abi.encode(1)), shares1 + shares2);
assertGt(gasSpent, gasLeft - gasleft());
assertEq(vault.activeBalanceOfAt(alice, uint48(blockTimestamp - 1), ""), amount1 - 1);
assertEq(vault.activeBalanceOfAt(alice, uint48(blockTimestamp), ""), amount1 - 1 + amount2 - 1);
assertEq(vault.activeBalanceOf(alice), amount1 - 1 + amount2 - 1);
assertEq(vault.balanceOf(alice), amount1 - 1 + amount2 - 1);
gasLeft = gasleft();
assertEq(
vault.activeBalanceOfAt(
alice,
uint48(blockTimestamp - 1),
abi.encode(
IVault.ActiveBalanceOfHints({
activeSharesOfHint: abi.encode(1),
activeStakeHint: abi.encode(1),
activeSharesHint: abi.encode(1)
})
)
),
amount1 - 1
);
gasSpent = gasLeft - gasleft();
gasLeft = gasleft();
assertEq(
vault.activeBalanceOfAt(
alice,
uint48(blockTimestamp - 1),
abi.encode(
IVault.ActiveBalanceOfHints({
activeSharesOfHint: abi.encode(0),
activeStakeHint: abi.encode(0),
activeSharesHint: abi.encode(0)
})
)
),
amount1 - 1
);
assertGt(gasSpent, gasLeft - gasleft());
gasLeft = gasleft();
assertEq(
vault.activeBalanceOfAt(
alice,
uint48(blockTimestamp),
abi.encode(
IVault.ActiveBalanceOfHints({
activeSharesOfHint: abi.encode(0),
activeStakeHint: abi.encode(0),
activeSharesHint: abi.encode(0)
})
)
),
amount1 - 1 + amount2 - 1
);
gasSpent = gasLeft - gasleft();
gasLeft = gasleft();
assertEq(
vault.activeBalanceOfAt(
alice,
uint48(blockTimestamp),
abi.encode(
IVault.ActiveBalanceOfHints({
activeSharesOfHint: abi.encode(1),
activeStakeHint: abi.encode(1),
activeSharesHint: abi.encode(1)
})
)
),
amount1 - 1 + amount2 - 1
);
assertGt(gasSpent, gasLeft - gasleft());
}

function test_DepositBoth(uint256 amount1, uint256 amount2) public {
amount1 = bound(amount1, 1, 100 * 10 ** 18);
amount2 = bound(amount2, 1, 100 * 10 ** 18);
Expand Down Expand Up @@ -1201,7 +1403,7 @@ contract VaultTest is Test {
slashAmount1 = bound(slashAmount1, 1, type(uint256).max / 2);
slashAmount2 = bound(slashAmount2, 1, type(uint256).max / 2);
captureAgo = bound(captureAgo, 1, 10 days);
vm.assume(depositAmount >= withdrawAmount1 + withdrawAmount2);
vm.assume(depositAmount > withdrawAmount1 + withdrawAmount2);
vm.assume(depositAmount > slashAmount1);
vm.assume(captureAgo <= 7 days);

Expand Down

0 comments on commit a4113dd

Please sign in to comment.