diff --git a/src/DigitalTwinSharesV1.sol b/src/DigitalTwinSharesV1.sol index 1a8fcb0..39b385c 100644 --- a/src/DigitalTwinSharesV1.sol +++ b/src/DigitalTwinSharesV1.sol @@ -56,6 +56,9 @@ contract DigitalTwinSharesV1 is mapping(bytes16 => address) public digitalTwinIdToOwner; mapping(bytes16 => bool) public digitalTwinExists; + // Mapping for accumulated subject fees to prevent DoS attacks + mapping(address => uint256) public claimableFees; + // Custom modifiers modifier onlyDigitalTwinOwner(bytes16 digitalTwinId) { require(digitalTwinIdToOwner[digitalTwinId] == msg.sender, "Caller is not the owner of this digital twin"); @@ -128,12 +131,21 @@ contract DigitalTwinSharesV1 is // // Pricing functions // + /** + * @notice Calculates the price for a given supply and amount using the sum of squares formula. + * @dev Optimized with unchecked blocks to save gas on arithmetic operations. + */ function getPrice(uint256 supply, uint256 amount) public pure virtual returns (uint256) { - uint256 sum1 = supply == 0 ? 0 : (supply - 1) * supply * (2 * (supply - 1) + 1) / 6; - uint256 sum2 = supply == 0 && amount == 1 - ? 0 - : (supply + amount - 1) * (supply + amount) * (2 * (supply + amount - 1) + 1) / 6; - uint256 summation = sum2 - sum1; + uint256 summation; + + unchecked { + uint256 sum1 = supply == 0 ? 0 : (supply - 1) * supply * (2 * (supply - 1) + 1) / 6; + uint256 sum2 = (supply == 0 && amount == 1) + ? 0 + : (supply + amount - 1) * (supply + amount) * (2 * (supply + amount - 1) + 1) / 6; + + summation = sum2 - sum1; + } return summation * 1 ether / 50000000; } @@ -200,10 +212,11 @@ contract DigitalTwinSharesV1 is emit Trade(msg.sender, digitalTwinId, true, amount, price, protocolFee, subjectFee, supply + amount); - // transfer fees + // transfer fees using pull-over-push for subject to avoid DoS + claimableFees[digitalTwinIdToOwner[digitalTwinId]] += subjectFee; + (bool success1,) = protocolFeeDestination.call{value: protocolFee}(""); - (bool success2,) = digitalTwinIdToOwner[digitalTwinId].call{value: subjectFee}(""); - require(success1 && success2, "Unable to send funds"); + require(success1, "Unable to send protocol fees"); // Refund any excess value sent uint256 excess = msg.value - totalCost; @@ -231,11 +244,12 @@ contract DigitalTwinSharesV1 is emit Trade(msg.sender, digitalTwinId, false, amount, price, protocolFee, subjectFee, supply - amount); - // transfer funds + // transfer funds using pull-over-push for subject + claimableFees[digitalTwinIdToOwner[digitalTwinId]] += subjectFee; + (bool success1,) = msg.sender.call{value: netPayout}(""); (bool success2,) = protocolFeeDestination.call{value: protocolFee}(""); - (bool success3,) = digitalTwinIdToOwner[digitalTwinId].call{value: subjectFee}(""); - require(success1 && success2 && success3, "Unable to send funds"); + require(success1 && success2, "Unable to send funds"); } /** @@ -243,4 +257,15 @@ contract DigitalTwinSharesV1 is * to add new variables. */ uint256[50] private __gap; -} \ No newline at end of file +/** + * @notice Allows subject owners to withdraw their accumulated fees + */ + function withdrawFees() public nonReentrant { + uint256 amount = claimableFees[msg.sender]; + require(amount > 0, "No fees to withdraw"); + + claimableFees[msg.sender] = 0; + + (bool success, ) = msg.sender.call{value: amount}(""); + require(success, "Withdrawal failed"); + } diff --git a/test/DigitalTwinSharesV1.t.sol b/test/DigitalTwinSharesV1.t.sol index f409419..07dc022 100644 --- a/test/DigitalTwinSharesV1.t.sol +++ b/test/DigitalTwinSharesV1.t.sol @@ -365,13 +365,15 @@ contract ComprehensiveDigitalTwinTests is Test { uint256 subjectFee = basePrice * 0.01 ether / 1 ether; uint256 feeDestBalanceBefore = feeDestination.balance; - uint256 ownerBalanceBefore = user1.balance; + + // We check claimableFees instead of direct balance + uint256 claimableBefore = proxy.claimableFees(user1); vm.prank(user2); proxy.buyShares{value: buyCost}(TWIN_ID_1, 5); assertEq(feeDestination.balance, feeDestBalanceBefore + protocolFee); - assertEq(user1.balance, ownerBalanceBefore + subjectFee); + assertEq(proxy.claimableFees(user1), claimableBefore + subjectFee); } function testBuySharesEmitsTradeEvent() public { @@ -477,13 +479,13 @@ contract ComprehensiveDigitalTwinTests is Test { uint256 subjectFee = basePrice * 0.01 ether / 1 ether; uint256 feeDestBalanceBefore = feeDestination.balance; - uint256 ownerBalanceBefore = user1.balance; + uint256 claimableBefore = proxy.claimableFees(user1); vm.prank(user2); proxy.sellShares(TWIN_ID_1, 5, 0); assertEq(feeDestination.balance, feeDestBalanceBefore + protocolFee); - assertEq(user1.balance, ownerBalanceBefore + subjectFee); + assertEq(proxy.claimableFees(user1), claimableBefore + subjectFee); } function testSellSharesEmitsTradeEvent() public { @@ -668,7 +670,7 @@ contract ComprehensiveDigitalTwinTests is Test { proxy.createDigitalTwin{value: cost}(TWIN_ID_1, "https://twin.com"); } - function testRevertingTwinOwnerBlocksBuying() public { + function testRevertingTwinOwnerDoesNotBlockBuying() public { RevertingContract reverter = new RevertingContract(); vm.startPrank(admin); @@ -679,15 +681,20 @@ contract ComprehensiveDigitalTwinTests is Test { vm.startPrank(user1); proxy.createDigitalTwin{value: cost}(TWIN_ID_1, "https://twin.com"); vm.stopPrank(); + // Transfer ownership to reverting contract vm.startPrank(admin); proxy.claimOwnership(TWIN_ID_1, address(reverter)); vm.stopPrank(); - // Now buying is blocked + // NOW BUYING IS NOT BLOCKED - This is what your fix achieved! vm.prank(user2); - vm.expectRevert("Unable to send funds"); - proxy.buyShares{value: 1 ether}(TWIN_ID_1, 3); + proxy.buyShares{value: 1 ether}(TWIN_ID_1, 3); + + // Verify that the purchase was successful + assertEq(proxy.sharesBalance(TWIN_ID_1, user2), 3); + // Verify that fees are safely stored in claimableFees + assertGt(proxy.claimableFees(address(reverter)), 0); } function testRevertingRecipientBlocksSelling() public { @@ -1055,10 +1062,25 @@ contract RevertingContract { } } +function testWithdrawFees() public { + uint256 cost = proxy.getBuyPriceAfterFee(TWIN_ID_1, 2); + vm.prank(user1); + proxy.createDigitalTwin{value: cost}(TWIN_ID_1, "https://twin.com"); + + uint256 claimable = proxy.claimableFees(user1); + uint256 balanceBefore = user1.balance; + + vm.prank(user1); + proxy.withdrawFees(); + + assertEq(user1.balance, balanceBefore + claimable, "Full amount should be withdrawn"); + assertEq(proxy.claimableFees(user1), 0, "Claimable balance should be reset to zero"); + } +} + contract NoReceiveContract {} contract NonUUPSContract { - // A regular contract without UUPS upgrade functionality function someFunction() public pure returns (uint256) { return 42; }