Skip to content

Commit ce96b8b

Browse files
committed
feat(multi-atm): linear accrual interpolated
1 parent d26f97c commit ce96b8b

2 files changed

Lines changed: 2792 additions & 0 deletions

File tree

Lines changed: 398 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,398 @@
1+
// SPDX-License-Identifier: MIT
2+
3+
pragma solidity ^0.8.27;
4+
5+
import { IAuthority } from "@openzeppelin/contracts/access/manager/IAuthority.sol";
6+
import { IERC20 } from "@openzeppelin/contracts/interfaces/IERC20.sol";
7+
import { IERC20Metadata } from "@openzeppelin/contracts/interfaces/IERC20Metadata.sol";
8+
import { ERC2771Context } from "@openzeppelin/contracts/metatx/ERC2771Context.sol";
9+
import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol";
10+
import { Hashes } from "@openzeppelin/contracts/utils/cryptography/Hashes.sol";
11+
import { Math } from "@openzeppelin/contracts/utils/math/Math.sol";
12+
import { SafeCast } from "@openzeppelin/contracts/utils/math/SafeCast.sol";
13+
import { SignedMath } from "@openzeppelin/contracts/utils/math/SignedMath.sol";
14+
import { Context } from "@openzeppelin/contracts/utils/Context.sol";
15+
import { Multicall } from "@openzeppelin/contracts/utils/Multicall.sol";
16+
import { Oracle } from "../oracle/Oracle.sol";
17+
import { PermissionManaged } from "../permissions/PermissionManaged.sol";
18+
19+
contract MultiATMLinearInterpolated is ERC2771Context, PermissionManaged, Multicall {
20+
using Math for *;
21+
using SafeCast for *;
22+
23+
uint256 private constant _BASIS_POINT_SCALE = 1e4;
24+
uint256 private constant _PRECISION = 1e18;
25+
uint8 private constant _MAX_REGRESSION_POINTS = 30;
26+
27+
struct Pair {
28+
IERC20 token1;
29+
IERC20 token2;
30+
Oracle oracle;
31+
uint256 oracleTTL;
32+
uint256 numerator;
33+
uint256 denominator;
34+
uint8 accrualRounds;
35+
}
36+
// Numerator and denominator account for the difference in decimals between the two tokens AND for the decimals
37+
// of the oracle. They are used to scale the conversion rate between the two tokens.
38+
//
39+
// For example, if token A has 18 decimals and token B has 6 decimals, and the oracle has 8 decimals, then
40+
// - 1 token A correspond to 10**18 units (wei),
41+
// - 1 token B correspond to 10**6 units (wei),
42+
// - the rate provided by the oracle must be divided by 10**8.
43+
//
44+
// Therefore:
45+
// (<Amount of token A> / 10**18) * (rate / 10**8) = (<Amount of token B> / 10**6)
46+
// i.e. <Amount of token A> * rate * 10**6 = <Amount of token B> * 10**(18 + 8)
47+
//
48+
// Which gives us the following conversion rate:
49+
// * <Amount of token A> * rate * <numerator> / <denominator> = <Amount of token B>
50+
// * <Amount of token B> / rate / <numerator> * <denominator> = <Amount of token A>
51+
//
52+
// with:
53+
// * numerator = 10**<decimals of token B>
54+
// * denominator = 10**(<decimals of token A> + <decimals of oracle>).
55+
56+
mapping(bytes32 id => Pair) private _pairs;
57+
uint256 public feeBasisPoints;
58+
59+
event SwapExact(
60+
IERC20 indexed input,
61+
IERC20 indexed output,
62+
uint256 inputAmount,
63+
uint256 outputAmount,
64+
address from,
65+
address to
66+
);
67+
event PairUpdated(
68+
bytes32 indexed id,
69+
IERC20 indexed token1,
70+
IERC20 indexed token2,
71+
Oracle oracle,
72+
uint256 oracleTTL,
73+
uint8 accrualRounds
74+
);
75+
event PairRemoved(bytes32 indexed id);
76+
event FeeUpdated(uint256 newFeeBasisPoints);
77+
error OutputAmountTooLow(uint256 outputAmount, uint256 minOutputAmount);
78+
error InputAmountTooHigh(uint256 inputAmount, uint256 maxInputAmount);
79+
error OracleValueTooOld(Oracle oracle);
80+
error UnknownPair(IERC20 input, IERC20 output);
81+
error InvalidFee(uint256 feeBasisPoints);
82+
error InvalidAccrualRounds(uint8 accrualRounds);
83+
error InvalidOracleData();
84+
85+
/// @custom:oz-upgrades-unsafe-allow constructor
86+
constructor(
87+
IAuthority _authority,
88+
address _trustedForwarder
89+
) PermissionManaged(_authority) ERC2771Context(_trustedForwarder) {}
90+
91+
/****************************************************************************************************************
92+
* Getters *
93+
****************************************************************************************************************/
94+
function viewPairDetails(
95+
IERC20 input,
96+
IERC20 output
97+
)
98+
public
99+
view
100+
virtual
101+
returns (
102+
bytes32 id,
103+
IERC20 token1,
104+
IERC20 token2,
105+
Oracle oracle,
106+
uint256 oracleTTL,
107+
uint256 numerator,
108+
uint256 denominator,
109+
uint8 accrualRounds
110+
)
111+
{
112+
id = hashPair(input, output);
113+
Pair storage pair = _pairs[id];
114+
115+
return (
116+
id,
117+
pair.token1,
118+
pair.token2,
119+
pair.oracle,
120+
pair.oracleTTL,
121+
pair.numerator,
122+
pair.denominator,
123+
pair.accrualRounds
124+
);
125+
}
126+
127+
function hashPair(IERC20 input, IERC20 output) public view virtual returns (bytes32) {
128+
return
129+
Hashes.commutativeKeccak256(
130+
bytes32(uint256(uint160(address(input)))),
131+
bytes32(uint256(uint160(address(output))))
132+
);
133+
}
134+
135+
/****************************************************************************************************************
136+
* Core - preview swaps *
137+
****************************************************************************************************************/
138+
function previewExactInput(
139+
IERC20[] memory path,
140+
uint256 inputAmount
141+
) public view virtual returns (uint256 /*outputAmount*/) {
142+
uint256 outputAmount = inputAmount;
143+
for (uint256 i = 0; i < path.length - 1; ++i) {
144+
outputAmount = _exactInput(path[i], path[i + 1], outputAmount);
145+
}
146+
return outputAmount.mulDiv(_BASIS_POINT_SCALE - feeBasisPoints, _BASIS_POINT_SCALE, Math.Rounding.Floor);
147+
}
148+
149+
function previewExactOutput(
150+
IERC20[] memory path,
151+
uint256 outputAmount
152+
) public view virtual returns (uint256 /*inputAmount*/) {
153+
uint256 inputAmount = outputAmount;
154+
for (uint256 i = path.length - 1; i > 0; --i) {
155+
inputAmount = _exactOutput(path[i - 1], path[i], inputAmount);
156+
}
157+
return inputAmount.mulDiv(_BASIS_POINT_SCALE, _BASIS_POINT_SCALE - feeBasisPoints, Math.Rounding.Ceil);
158+
}
159+
160+
function previewExactInputSingle(
161+
IERC20 input,
162+
IERC20 output,
163+
uint256 inputAmount
164+
) public view virtual returns (uint256 /*outputAmount*/) {
165+
return
166+
_exactInput(input, output, inputAmount).mulDiv(
167+
_BASIS_POINT_SCALE - feeBasisPoints,
168+
_BASIS_POINT_SCALE,
169+
Math.Rounding.Floor
170+
);
171+
}
172+
173+
function previewExactOutputSingle(
174+
IERC20 input,
175+
IERC20 output,
176+
uint256 outputAmount
177+
) public view virtual returns (uint256 /*inputAmount*/) {
178+
return
179+
_exactOutput(input, output, outputAmount).mulDiv(
180+
_BASIS_POINT_SCALE,
181+
_BASIS_POINT_SCALE - feeBasisPoints,
182+
Math.Rounding.Ceil
183+
);
184+
}
185+
186+
function _exactInput(
187+
IERC20 input,
188+
IERC20 output,
189+
uint256 inputAmount
190+
) internal view virtual returns (uint256 /*outputAmount*/) {
191+
(
192+
,
193+
IERC20 token1,
194+
,
195+
Oracle oracle,
196+
uint256 oracleTTL,
197+
uint256 numerator,
198+
uint256 denominator,
199+
uint8 accrualRounds
200+
) = viewPairDetails(input, output);
201+
202+
require(address(oracle) != address(0), UnknownPair(input, output));
203+
204+
(int256 minPrice, int256 maxPrice) = _getPrices(oracle, oracleTTL, accrualRounds);
205+
return
206+
inputAmount.mulDiv(
207+
Math.ternary(input == token1, numerator * minPrice.toUint256(), denominator),
208+
Math.ternary(input == token1, denominator, numerator * maxPrice.toUint256()),
209+
Math.Rounding.Floor
210+
);
211+
}
212+
213+
function _exactOutput(
214+
IERC20 input,
215+
IERC20 output,
216+
uint256 outputAmount
217+
) internal view virtual returns (uint256 /*inputAmount*/) {
218+
(
219+
,
220+
IERC20 token1,
221+
,
222+
Oracle oracle,
223+
uint256 oracleTTL,
224+
uint256 numerator,
225+
uint256 denominator,
226+
uint8 accrualRounds
227+
) = viewPairDetails(input, output);
228+
229+
require(address(oracle) != address(0), UnknownPair(input, output));
230+
231+
(int256 minPrice, int256 maxPrice) = _getPrices(oracle, oracleTTL, accrualRounds);
232+
return
233+
outputAmount.mulDiv(
234+
Math.ternary(input == token1, denominator, numerator * maxPrice.toUint256()),
235+
Math.ternary(input == token1, numerator * minPrice.toUint256(), denominator),
236+
Math.Rounding.Ceil
237+
);
238+
}
239+
240+
/****************************************************************************************************************
241+
* Core - execute swaps *
242+
****************************************************************************************************************/
243+
function swapExactInput(
244+
IERC20[] memory path,
245+
uint256 inputAmount,
246+
address recipient,
247+
uint256 minOutputAmount
248+
) public virtual restricted returns (uint256 /*outputAmount*/) {
249+
uint256 outputAmount = previewExactInput(path, inputAmount);
250+
require(outputAmount >= minOutputAmount, OutputAmountTooLow(outputAmount, minOutputAmount));
251+
_swapExact(path[0], path[path.length - 1], inputAmount, outputAmount, _msgSender(), recipient);
252+
return outputAmount;
253+
}
254+
255+
function swapExactInputSingle(
256+
IERC20 input,
257+
IERC20 output,
258+
uint256 inputAmount,
259+
address recipient,
260+
uint256 minOutputAmount
261+
) public virtual restricted returns (uint256 /*outputAmount*/) {
262+
uint256 outputAmount = previewExactInputSingle(input, output, inputAmount);
263+
require(outputAmount >= minOutputAmount, OutputAmountTooLow(outputAmount, minOutputAmount));
264+
_swapExact(input, output, inputAmount, outputAmount, _msgSender(), recipient);
265+
return outputAmount;
266+
}
267+
268+
function swapExactOutput(
269+
IERC20[] memory path,
270+
uint256 outputAmount,
271+
address recipient,
272+
uint256 maxInputAmount
273+
) public virtual restricted returns (uint256 /*inputAmount*/) {
274+
uint256 inputAmount = previewExactOutput(path, outputAmount);
275+
require(inputAmount <= maxInputAmount, InputAmountTooHigh(inputAmount, maxInputAmount));
276+
_swapExact(path[0], path[path.length - 1], inputAmount, outputAmount, _msgSender(), recipient);
277+
return inputAmount;
278+
}
279+
280+
function swapExactOutputSingle(
281+
IERC20 input,
282+
IERC20 output,
283+
uint256 outputAmount,
284+
address recipient,
285+
uint256 maxInputAmount
286+
) public virtual restricted returns (uint256 /*inputAmount*/) {
287+
uint256 inputAmount = previewExactOutputSingle(input, output, outputAmount);
288+
require(inputAmount <= maxInputAmount, InputAmountTooHigh(inputAmount, maxInputAmount));
289+
_swapExact(input, output, inputAmount, outputAmount, _msgSender(), recipient);
290+
return inputAmount;
291+
}
292+
293+
function _swapExact(
294+
IERC20 input,
295+
IERC20 output,
296+
uint256 inputAmount,
297+
uint256 outputAmount,
298+
address from,
299+
address to
300+
) private {
301+
SafeERC20.safeTransferFrom(input, from, address(this), inputAmount);
302+
SafeERC20.safeTransfer(output, to, outputAmount);
303+
emit SwapExact(input, output, inputAmount, outputAmount, from, to);
304+
}
305+
306+
function _getPrices(
307+
Oracle oracle,
308+
uint256 oracleTTL,
309+
uint256 accrualRounds
310+
) internal view virtual returns (int256 min, int256 max) {
311+
(uint80 roundId, int256 latest, , uint256 updatedAt, ) = oracle.latestRoundData();
312+
require(roundId + 1 >= accrualRounds, InvalidOracleData());
313+
require(block.timestamp < updatedAt + oracleTTL, OracleValueTooOld(oracle));
314+
if (accrualRounds == 0) {
315+
(, int256 previous, , , ) = oracle.getRoundData(roundId - 1);
316+
min = SignedMath.min(latest, previous);
317+
max = SignedMath.max(latest, previous);
318+
} else {
319+
int256 sumT = 0;
320+
int256 sumP = 0;
321+
int256 sumTT = 0;
322+
int256 sumTP = 0;
323+
for (uint256 currentRound = roundId + 1 - accrualRounds; currentRound <= roundId; ++currentRound) {
324+
(, latest, , updatedAt, ) = oracle.getRoundData(currentRound.toUint80());
325+
sumT += updatedAt.toInt256();
326+
sumP += latest;
327+
sumTT += updatedAt.toInt256() * updatedAt.toInt256();
328+
sumTP += updatedAt.toInt256() * latest;
329+
}
330+
min = accrualRounds.toInt256();
331+
max = min * sumTP - sumT * sumP;
332+
latest = min * sumTT - sumT * sumT;
333+
require(latest > 0, InvalidOracleData());
334+
min = max = (sumP - (sumT * max) / latest) / min + (block.timestamp.toInt256() * max) / latest;
335+
require(min > 0, InvalidOracleData());
336+
}
337+
}
338+
339+
/****************************************************************************************************************
340+
* Admin actions *
341+
****************************************************************************************************************/
342+
function setPair(
343+
IERC20Metadata token1,
344+
IERC20Metadata token2,
345+
Oracle oracle,
346+
uint256 oracleTTL,
347+
uint8 accrualRounds
348+
) public virtual restricted {
349+
bytes32 id = hashPair(token1, token2);
350+
require(
351+
accrualRounds == 0 || (accrualRounds >= 2 && accrualRounds <= _MAX_REGRESSION_POINTS),
352+
InvalidAccrualRounds(accrualRounds)
353+
);
354+
_pairs[id] = Pair({
355+
token1: token1,
356+
token2: token2,
357+
oracle: oracle,
358+
oracleTTL: oracleTTL,
359+
numerator: 10 ** token2.decimals(),
360+
denominator: 10 ** (token1.decimals() + oracle.decimals()),
361+
accrualRounds: accrualRounds
362+
});
363+
364+
emit PairUpdated(id, token1, token2, oracle, oracleTTL, accrualRounds);
365+
}
366+
367+
function removePair(IERC20 token1, IERC20 token2) public virtual restricted {
368+
bytes32 id = hashPair(token1, token2);
369+
delete _pairs[id];
370+
371+
emit PairRemoved(id);
372+
}
373+
374+
function setFee(uint256 newFeeBasisPoints) public virtual restricted {
375+
require(newFeeBasisPoints <= 50, InvalidFee(newFeeBasisPoints)); // Max 0.5%
376+
feeBasisPoints = newFeeBasisPoints;
377+
emit FeeUpdated(newFeeBasisPoints);
378+
}
379+
380+
function withdraw(IERC20 _token, address _to, uint256 _amount) public virtual restricted {
381+
SafeERC20.safeTransfer(_token, _to, _amount == type(uint256).max ? _token.balanceOf(address(this)) : _amount);
382+
}
383+
384+
/****************************************************************************************************************
385+
* Context overrides *
386+
****************************************************************************************************************/
387+
function _msgSender() internal view override(Context, ERC2771Context) returns (address) {
388+
return super._msgSender();
389+
}
390+
391+
function _msgData() internal view override(Context, ERC2771Context) returns (bytes calldata) {
392+
return super._msgData();
393+
}
394+
395+
function _contextSuffixLength() internal view override(Context, ERC2771Context) returns (uint256) {
396+
return super._contextSuffixLength();
397+
}
398+
}

0 commit comments

Comments
 (0)