Skip to content

Commit 6b61a34

Browse files
authored
Merge pull request #76 from pattern-tech/refactor/ai-v2
feat: Enhance contract ABI fetching to support proxy contracts and im…
2 parents 42ad6d3 + 728e74b commit 6b61a34

2 files changed

Lines changed: 35 additions & 11 deletions

File tree

src/agentflow/providers/chain_scan_tools.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,21 @@ def get_contract_abi(contract_address: str, chain_id: str) -> Dict:
222222
Dict: The contract ABI.
223223
"""
224224
api_key = _get_chain_config(chain_id)["API_KEY"]
225-
return fetch_contract_abi(contract_address, chain_id, api_key)
225+
226+
final_output = {"proxy": [], "implementation": []}
227+
228+
response = fetch_contract_source_code(contract_address, chain_id, api_key)
229+
230+
current_address = contract_address
231+
while int(response.get('Proxy', 0)):
232+
final_output["proxy"].append(json.loads(response["ABI"]))
233+
current_address = response["Implementation"]
234+
response = fetch_contract_source_code(
235+
current_address, chain_id, api_key)
236+
237+
final_output["implementation"].append(json.loads(response["ABI"]))
238+
239+
return final_output
226240

227241

228242
@tool
@@ -663,7 +677,17 @@ def call_contract_function(contract_address: str, chain_id: str, function_name:
663677
try:
664678
# Get the contract ABI
665679
contract_address = Web3.to_checksum_address(contract_address)
666-
abi = fetch_contract_abi(contract_address, chain_id, api_key)
680+
implementation_contract_address = contract_address
681+
# Check if the contract address is a proxy and get the implementation address
682+
if int(fetch_contract_source_code(contract_address, chain_id, api_key).get('Proxy', 0)):
683+
implementation_contract_address = fetch_contract_source_code(
684+
contract_address, chain_id, api_key)["Implementation"]
685+
implementation_contract_address = Web3.to_checksum_address(
686+
implementation_contract_address)
687+
688+
abi = fetch_contract_abi(
689+
implementation_contract_address, chain_id, api_key)
690+
667691
contract = web3.eth.contract(address=contract_address, abi=abi)
668692

669693
# Find the function in the ABI

src/agentflow/providers/moralis_tools.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ def check_chain_supported(chain: str) -> bool:
3232
@handle_exceptions
3333
def get_wallet_token_balances(wallet_address: str, chain: str, output_include: list[str], cursor: str = "") -> dict:
3434
"""
35-
Get token balances for a specific wallet address and their token prices in USD. (paginated)
35+
Get token balances for a specific wallet address in a specific chian. (paginated)
3636
apply decimal conversion for balance
3737
3838
Args:
39-
wallet_address (str): Ethereum wallet address
39+
wallet_address (str): Wallet address
4040
chain (str): The chain ID can be ["eth", "0x1", "polygon", "0x89", "bsc", "0x38", "avalanche", "0xa86a", "fantom", "0xfa", "palm", "0x2a15c308d", "cronos", "0x19", "arbitrum", "0xa4b1", "chiliz", "0x15b38","gnosis", "0x64", "base", "0x2105", "optimism", "0xa", "linea", "0xe708", "moonbeam", "0x504", "moonriver", "0x505", "flow", "0x2eb", "ronin", "0x7e4", "lisk", "0x46f", "pulse", "0x171"]
4141
output_include (list[str]): A list of field names to include in the output.
4242
cursor (str): The cursor returned in the previous response (used for getting the next page). end of page cursor is None
@@ -81,10 +81,10 @@ def get_wallet_token_balances(wallet_address: str, chain: str, output_include: l
8181
@handle_exceptions
8282
def get_wallet_stats(wallet_address: str, chain: str, output_include: list[str]) -> dict:
8383
"""
84-
Get the stats for a wallet address.
84+
Get the stats for a wallet address in a specific chain.
8585
8686
Args:
87-
wallet_address (str): Ethereum wallet address
87+
wallet_address (str): Wallet address
8888
chain (str): The chain ID can be ["eth", "0x1", "polygon", "0x89", "bsc", "0x38", "avalanche", "0xa86a", "fantom", "0xfa", "palm", "0x2a15c308d", "cronos", "0x19", "arbitrum", "0xa4b1", "chiliz", "0x15b38","gnosis", "0x64", "base", "0x2105", "optimism", "0xa", "linea", "0xe708", "moonbeam", "0x504", "moonriver", "0x505", "flow", "0x2eb", "ronin", "0x7e4", "lisk", "0x46f", "pulse", "0x171"]
8989
output_include (list[str]): A list of field names to include in the output.
9090
@@ -117,10 +117,10 @@ def get_wallet_stats(wallet_address: str, chain: str, output_include: list[str])
117117
def get_wallet_history(wallet_address: str, chain: str, output_include: list[str], cursor: str = "") -> dict:
118118
"""
119119
Retrieve the full transaction history of a specified wallet address, including sends, receives, token and NFT transfers
120-
and contract interactions. (paginated & in descending order)
120+
and contract interactions in a specific chain. (paginated & in descending order)
121121
122122
Args:
123-
wallet_address (str): Ethereum wallet address
123+
wallet_address (str): Wallet address
124124
chain (str): The chain ID can be ["eth", "0x1", "polygon", "0x89", "bsc", "0x38", "avalanche", "0xa86a", "fantom", "0xfa", "palm", "0x2a15c308d", "cronos", "0x19", "arbitrum", "0xa4b1", "chiliz", "0x15b38","gnosis", "0x64", "base", "0x2105", "optimism", "0xa", "linea", "0xe708", "moonbeam", "0x504", "moonriver", "0x505", "flow", "0x2eb", "ronin", "0x7e4", "lisk", "0x46f", "pulse", "0x171"]
125125
output_include (list[str]): A list of field names to include in the output.
126126
cursor (str): The cursor returned in the previous response (used for getting the next page). end of page cursor is None
@@ -165,7 +165,7 @@ def get_wallet_history(wallet_address: str, chain: str, output_include: list[str
165165
@handle_exceptions
166166
def get_transaction_detail(transaction_hash: str, chain: str, output_include: list[str]) -> dict:
167167
"""
168-
Get the contents of a transaction by the given transaction hash.
168+
Get the contents of a transaction for a specific chain by the given transaction hash.
169169
170170
Args:
171171
transaction_hash (str): transaction hash to be decoded
@@ -205,10 +205,10 @@ def get_transaction_detail(transaction_hash: str, chain: str, output_include: li
205205
@handle_exceptions
206206
def get_token_approvals(wallet_address: str, chain: str, output_include: list[str], cursor: str = "") -> dict:
207207
"""
208-
Get ERC20 approvals for one or many wallet addresses and/or contract addresses, ordered by block number in descending order.
208+
Get ERC20 approvals for one or many wallet addresses for a specific chain and/or contract addresses, ordered by block number in descending order.
209209
210210
Args:
211-
wallet_address (str): Ethereum wallet address
211+
wallet_address (str): Wallet address
212212
chain (str): The chain ID can be ["eth", "0x1", "polygon", "0x89", "bsc", "0x38", "avalanche", "0xa86a", "fantom", "0xfa", "palm", "0x2a15c308d", "cronos", "0x19", "arbitrum", "0xa4b1", "chiliz", "0x15b38","gnosis", "0x64", "base", "0x2105", "optimism", "0xa", "linea", "0xe708", "moonbeam", "0x504", "moonriver", "0x505", "flow", "0x2eb", "ronin", "0x7e4", "lisk", "0x46f", "pulse", "0x171"]
213213
output_include (list[str]): A list of field names to include in the output.
214214
cursor (str): The cursor returned in the previous response (used for getting the next page). end of page cursor is None

0 commit comments

Comments
 (0)