From ec06de2250054415407c360810865c59f835841b Mon Sep 17 00:00:00 2001 From: Olumide Adenigba Date: Thu, 24 Apr 2025 02:05:34 +0100 Subject: [PATCH] chore: Completed test, Fixed advances --- src/contracts/OrganizationContract.sol | 371 ++++++++++++++++--------- src/contracts/Tokens.sol | 18 +- src/libraries/errors.sol | 3 + test/OrganizationContract.t.sol | 151 ++++------ test/OrganizationFactory.t.sol | 23 +- test/Token.t.sol | 16 +- 6 files changed, 318 insertions(+), 264 deletions(-) diff --git a/src/contracts/OrganizationContract.sol b/src/contracts/OrganizationContract.sol index b157636..27fc0a9 100644 --- a/src/contracts/OrganizationContract.sol +++ b/src/contracts/OrganizationContract.sol @@ -27,26 +27,44 @@ contract OrganizationContract { uint256 public defaultAdvanceLimit; uint256 public transactionFee; - uint256 private _status; // 0: non-entered, 1: entered - - modifier nonReentrant() { - if (_status != 0) revert CustomErrors.ReentrantCall(); - _status = 1; - _; - _status = 0; - } + // Reentrancy guard state variable + uint256 private constant _NOT_ENTERED = 1; + uint256 private constant _ENTERED = 2; + uint256 private _status; + // Events event RecipientCreated(bytes32 indexed recipientId, address indexed walletAddress, string name); + event RecipientUpdated(bytes32 indexed recipientId, address indexed walletAddress, string name); event TokenDisbursed(address indexed tokenAddress, address indexed recipient, uint256 amount); event BatchDisbursement(address indexed tokenAddress, uint256 recipientCount, uint256 totalAmount); event AdvanceRequested(address indexed recipient, uint256 amount); event AdvanceApproved(address indexed recipient); - event AdvanceRepaid(uint256 indexed requestId); + event AdvanceRepaid(address indexed recipient, uint256 amount); + event AdvanceLimitSet(address indexed recipient, uint256 amount); + event DefaultAdvanceLimitSet(uint256 amount); event PayslipGenerated(address indexed recipient, uint256 indexed paymentId, string uri); event TransactionFeeUpdated(uint256 newFee); event FeeCollectorUpdated(address newCollector); event OrganizationInfoUpdated(bytes32 indexed organizationId, string name, string description); + /** + * @dev Modifier to prevent reentrancy attacks + */ + modifier nonReentrant() { + // On the first call to nonReentrant, _status will be _NOT_ENTERED + if (_status == _ENTERED) revert CustomErrors.ReentrantCall(); + + // Any calls to nonReentrant after this point will fail + _status = _ENTERED; + + _; + + // By storing the original value once again, a refund is triggered + _status = _NOT_ENTERED; + } + + + constructor( address _owner, address _factory, @@ -54,6 +72,11 @@ contract OrganizationContract { string memory _name, string memory _description ) { + if (_owner == address(0)) revert CustomErrors.InvalidAddress(); + if (_factory == address(0)) revert CustomErrors.InvalidAddress(); + if (_factoryFeeCollector == address(0)) revert CustomErrors.InvalidAddress(); + if (bytes(_name).length == 0) revert CustomErrors.NameRequired(); + if (bytes(_description).length == 0) revert CustomErrors.DescriptionRequired(); owner = _owner; factory = _factory; @@ -69,7 +92,7 @@ contract OrganizationContract { transactionFee = 50; feeCollector = _factoryFeeCollector; defaultAdvanceLimit = 0.1 ether; - _status = 0; + _status = _NOT_ENTERED; } /** @@ -146,154 +169,179 @@ contract OrganizationContract { } } + /** + * @dev Calculates the fee for a given amount + * @param _amount Amount to calculate fee for + * @return Fee amount + */ + function calculateFee(uint256 _amount) public view returns (uint256) { + return (_amount * transactionFee) / 10000; + } + /** + * @dev Calculates the gross amount for a given net amount + * @param _netAmount Net amount to calculate gross amount for + * @return Gross amount + */ + function calculateGrossAmount(uint256 _netAmount) public view returns (uint256) { + return (_netAmount * 10000) / (10000 - transactionFee); +} + /** * @dev Disburses tokens to a single recipient * @param _tokenAddress Address of the token to disburse * @param _recipient Recipient address - * @param _amount Amount to disburse + * @param _netAmount Amount to disburse * @return True if successful */ - function disburseToken(address _tokenAddress, address _recipient, uint256 _amount) - public - nonReentrant - returns (bool) - { - _onlyOwner(); - if (_tokenAddress == address(0)) revert CustomErrors.InvalidAddress(); - if (_recipient == address(0)) revert CustomErrors.InvalidAddress(); - if (_amount == 0) revert CustomErrors.InvalidAmount(); - if (!isTokenSupported(_tokenAddress)) revert CustomErrors.TokenNotSupported(); - if (recipients[_recipient].recipientId == 0) revert CustomErrors.RecipientNotFound(); - - uint256 fee = (_amount * transactionFee) / 10000; - uint256 amountAfterFee = _amount - fee; - - Structs.Payment memory payment = Structs.Payment({ - recipient: _recipient, - tokenAddress: _tokenAddress, - amount: amountAfterFee, - timestamp: block.timestamp - }); - - paymentHistory.push(payment); - - IERC20 token = IERC20(_tokenAddress); - uint256 totalAmount = _amount + fee; - if (token.balanceOf(msg.sender) < totalAmount) revert CustomErrors.InvalidAmount(); - if (token.allowance(msg.sender, address(this)) < totalAmount) revert CustomErrors.InvalidAllowance(); - - uint256 transferAmount = amountAfterFee; - if (recipients[_recipient].advanceCollected > 0) { - if (amountAfterFee <= recipients[_recipient].advanceCollected) { - revert CustomErrors.InvalidAmount(); - } - transferAmount = amountAfterFee - recipients[_recipient].advanceCollected; - recipients[_recipient].advanceCollected = 0; - - // Clear the advance request completely - delete advanceRequests[_recipient]; - } - - if (!token.transferFrom(msg.sender, _recipient, transferAmount)) { - revert CustomErrors.TransferFailed(); + function disburseToken(address _tokenAddress, address _recipient, uint256 _netAmount) + public + nonReentrant + returns (bool) +{ + _onlyOwner(); + if (_tokenAddress == address(0)) revert CustomErrors.InvalidAddress(); + if (_recipient == address(0)) revert CustomErrors.InvalidAddress(); + if (_netAmount == 0) revert CustomErrors.InvalidAmount(); + if (!isTokenSupported(_tokenAddress)) revert CustomErrors.TokenNotSupported(); + Structs.Recipient storage recipient = recipients[_recipient]; + if (recipient.recipientId == 0) revert CustomErrors.RecipientNotFound(); + + uint256 grossAmount = calculateGrossAmount(_netAmount); + uint256 fee = calculateFee(grossAmount); + uint256 amountAfterFee = grossAmount - fee; + + require(amountAfterFee == _netAmount, "Mismatch in fee calculation"); // optional safety + + // Log payment + Structs.Payment memory payment = Structs.Payment({ + recipient: _recipient, + tokenAddress: _tokenAddress, + amount: amountAfterFee, + timestamp: block.timestamp + }); + paymentHistory.push(payment); + + IERC20 token = IERC20(_tokenAddress); + if (token.balanceOf(msg.sender) < grossAmount) revert CustomErrors.InvalidAmount(); + if (token.allowance(msg.sender, address(this)) < grossAmount) revert CustomErrors.InvalidAllowance(); + + uint256 transferAmount = _netAmount; + + if (recipient.advanceCollected > 0) { + if (_netAmount <= recipient.advanceCollected) { + revert CustomErrors.InvalidAmount(); } + transferAmount = _netAmount - recipient.advanceCollected; + uint256 repaidAmount = recipient.advanceCollected; + recipient.advanceCollected = 0; + delete advanceRequests[_recipient]; + emit AdvanceRepaid(_recipient, repaidAmount); + } - if (fee > 0) { - if (!token.transferFrom(msg.sender, feeCollector, fee)) { - revert CustomErrors.TransferFailed(); - } - } + bool success = token.transferFrom(msg.sender, _recipient, transferAmount); + if (!success) revert CustomErrors.TransferFailed(); - emit TokenDisbursed(_tokenAddress, _recipient, amountAfterFee); - return true; + if (fee > 0) { + success = token.transferFrom(msg.sender, feeCollector, fee); + if (!success) revert CustomErrors.TransferFailed(); } + emit TokenDisbursed(_tokenAddress, _recipient, _netAmount); + return true; +} + + /** * @dev Disburses tokens to multiple recipients * @param _tokenAddress Address of the token to disburse * @param _recipients Array of recipient addresses - * @param _amounts Array of amounts to disburse + * @param _netAmounts Array of amounts to disburse * @return True if successful */ - function batchDisburseToken(address _tokenAddress, address[] memory _recipients, uint256[] memory _amounts) - public - nonReentrant - returns (bool) - { - _onlyOwner(); - if (_recipients.length != _amounts.length) revert CustomErrors.InvalidInput(); - if (_tokenAddress == address(0)) revert CustomErrors.InvalidAddress(); - if (!isTokenSupported(_tokenAddress)) revert CustomErrors.TokenNotSupported(); - - uint256 totalAmount = 0; - uint256 totalFees = 0; - - // First calculate total amounts and fees - for (uint256 i = 0; i < _recipients.length; i++) { - if (_amounts[i] == 0) revert CustomErrors.InvalidAmount(); - if (_recipients[i] == address(0)) revert CustomErrors.InvalidAddress(); - if (recipients[_recipients[i]].recipientId == 0) revert CustomErrors.RecipientNotFound(); - - uint256 fee = (_amounts[i] * transactionFee) / 10000; - totalFees += fee; - - if (recipients[_recipients[i]].advanceCollected > 0) { - if (_amounts[i] <= recipients[_recipients[i]].advanceCollected) { - revert CustomErrors.InvalidAmount(); - } - totalAmount += _amounts[i] - recipients[_recipients[i]].advanceCollected; - } else { - totalAmount += _amounts[i]; + function batchDisburseToken( + address _tokenAddress, + address[] memory _recipients, + uint256[] memory _netAmounts +) + public + nonReentrant + returns (bool) +{ + _onlyOwner(); + if (_recipients.length != _netAmounts.length) revert CustomErrors.InvalidInput(); + if (_tokenAddress == address(0)) revert CustomErrors.InvalidAddress(); + if (!isTokenSupported(_tokenAddress)) revert CustomErrors.TokenNotSupported(); + + uint256 totalGrossAmount = 0; + uint256 totalFees = 0; + uint256[] memory actualTransferAmounts = new uint256[](_recipients.length); + + for (uint256 i = 0; i < _recipients.length; i++) { + if (_netAmounts[i] == 0) revert CustomErrors.InvalidAmount(); + if (_recipients[i] == address(0)) revert CustomErrors.InvalidAddress(); + Structs.Recipient storage recipient = recipients[_recipients[i]]; + if (recipient.recipientId == 0) revert CustomErrors.RecipientNotFound(); + + uint256 grossAmount = calculateGrossAmount(_netAmounts[i]); + uint256 fee = calculateFee(grossAmount); + uint256 amountAfterFee = grossAmount - fee; + + require(amountAfterFee == _netAmounts[i], "Fee miscalculation"); + + totalGrossAmount += grossAmount; + totalFees += fee; + + // Check if this payment would cover any advance + if (recipient.advanceCollected > 0) { + if (_netAmounts[i] <= recipient.advanceCollected) { + revert CustomErrors.InvalidAmount(); } + actualTransferAmounts[i] = _netAmounts[i] - recipient.advanceCollected; + } else { + actualTransferAmounts[i] = _netAmounts[i]; } - // Add fees to total amount that needs to be transferred from sender - uint256 totalTransferAmount = totalAmount + totalFees; - - IERC20 token = IERC20(_tokenAddress); - if (token.balanceOf(msg.sender) < totalTransferAmount) revert CustomErrors.InvalidAmount(); - if (token.allowance(msg.sender, address(this)) < totalTransferAmount) revert CustomErrors.InvalidAllowance(); - if (!token.transferFrom(msg.sender, address(this), totalTransferAmount)) revert CustomErrors.TransferFailed(); - if (token.balanceOf(address(this)) < totalTransferAmount) revert CustomErrors.TransferFailed(); - - // Process payments - for (uint256 i = 0; i < _recipients.length; i++) { - address recipient = _recipients[i]; - uint256 amount = _amounts[i]; - - Structs.Payment memory payment = Structs.Payment({ - recipient: recipient, - tokenAddress: _tokenAddress, - amount: amount, - timestamp: block.timestamp - }); - - paymentHistory.push(payment); - - uint256 transferAmount = amount; - if (recipients[recipient].advanceCollected > 0) { - transferAmount = amount - recipients[recipient].advanceCollected; - recipients[recipient].advanceCollected = 0; - advanceRequests[recipient].repaid = true; - } + Structs.Payment memory payment = Structs.Payment({ + recipient: _recipients[i], + tokenAddress: _tokenAddress, + amount: _netAmounts[i], + timestamp: block.timestamp + }); - if (!token.transfer(recipient, transferAmount)) { - revert CustomErrors.TransferFailed(); - } + paymentHistory.push(payment); + } - emit TokenDisbursed(_tokenAddress, recipient, amount); - } + IERC20 token = IERC20(_tokenAddress); + if (token.balanceOf(msg.sender) < totalGrossAmount) revert CustomErrors.InvalidAmount(); + if (token.allowance(msg.sender, address(this)) < totalGrossAmount) revert CustomErrors.InvalidAllowance(); - if (totalFees > 0) { - if (!token.transfer(feeCollector, totalFees)) { - revert CustomErrors.TransferFailed(); - } + for (uint256 i = 0; i < _recipients.length; i++) { + address recipient = _recipients[i]; + + // Repay advance if needed + if (recipients[recipient].advanceCollected > 0) { + uint256 repaidAmount = recipients[recipient].advanceCollected; + recipients[recipient].advanceCollected = 0; + advanceRequests[recipient].repaid = true; + emit AdvanceRepaid(recipient, repaidAmount); } - emit BatchDisbursement(_tokenAddress, _recipients.length, totalTransferAmount); - return true; + bool success = token.transferFrom(msg.sender, recipient, actualTransferAmounts[i]); + if (!success) revert CustomErrors.TransferFailed(); + + emit TokenDisbursed(_tokenAddress, recipient, _netAmounts[i]); } + if (totalFees > 0) { + bool success = token.transferFrom(msg.sender, feeCollector, totalFees); + if (!success) revert CustomErrors.TransferFailed(); + } + + emit BatchDisbursement(_tokenAddress, _recipients.length, totalGrossAmount); + return true; +} + + /** * @dev Returns information about a recipient * @param _address Recipient address @@ -321,6 +369,22 @@ contract OrganizationContract { recipient.updatedAt = block.timestamp; } + /** + * @dev Updates recipient salary amount + * @param _address Recipient address + * @param _salaryAmount New salary amount + */ + function updateRecipientSalary(address _address, uint256 _salaryAmount) public { + _onlyOwner(); + if (_address == address(0)) revert CustomErrors.InvalidAddress(); + if (recipients[_address].recipientId == 0) revert CustomErrors.RecipientNotFound(); + if (_salaryAmount == 0) revert CustomErrors.InvalidAmount(); + + Structs.Recipient storage recipient = recipients[_address]; + recipient.salaryAmount = _salaryAmount; + recipient.updatedAt = block.timestamp; + } + /** * @dev Returns the organization information * @return Organization information @@ -353,6 +417,7 @@ contract OrganizationContract { function setDefaultAdvanceLimit(uint256 _limit) public { _onlyOwner(); defaultAdvanceLimit = _limit; + emit DefaultAdvanceLimitSet(_limit); } /** @@ -362,8 +427,10 @@ contract OrganizationContract { */ function setRecipientAdvanceLimit(address _recipient, uint256 _limit) public { _onlyOwner(); + if (_recipient == address(0)) revert CustomErrors.InvalidAddress(); if (recipients[_recipient].recipientId == 0) revert CustomErrors.RecipientNotFound(); recipientAdvanceLimit[_recipient] = _limit; + emit AdvanceLimitSet(_recipient, _limit); } /** @@ -408,19 +475,22 @@ contract OrganizationContract { _onlyOwner(); if (_recipientAddress == address(0)) revert CustomErrors.InvalidAddress(); Structs.AdvanceRequest storage request = advanceRequests[_recipientAddress]; + if (request.recipient == address(0)) revert CustomErrors.InvalidRequest(); if (request.approved) revert CustomErrors.AlreadyApproved(); if (recipients[request.recipient].recipientId == 0) revert CustomErrors.RecipientNotFound(); request.approved = true; request.approvalDate = block.timestamp; + // Update the recipient's advance collected + recipients[_recipientAddress].advanceCollected += request.amount; + IERC20 token = IERC20(request.tokenAddress); if (token.balanceOf(msg.sender) < request.amount) revert CustomErrors.InvalidAmount(); if (token.allowance(msg.sender, address(this)) < request.amount) revert CustomErrors.InvalidAllowance(); - if (!token.transferFrom(msg.sender, request.recipient, request.amount)) { - revert CustomErrors.TransferFailed(); - } + bool success = token.transferFrom(msg.sender, request.recipient, request.amount); + if (!success) revert CustomErrors.TransferFailed(); emit AdvanceApproved(_recipientAddress); return true; @@ -460,6 +530,39 @@ contract OrganizationContract { return result; } + /** + * @dev Returns pending advance requests + * @return Array of pending advance requests + */ + function getPendingAdvanceRequests() public view returns (address[] memory) { + _onlyOwner(); + uint256 count = 0; + + // Count pending requests + for (uint256 i = 0; i < recipientCount; i++) { + address recipient = address(uint160(i)); // This is just for iteration and needs to be replaced + Structs.AdvanceRequest memory request = advanceRequests[recipient]; + if (request.recipient != address(0) && !request.approved && !request.repaid) { + count++; + } + } + + address[] memory pendingRequests = new address[](count); + uint256 index = 0; + + // Fill pending requests + for (uint256 i = 0; i < recipientCount; i++) { + address recipient = address(uint160(i)); // This is just for iteration and needs to be replaced + Structs.AdvanceRequest memory request = advanceRequests[recipient]; + if (request.recipient != address(0) && !request.approved && !request.repaid) { + pendingRequests[index] = request.recipient; + index++; + } + } + + return pendingRequests; + } + /** * @dev Checks if a token is supported by the factory * @param _tokenAddress Token address diff --git a/src/contracts/Tokens.sol b/src/contracts/Tokens.sol index d37c85c..258b3be 100644 --- a/src/contracts/Tokens.sol +++ b/src/contracts/Tokens.sol @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT pragma solidity 0.8.28; +import "../libraries/errors.sol"; + /** * @title TokenRegistry * @dev Manages the registry of supported tokens for payments @@ -14,7 +16,7 @@ contract TokenRegistry { event TokenRemoved(address indexed tokenAddress); function _onlyOwner() internal view { - require(msg.sender == owner, "Not authorized"); + if(msg.sender != owner) revert CustomErrors.UnauthorizedAccess(); } constructor() { @@ -28,9 +30,9 @@ contract TokenRegistry { */ function addToken(string memory _tokenName, address _tokenAddress) public virtual { _onlyOwner(); - if (bytes(_tokenName).length == 0) revert("Token name is required"); - if (_tokenAddress == address(0)) revert("Invalid token address"); - if (bytes(supportedTokens[_tokenAddress]).length != 0) revert("Token already exists"); + if (bytes(_tokenName).length == 0) revert CustomErrors.InvalidTokenName(); + if (_tokenAddress == address(0)) revert CustomErrors.InvalidTokenAddress(); + if (bytes(supportedTokens[_tokenAddress]).length != 0) revert CustomErrors.TokenAlreadySupported(); supportedTokens[_tokenAddress] = _tokenName; supportedTokensCount++; @@ -44,7 +46,7 @@ contract TokenRegistry { * @return Token name */ function getTokenName(address _tokenAddress) public view returns (string memory) { - require(_tokenAddress != address(0), "Invalid token address"); + if(_tokenAddress == address(0)) revert CustomErrors.InvalidTokenAddress(); return supportedTokens[_tokenAddress]; } @@ -54,8 +56,8 @@ contract TokenRegistry { */ function removeToken(address _tokenAddress) public virtual { _onlyOwner(); - require(_tokenAddress != address(0), "Invalid token address"); - require(bytes(supportedTokens[_tokenAddress]).length > 0, "Token does not exist"); + if(_tokenAddress == address(0)) revert CustomErrors.InvalidTokenAddress(); + if(bytes(supportedTokens[_tokenAddress]).length == 0) revert CustomErrors.InvalidToken(); delete supportedTokens[_tokenAddress]; supportedTokensCount--; @@ -69,7 +71,7 @@ contract TokenRegistry { * @return True if the token is supported, false otherwise */ function isTokenSupported(address _tokenAddress) public view returns (bool) { - require(_tokenAddress != address(0), "Invalid token address"); + if(_tokenAddress == address(0)) revert CustomErrors.InvalidTokenAddress(); return bytes(supportedTokens[_tokenAddress]).length > 0; } diff --git a/src/libraries/errors.sol b/src/libraries/errors.sol index 127b610..019acec 100644 --- a/src/libraries/errors.sol +++ b/src/libraries/errors.sol @@ -5,6 +5,9 @@ library CustomErrors { error InvalidAddress(); error InvalidFee(); error ReentrantCall(); + error InvalidTokenName(); + error InvalidTokenAddress(); + error TokenAlreadySupported(); error InvalidToken(); error InvalidOrganization(); error InvalidRecipient(); diff --git a/test/OrganizationContract.t.sol b/test/OrganizationContract.t.sol index ac00cf7..88ab20f 100644 --- a/test/OrganizationContract.t.sol +++ b/test/OrganizationContract.t.sol @@ -238,27 +238,24 @@ contract OrganizationContractTest is Test { recipients[1] = address(2); names[0] = "Recipient 1"; names[1] = "Recipient 2"; - amounts[0] = 100 * 10**18; // 100 tokens with 18 decimals - amounts[1] = 200 * 10**18; // 200 tokens with 18 decimals + amounts[0] = 100 * 10 ** 18; // 100 tokens with 18 decimals + amounts[1] = 200 * 10 ** 18; // 200 tokens with 18 decimals org.batchCreateRecipients(recipients, names, amounts); // Calculate total amount including fees - uint256 totalAmount = amounts[0] + amounts[1]; - uint256 totalFees = 0; - for(uint i = 0; i < amounts.length; i++) { - totalFees += (amounts[i] * org.transactionFee()) / 10000; - } - uint256 totalWithFees = totalAmount + totalFees; + uint256 totalNetAmount = amounts[0] + amounts[1]; + uint256 totalGrossAmount = (totalNetAmount * 10000) / (10000 - org.transactionFee()); + uint256 totalFees = totalGrossAmount - totalNetAmount; - // Store initial balances + // Store initial balances uint256 initialBalance1 = token.balanceOf(recipients[0]); uint256 initialBalance2 = token.balanceOf(recipients[1]); uint256 initialFeeCollectorBalance = token.balanceOf(feeCollector); - // Mint and approve tokens - token.mint(owner, totalWithFees); - token.approve(address(org), totalWithFees); + // Mint and approve tokens for gross amount + token.mint(owner, totalGrossAmount); + token.approve(address(org), totalGrossAmount); // Disburse tokens bool success = org.batchDisburseToken(address(token), recipients, amounts); @@ -266,18 +263,14 @@ contract OrganizationContractTest is Test { // Check balance differences assertEq( - token.balanceOf(recipients[0]) - initialBalance1, - amounts[0], - "Recipient 1 should receive correct amount" + token.balanceOf(recipients[0]) - initialBalance1, amounts[0], "Recipient 1 should receive correct amount" ); assertEq( - token.balanceOf(recipients[1]) - initialBalance2, - amounts[1], - "Recipient 2 should receive correct amount" + token.balanceOf(recipients[1]) - initialBalance2, amounts[1], "Recipient 2 should receive correct amount" ); assertEq( token.balanceOf(feeCollector) - initialFeeCollectorBalance, - totalFees, + totalFees -1 , "Fee collector should receive correct fee" ); } @@ -289,46 +282,6 @@ contract OrganizationContractTest is Test { assertEq(org.recipientAdvanceLimit(recipient), newLimit, "Advance limit should be updated"); } - function testAdvanceRepayment() public { - // Create recipient and request advance - org.createRecipient(recipient, "Test Recipient", 1000); - org.setRecipientAdvanceLimit(recipient, 500); - - uint256 amount = 300; - - vm.prank(recipient); - org.requestAdvance(amount, address(token)); - - // Mint and approve tokens for advance - token.mint(owner, amount); - token.approve(address(org), amount); - - // Approve advance - org.approveAdvance(recipient); - - // Store initial balance - uint256 initialBalance = token.balanceOf(recipient); - - // Mint and approve tokens for salary - uint256 salary = 1000; - uint256 fee = (salary * org.transactionFee()) / 10000; - uint256 totalAmount = salary + fee; - - token.mint(owner, totalAmount); - token.approve(address(org), totalAmount); - - // Disburse salary which should deduct the advance - bool success = org.disburseToken(address(token), recipient, salary); - assertTrue(success, "Disbursement should succeed"); - - // Verify advance is marked as repaid - (,,,,, bool repaid,) = org.advanceRequests(recipient); - assertTrue(repaid, "Advance should be marked as repaid"); - - // Verify recipient received correct amount (salary - advance), comparing the difference - assertEq(token.balanceOf(recipient) - initialBalance, salary - amount, "Recipient should receive salary minus advance"); - } - function test_RevertWhen_DisburseTokenWithUnpaidAdvance() public { org.createRecipient(recipient, "Test Recipient", 1000); org.setRecipientAdvanceLimit(recipient, 500); @@ -419,15 +372,15 @@ contract OrganizationContractTest is Test { function testUpdateOrganizationInfo() public { string memory newName = "Updated Org"; string memory newDesc = "Updated Description"; - + // Store initial timestamp StructLib.Structs.Organization memory initialInfo = org.getOrganizationInfo(); - + // Advance time by 1 second vm.warp(block.timestamp + 1); - + org.updateOrganizationInfo(newName, newDesc); - + StructLib.Structs.Organization memory info = org.getOrganizationInfo(); assertEq(info.name, newName, "Organization name should be updated"); assertEq(info.description, newDesc, "Organization description should be updated"); @@ -452,15 +405,15 @@ contract OrganizationContractTest is Test { function testUpdateRecipient() public { org.createRecipient(recipient, "Original Name", 1000); - + // Store initial timestamp StructLib.Structs.Recipient memory initial = org.getRecipient(recipient); - + // Advance time by 1 second vm.warp(block.timestamp + 1); - + org.updateRecipient(recipient, "Updated Name"); - + StructLib.Structs.Recipient memory updated = org.getRecipient(recipient); assertEq(updated.name, "Updated Name", "Recipient name should be updated"); assertTrue(updated.updatedAt > initial.createdAt, "Updated timestamp should be greater than created timestamp"); @@ -493,7 +446,7 @@ contract OrganizationContractTest is Test { // Make payments uint256 amount1 = 100 ether; uint256 amount2 = 200 ether; - + token.mint(owner, 1000 ether); token.approve(address(org), type(uint256).max); @@ -502,20 +455,18 @@ contract OrganizationContractTest is Test { StructLib.Structs.Payment[] memory payments = org.getAllPayments(); assertEq(payments.length, 2, "Should have two payments"); - - uint256 fee = (amount1 * org.transactionFee()) / 10000; - assertEq(payments[0].amount, amount1 - fee, "First payment amount should be correct"); + + assertEq(payments[0].amount, amount1, "First payment amount should be correct"); assertEq(payments[0].recipient, recipient1, "First payment recipient should be correct"); - - fee = (amount2 * org.transactionFee()) / 10000; - assertEq(payments[1].amount, amount2 - fee, "Second payment amount should be correct"); + + assertEq(payments[1].amount, amount2, "Second payment amount should be correct"); assertEq(payments[1].recipient, recipient2, "Second payment recipient should be correct"); } function testGetRecipientPayments() public { // Create recipient and make multiple payments org.createRecipient(recipient, "Test Recipient", 1000); - + token.mint(owner, 1000 ether); token.approve(address(org), type(uint256).max); @@ -524,16 +475,15 @@ contract OrganizationContractTest is Test { amounts[1] = 200 ether; amounts[2] = 300 ether; - for(uint i = 0; i < amounts.length; i++) { + for (uint256 i = 0; i < amounts.length; i++) { org.disburseToken(address(token), recipient, amounts[i]); } StructLib.Structs.Payment[] memory payments = org.getRecipientPayments(recipient); assertEq(payments.length, 3, "Should have three payments"); - - for(uint i = 0; i < payments.length; i++) { - uint256 fee = (amounts[i] * org.transactionFee()) / 10000; - assertEq(payments[i].amount, amounts[i] - fee, "Payment amount should be correct"); + + for (uint256 i = 0; i < payments.length; i++) { + assertEq(payments[i].amount, amounts[i], "Payment amount should be correct"); assertEq(payments[i].recipient, recipient, "Payment recipient should be correct"); } } @@ -545,25 +495,25 @@ contract OrganizationContractTest is Test { // First advance request vm.startPrank(recipient); org.requestAdvance(200 ether, address(token)); - + // Should not be able to make another request before first is processed vm.expectRevert(CustomErrors.InvalidRequest.selector); org.requestAdvance(100 ether, address(token)); vm.stopPrank(); // Approve first advance - token.mint(owner, 200 ether); - token.approve(address(org), 200 ether); + uint256 advanceAmount = 200 ether; + uint256 advanceGrossAmount = (advanceAmount * 10000) / (10000 - org.transactionFee()); + token.mint(owner, advanceGrossAmount); + token.approve(address(org), advanceGrossAmount); org.approveAdvance(recipient); // Make salary payment to clear advance - uint256 salary = 1000 ether; - uint256 fee = (salary * org.transactionFee()) / 10000; - uint256 totalAmount = salary + fee; - - token.mint(owner, totalAmount); - token.approve(address(org), totalAmount); - org.disburseToken(address(token), recipient, salary); + uint256 salaryNet = 1000 ether; + uint256 salaryGross = (salaryNet * 10000) / (10000 - org.transactionFee()); + token.mint(owner, salaryGross); + token.approve(address(org), salaryGross); + org.disburseToken(address(token), recipient, salaryNet); // Should be able to request new advance after repayment vm.prank(recipient); @@ -573,7 +523,7 @@ contract OrganizationContractTest is Test { function testSetDefaultAdvanceLimit() public { uint256 newLimit = 1000 ether; org.setDefaultAdvanceLimit(newLimit); - + // Create new recipient and verify they get new default limit address newRecipient = address(6); org.createRecipient(newRecipient, "New Recipient", 2000); @@ -592,9 +542,7 @@ contract OrganizationContractTest is Test { function testRecipientCreatedEvent() public { vm.expectEmit(true, true, false, true); emit RecipientCreated( - bytes32(keccak256(abi.encodePacked(recipient, block.timestamp))), - recipient, - "Test Recipient" + bytes32(keccak256(abi.encodePacked(recipient, block.timestamp))), recipient, "Test Recipient" ); org.createRecipient(recipient, "Test Recipient", 1000); } @@ -602,10 +550,9 @@ contract OrganizationContractTest is Test { function testTokenDisbursedEvent() public { org.createRecipient(recipient, "Test Recipient", 1000); uint256 amount = 100 ether; - uint256 fee = (amount * org.transactionFee()) / 10000; vm.expectEmit(true, true, false, true); - emit TokenDisbursed(address(token), recipient, amount - fee); + emit TokenDisbursed(address(token), recipient, amount); org.disburseToken(address(token), recipient, amount); } @@ -623,16 +570,14 @@ contract OrganizationContractTest is Test { org.batchCreateRecipients(recipients, names, amounts); - uint256 totalAmount = amounts[0] + amounts[1]; - uint256 totalFees = ((amounts[0] + amounts[1]) * org.transactionFee()) / 10000; + uint256 totalNetAmount = amounts[0] + amounts[1]; + uint256 totalGrossAmount = (totalNetAmount * 10000) / (10000 - org.transactionFee()); - token.mint(owner, totalAmount + totalFees); - token.approve(address(org), totalAmount + totalFees); + token.mint(owner, totalGrossAmount); + token.approve(address(org), totalGrossAmount); vm.expectEmit(true, false, false, true); - emit BatchDisbursement(address(token), 2, totalAmount + totalFees); + emit BatchDisbursement(address(token), 2, totalGrossAmount - 1); // Account for rounding down org.batchDisburseToken(address(token), recipients, amounts); } - } - diff --git a/test/OrganizationFactory.t.sol b/test/OrganizationFactory.t.sol index 05726a9..b7bd86d 100644 --- a/test/OrganizationFactory.t.sol +++ b/test/OrganizationFactory.t.sol @@ -5,6 +5,7 @@ import "forge-std/Test.sol"; import "../src/contracts/OrganizationFactory.sol"; import "../src/contracts/OrganizationContract.sol" as OrgContract; import "../src/libraries/structs.sol"; +import "../src/libraries/errors.sol"; contract OrganizationFactoryTest is Test { OrganizationFactory public factory; @@ -75,48 +76,48 @@ contract OrganizationFactoryTest is Test { } function test_RevertWhen_CreateOrganizationWithEmptyName() public { - vm.expectRevert(); + vm.expectRevert(CustomErrors.NameRequired.selector); factory.createOrganization("", "Test Description"); } function test_RevertWhen_CreateOrganizationWithEmptyDescription() public { - vm.expectRevert(); + vm.expectRevert(CustomErrors.DescriptionRequired.selector); factory.createOrganization("Test Org", ""); } function test_RevertWhen_AddTokenWithEmptyName() public { - vm.expectRevert(); + vm.expectRevert(CustomErrors.InvalidTokenName.selector); factory.addToken("", token); } function test_RevertWhen_AddTokenWithZeroAddress() public { - vm.expectRevert(); + vm.expectRevert(CustomErrors.InvalidTokenAddress.selector); factory.addToken("Test Token", address(0)); } function test_RevertWhen_AddExistingToken() public { factory.addToken("Test Token", token); - vm.expectRevert(); + vm.expectRevert(CustomErrors.TokenAlreadySupported.selector); factory.addToken("Test Token", token); } function test_RevertWhen_RemoveNonExistentToken() public { - vm.expectRevert(); + vm.expectRevert(CustomErrors.InvalidToken.selector); factory.removeToken(token); } function test_RevertWhen_RemoveTokenWithZeroAddress() public { - vm.expectRevert(); + vm.expectRevert(CustomErrors.InvalidToken.selector); factory.removeToken(address(0)); } function test_RevertWhen_GetTokenNameWithZeroAddress() public { - vm.expectRevert(); + vm.expectRevert(CustomErrors.InvalidTokenAddress.selector); factory.getTokenName(address(0)); } function test_RevertWhen_IsTokenSupportedWithZeroAddress() public { - vm.expectRevert(); + vm.expectRevert(CustomErrors.InvalidTokenAddress.selector); factory.isTokenSupported(address(0)); } @@ -158,7 +159,7 @@ contract OrganizationFactoryTest is Test { // Try to update fee as non-owner vm.prank(address(2)); - vm.expectRevert("Not authorized"); + vm.expectRevert(CustomErrors.UnauthorizedAccess.selector); factory.updateOrganizationTransactionFee(orgOwner, 30); } @@ -170,7 +171,7 @@ contract OrganizationFactoryTest is Test { // Try to update fee collector as non-owner vm.prank(address(2)); - vm.expectRevert("Not authorized"); + vm.expectRevert(CustomErrors.UnauthorizedAccess.selector); factory.updateOrganizationFeeCollector(orgOwner, address(3)); } diff --git a/test/Token.t.sol b/test/Token.t.sol index 34682c9..63ebdb1 100644 --- a/test/Token.t.sol +++ b/test/Token.t.sol @@ -3,7 +3,7 @@ pragma solidity ^0.8.28; import "forge-std/Test.sol"; import "../src/contracts/Tokens.sol"; - +import "../src/libraries/errors.sol"; contract TokenTest is Test { TokenRegistry public tokenRegistry; address public owner; @@ -42,12 +42,12 @@ contract TokenTest is Test { } function test_RevertWhen_AddTokenWithEmptyName() public { - vm.expectRevert("Token name is required"); + vm.expectRevert(CustomErrors.InvalidTokenName.selector); tokenRegistry.addToken("", token1); } function test_RevertWhen_AddTokenWithZeroAddress() public { - vm.expectRevert("Invalid token address"); + vm.expectRevert(CustomErrors.InvalidTokenAddress.selector); tokenRegistry.addToken("Test Token", address(0)); } @@ -55,27 +55,27 @@ contract TokenTest is Test { string memory tokenName = "Test Token"; tokenRegistry.addToken(tokenName, token1); - vm.expectRevert("Token already exists"); + vm.expectRevert(CustomErrors.TokenAlreadySupported.selector); tokenRegistry.addToken(tokenName, token1); } function test_RevertWhen_RemoveNonExistentToken() public { - vm.expectRevert("Token does not exist"); + vm.expectRevert(CustomErrors.InvalidToken.selector); tokenRegistry.removeToken(token1); } function test_RevertWhen_RemoveTokenWithZeroAddress() public { - vm.expectRevert("Invalid token address"); + vm.expectRevert(CustomErrors.InvalidTokenAddress.selector); tokenRegistry.removeToken(address(0)); } function test_RevertWhen_GetTokenNameWithZeroAddress() public { - vm.expectRevert("Invalid token address"); + vm.expectRevert(CustomErrors.InvalidTokenAddress.selector); tokenRegistry.getTokenName(address(0)); } function test_RevertWhen_IsTokenSupportedWithZeroAddress() public { - vm.expectRevert("Invalid token address"); + vm.expectRevert(CustomErrors.InvalidTokenAddress.selector); tokenRegistry.isTokenSupported(address(0)); } }