diff --git a/README.md b/README.md
index e58d107f..6359b2d0 100644
--- a/README.md
+++ b/README.md
@@ -52,6 +52,7 @@ AgentLab Features:
| Benchmark | Setup
Link | # Task
Template| Seed
Diversity | Max
Step | Multi-tab | Hosted Method | BrowserGym
Leaderboard |
|-----------|------------|---------|----------------|-----------|-----------|---------------|----------------------|
| [WebArena](https://webarena.dev/) | [setup](https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/webarena/README.md) | 812 | None | 30 | yes | self hosted (docker) | soon |
+| [WebArena-Verified](https://github.com/ServiceNow/webarena-verified) | [setup](https://github.com/ServiceNow/BrowserGym/blob/wa_verified/browsergym/webarena_verified/README.md) | 812 | None | 30 | yes | self hosted | soon |
| [WorkArena](https://github.com/ServiceNow/WorkArena) L1 | [setup](https://github.com/ServiceNow/WorkArena?tab=readme-ov-file#getting-started) | 33 | High | 30 | no | demo instance | soon |
| [WorkArena](https://github.com/ServiceNow/WorkArena) L2 | [setup](https://github.com/ServiceNow/WorkArena?tab=readme-ov-file#getting-started) | 341 | High | 50 | no | demo instance | soon |
| [WorkArena](https://github.com/ServiceNow/WorkArena) L3 | [setup](https://github.com/ServiceNow/WorkArena?tab=readme-ov-file#getting-started) | 341 | High | 50 | no | demo instance | soon |
diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py
index f69322a6..75dd9f40 100644
--- a/src/agentlab/experiments/loop.py
+++ b/src/agentlab/experiments/loop.py
@@ -915,6 +915,11 @@ def _get_env_name(task_name: str):
elif task_name.startswith("webarena"):
import browsergym.webarena
import browsergym.webarenalite
+
+ try:
+ import browsergym.webarena_verified
+ except ImportError:
+ logger.warning("browsergym.webarena_verified not found. Skipping import.")
elif task_name.startswith("visualwebarena"):
import browsergym.visualwebarena
elif task_name.startswith("assistantbench"):
diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py
index 188747ac..d69147d7 100644
--- a/src/agentlab/llm/chat_api.py
+++ b/src/agentlab/llm/chat_api.py
@@ -433,7 +433,7 @@ def __init__(
min_retry_wait_time=min_retry_wait_time,
client_class=OpenAI,
client_args=client_args,
- pricing_func=tracking.get_pricing_openai,
+ pricing_func=tracking.partial(tracking.get_pricing_litellm, model_name=model_name),
log_probs=log_probs,
)
@@ -492,6 +492,7 @@ def __init__(
temperature=0.5,
max_tokens=100,
max_retry=4,
+ pricing_func=None,
):
self.model_name = model_name
self.temperature = temperature
@@ -501,6 +502,22 @@ def __init__(
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic(api_key=api_key)
+ # Get pricing information
+ if pricing_func:
+ pricings = pricing_func()
+ try:
+ self.input_cost = float(pricings[model_name]["prompt"])
+ self.output_cost = float(pricings[model_name]["completion"])
+ except KeyError:
+ logging.warning(
+ f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
+ )
+ self.input_cost = 0.0
+ self.output_cost = 0.0
+ else:
+ self.input_cost = 0.0
+ self.output_cost = 0.0
+
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
# Convert OpenAI format to Anthropic format
system_message = None
@@ -528,13 +545,29 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
response = self.client.messages.create(**kwargs)
+ usage = getattr(response, "usage", {})
+ new_input_tokens = getattr(usage, "input_tokens", 0)
+ output_tokens = getattr(usage, "output_tokens", 0)
+ cache_read_tokens = getattr(usage, "cache_input_tokens", 0)
+ cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0)
+ cache_read_cost = (
+ self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"]
+ )
+ cache_write_cost = (
+ self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"]
+ )
+ cost = (
+ new_input_tokens * self.input_cost
+ + output_tokens * self.output_cost
+ + cache_read_tokens * cache_read_cost
+ + cache_write_tokens * cache_write_cost
+ )
+
# Track usage if available
- if hasattr(tracking.TRACKER, "instance"):
- tracking.TRACKER.instance(
- response.usage.input_tokens,
- response.usage.output_tokens,
- 0, # cost calculation would need pricing info
- )
+ if hasattr(tracking.TRACKER, "instance") and isinstance(
+ tracking.TRACKER.instance, tracking.LLMTracker
+ ):
+ tracking.TRACKER.instance(new_input_tokens, output_tokens, cost)
return AIMessage(response.content[0].text)
@@ -552,6 +585,7 @@ def make_model(self):
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
+ pricing_func=partial(tracking.get_pricing_litellm, model_name=self.model_name),
)