Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions kloppy/domain/services/transformers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def change_point_dimensions(
point_base = self._from_pitch_dimensions.to_metric_base(
point, pitch_length=base_pitch_length, pitch_width=base_pitch_width
)
print(point_base)
print(self._to_pitch_dimensions.from_metric_base)
point_to = self._to_pitch_dimensions.from_metric_base(
point=point_base,
pitch_length=base_pitch_length,
Expand Down Expand Up @@ -329,8 +327,8 @@ def transform_event(self, event: Event) -> Event:
):
event = self.__flip_event(event)

if event.freeze_frame:
event.freeze_frame = self.transform_frame(event.freeze_frame)
if event.freeze_frame:
event.freeze_frame = self.transform_frame(event.freeze_frame)

return event

Expand Down
61 changes: 40 additions & 21 deletions kloppy/infra/serializers/event/statsbomb/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Orientation,
Period,
Player,
PositionType,
Provider,
Team,
)
Expand Down Expand Up @@ -77,6 +78,29 @@ def deserialize(
)
for event in new_events:
if self.should_include_event(event):
if "freeze_frame" in event.raw_event.get("shot", {}):
event.freeze_frame = parse_freeze_frame(
freeze_frame=event.raw_event["shot"][
"freeze_frame"
],
home_team=teams[0],
away_team=teams[1],
event=event,
fidelity_version=data_version.shot_fidelity_version,
)
if (
not event.freeze_frame
and event.event_id in three_sixty_data
):
freeze_frame = three_sixty_data[event.event_id]
event.freeze_frame = parse_freeze_frame(
freeze_frame=freeze_frame["freeze_frame"],
home_team=teams[0],
away_team=teams[1],
event=event,
fidelity_version=data_version.xy_fidelity_version,
visible_area=freeze_frame["visible_area"],
)
# Transform event to the coordinate system
event = self.transformer.transform_event(event)
events.append(event)
Expand All @@ -102,29 +126,24 @@ def deserialize(
**additional_metadata,
)
dataset = EventDataset(metadata=metadata, records=events)
# We can now update GK identities in the freeze frames
# because we know the positions of the GKs at the event times
for event in dataset:
if "freeze_frame" in event.raw_event.get("shot", {}):
event.freeze_frame = self.transformer.transform_frame(
parse_freeze_frame(
freeze_frame=event.raw_event["shot"]["freeze_frame"],
home_team=teams[0],
away_team=teams[1],
event=event,
fidelity_version=data_version.shot_fidelity_version,
)
)
if not event.freeze_frame and event.event_id in three_sixty_data:
freeze_frame = three_sixty_data[event.event_id]
event.freeze_frame = self.transformer.transform_frame(
parse_freeze_frame(
freeze_frame=freeze_frame["freeze_frame"],
home_team=teams[0],
away_team=teams[1],
event=event,
fidelity_version=data_version.xy_fidelity_version,
visible_area=freeze_frame["visible_area"],
if not event.freeze_frame:
continue

new_players_data = {}
for player, data in event.freeze_frame.players_data.items():
if player.attributes.get("goalkeeper", False):
actual_gk = player.team.get_player_by_position(
position=PositionType.Goalkeeper,
time=event.time,
)
)
new_players_data[actual_gk] = data
else:
new_players_data[player] = data

event.freeze_frame.players_data = new_players_data
return dataset

def load_data(self, inputs: StatsBombInputs):
Expand Down
25 changes: 20 additions & 5 deletions kloppy/infra/serializers/event/statsbomb/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
PlayerData,
Point,
Point3D,
PositionType,
Team,
)
from kloppy.domain.services.frame_factory import create_frame
Expand Down Expand Up @@ -107,18 +106,34 @@ def parse_freeze_frame(

def get_player_from_freeze_frame(player_data, team, i):
if "player" in player_data:
return team.get_player_by_id(player_data["player"]["id"])
elif player_data.get("actor"):
home_player = home_team.get_player_by_id(
player_data["player"]["id"]
)
if home_player:
return home_player
away_player = away_team.get_player_by_id(
player_data["player"]["id"]
)
if away_player:
return away_player

if player_data.get("actor"):
return event.player
elif player_data.get("keeper"):
return team.get_player_by_position(
position=PositionType.Goalkeeper, time=event.time
# We can later identify the goalkeeper by their position
# if we know the formation, but for now we just flag them
return Player(
player_id=f"T{team.team_id}-E{event.event_id}-{i}",
team=team,
jersey_no=None,
attributes={"goalkeeper": True},
)
else:
return Player(
player_id=f"T{team.team_id}-E{event.event_id}-{i}",
team=team,
jersey_no=None,
attributes={"goalkeeper": False},
)

for i, freeze_frame_player in enumerate(freeze_frame):
Expand Down
9 changes: 9 additions & 0 deletions kloppy/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module to store common fixtures."""

import os
from pathlib import Path

import pytest
Expand All @@ -8,3 +9,11 @@
@pytest.fixture(scope="session")
def base_dir() -> Path:
return Path(__file__).parent


@pytest.fixture(scope="session")
def with_visualization(base_dir):
enable_viz = os.environ.get("KLOPPY_TESTWITHVIZ") == "1"
if enable_viz:
(base_dir / "outputs").mkdir(exist_ok=True)
return enable_viz
42 changes: 20 additions & 22 deletions kloppy/tests/test_statsbomb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import defaultdict
from datetime import datetime, timedelta, timezone
import os
from pathlib import Path
from typing import cast

Expand Down Expand Up @@ -51,19 +50,9 @@
from kloppy.infra.serializers.event.statsbomb.helpers import parse_str_ts
import kloppy.infra.serializers.event.statsbomb.specification as SB

ENABLE_PLOTTING = True
API_URL = "https://raw.githubusercontent.com/statsbomb/open-data/master/data/"


def test_with_visualization():
if (
"KLOPPY_TESTWITHVIZ" in os.environ
and os.environ["KLOPPY_TESTWITHVIZ"] == "1"
):
return True
return False


@pytest.fixture(scope="module")
def dataset() -> EventDataset:
"""Load StatsBomb data for Belgium - Portugal at Euro 2020"""
Expand Down Expand Up @@ -318,7 +307,9 @@ def test_synthetic_out_events(self, dataset: EventDataset):

assert ball_out_events[0].ball_state == BallState.DEAD

def test_freeze_frame_shot(self, dataset: EventDataset, base_dir: Path):
def test_freeze_frame_shot(
self, dataset: EventDataset, base_dir: Path, with_visualization: bool
):
"""Test if shot freeze-frame is properly parsed and attached to shot events"""
shot_event = dataset.get_event_by_id(
"a5c60797-631e-418a-9f24-1e9779cb2b42"
Expand All @@ -344,7 +335,7 @@ def test_freeze_frame_shot(self, dataset: EventDataset, base_dir: Path):
91.45, 28.15
)

if test_with_visualization():
if with_visualization:
import matplotlib.pyplot as plt
from mplsoccer import VerticalPitch

Expand All @@ -359,7 +350,10 @@ def test_freeze_frame_shot(self, dataset: EventDataset, base_dir: Path):
def get_color(player):
if player.team == shot_event.player.team:
return "#b94b75"
elif player.starting_position.position_id == "1":
elif (
player.starting_position.position_group
== PositionType.Goalkeeper
):
return "#c15ca5"
else:
return "#7f63b8"
Expand Down Expand Up @@ -426,7 +420,9 @@ def get_color(player):
base_dir / "outputs" / "test_statsbomb_freeze_frame_shot.png"
)

def test_freeze_frame_360(self, dataset: EventDataset, base_dir: Path):
def test_freeze_frame_360(
self, dataset: EventDataset, base_dir: Path, with_visualization: bool
):
"""Test if 360 freeze-frame is properly parsed and attached to shot events"""
pass_event = dataset.get_event_by_id(
"8022c113-e349-4b0b-b4a7-a3bb662535f8"
Expand Down Expand Up @@ -502,7 +498,7 @@ def test_freeze_frame_360(self, dataset: EventDataset, base_dir: Path):
abs=1e-2,
)

if test_with_visualization():
if with_visualization:
import matplotlib.pyplot as plt
from mplsoccer import Pitch

Expand Down Expand Up @@ -610,13 +606,20 @@ def test_correct_normalized_deserialization(self):
pass_event = dataset.get_event_by_id(
"8022c113-e349-4b0b-b4a7-a3bb662535f8"
)
assert (
pass_event.coordinates.x
== pass_event.freeze_frame.ball_coordinates.x
)
assert (
pass_event.coordinates.y
== pass_event.freeze_frame.ball_coordinates.y
)
coordinates_per_team = defaultdict(list)
for (
player,
coordinates,
) in pass_event.freeze_frame.players_coordinates.items():
coordinates_per_team[player.team.name].append(coordinates)
print(coordinates_per_team)
assert coordinates_per_team == {
"Belgium": [
Point(x=0.30230680550305883, y=0.5224074534269804),
Expand Down Expand Up @@ -1233,11 +1236,6 @@ def test_player_position(self, base_dir):
event_data=base_dir / "files/statsbomb_event.json",
)

for item in dataset.aggregate("minutes_played", include_position=True):
print(
f"{item.player} {item.player.player_id}- {item.start_time} - {item.end_time} - {item.duration} - {item.position}"
)

home_team, away_team = dataset.metadata.teams
period1, period2 = dataset.metadata.periods

Expand Down
Loading