Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backend/annotation/serializers_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from typing import Any

from django.utils import timezone
from django.contrib.auth.models import Permission
from rest_framework.test import APIRequestFactory

Expand All @@ -11,7 +10,6 @@
LabelAnnotationSerializer,
SaveLabelsInputSerializer,
)
from problem.serializers import ProblemSerializer


@pytest.mark.django_db
Expand Down
35 changes: 35 additions & 0 deletions backend/problem/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from problem.models import Problem, Sentence


@pytest.fixture
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've lifted these fixtures here so they are available to more test files.

def hypothesis_sentence(db):
return Sentence.objects.create(text="Hypothesis")


@pytest.fixture
def premise_sentence(db):
return Sentence.objects.create(text="Premise")


@pytest.fixture
def user_problem(db, hypothesis_sentence, premise_sentence):
problem = Problem.objects.create(
dataset=Problem.Dataset.USER,
hypothesis=hypothesis_sentence,
extra_data={},
)
problem.premises.add(premise_sentence)
return problem


@pytest.fixture
def non_user_problem(db, hypothesis_sentence, premise_sentence):
problem = Problem.objects.create(
dataset=Problem.Dataset.SICK,
hypothesis=hypothesis_sentence,
extra_data={},
)
problem.premises.add(premise_sentence)
return problem
16 changes: 16 additions & 0 deletions backend/problem/migrations/0010_problem_gold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("problem", "0009_problem_hidden"),
]

operations = [
migrations.AddField(
model_name="problem",
name="gold",
field=models.BooleanField(default=False),
),
]
23 changes: 23 additions & 0 deletions backend/problem/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class EntailmentLabel(models.TextChoices):
CONTRADICTION = "contradiction", "Contradiction"
UNKNOWN = "unknown", "Unknown"

class Status(models.TextChoices):
GOLD = "gold", "Gold"
SILVER = "silver", "Silver"
BRONZE = "bronze", "Bronze"

dataset = models.CharField(
max_length=255,
choices=Dataset.choices,
Expand Down Expand Up @@ -55,6 +60,8 @@ class EntailmentLabel(models.TextChoices):

hidden = models.BooleanField(default=False)

gold = models.BooleanField(default=False)

extra_data = models.JSONField()

class Meta:
Expand All @@ -76,3 +83,19 @@ def get_index(self, qs: QuerySet) -> int | None:
except Exception as e:
logger.exception(f"Error getting index for problem {self.pk}: {e}")
return None

@property
def status(self) -> "Problem.Status":
"""
Returns the computed status of this problem:
- GOLD if the problem is marked as gold.
- SILVER if not gold but has active annotations (KB items or labels).
- BRONZE otherwise (no annotations).
"""
if self.gold:
return Problem.Status.GOLD
has_annotations = (
self.knowledgebaseannotations.filter(removed_at__isnull=True).exists()
or self.labelannotations.filter(removed_at__isnull=True).exists()
)
return Problem.Status.SILVER if has_annotations else Problem.Status.BRONZE
87 changes: 87 additions & 0 deletions backend/problem/models_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest

from annotation.models import KnowledgeBaseAnnotation, LabelAnnotation
from problem.models import Problem


@pytest.fixture
def kb_annotation(db, annotator_session, non_user_problem):
return KnowledgeBaseAnnotation.objects.create(
problem=non_user_problem,
entity1="dog",
entity2="canine",
relationship=KnowledgeBaseAnnotation.Relationship.EQUAL,
Comment on lines +11 to +13
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zoology nitpick: dogs are a subset of canines.

session=annotator_session,
created_by=annotator_session.user,
)

@pytest.mark.django_db
def test_status_bronze_no_annotations(db, non_user_problem):
"""
A problem with no annotations and gold=False is bronze.

This also functions as an initial assumption check for the tests below.
"""
assert non_user_problem.status == Problem.Status.BRONZE


@pytest.mark.django_db
def test_status_gold_without_annotation(db, non_user_problem):
"""A problem with gold=True is gold without annotations."""
non_user_problem.gold = True
non_user_problem.save()
assert non_user_problem.status == Problem.Status.GOLD

@pytest.mark.django_db
def test_status_gold_with_annotation(db, non_user_problem, kb_annotation):
"""A problem with gold=True is gold even with annotations."""
non_user_problem.gold = True
non_user_problem.save()
assert non_user_problem.status == Problem.Status.GOLD


@pytest.mark.django_db
def test_status_silver_with_kb_annotation(db, non_user_problem, kb_annotation):
"""A problem with an active KB annotation and gold=False is silver."""
assert non_user_problem.status == Problem.Status.SILVER

@pytest.mark.django_db
def test_status_silver_with_label_annotation(db, non_user_problem, annotator_session, sample_label):
"""A problem with an active label annotation and gold=False is silver."""
LabelAnnotation.objects.create(
problem=non_user_problem,
label=sample_label,
session=annotator_session,
created_by=annotator_session.user,
)
assert non_user_problem.status == Problem.Status.SILVER


@pytest.mark.django_db
def test_status_bronze_when_all_annotations_removed(db, non_user_problem, annotator_session):
"""A problem whose only annotation is removed reverts to bronze."""
from django.utils import timezone

kb = KnowledgeBaseAnnotation.objects.create(
problem=non_user_problem,
entity1="dog",
entity2="canine",
relationship=KnowledgeBaseAnnotation.Relationship.EQUAL,
session=annotator_session,
created_by=annotator_session.user,
)
kb.removed_at = timezone.now()
kb.removed_by = annotator_session.user
kb.save()

assert non_user_problem.status == Problem.Status.BRONZE


@pytest.mark.django_db
def test_status_serialized(db, non_user_problem):
"""The status field is correctly serialized."""
from problem.serializers import ProblemSerializer

serializer = ProblemSerializer(non_user_problem)
assert serializer.data["status"] == Problem.Status.BRONZE
assert serializer.data["gold"] is False
Comment on lines +86 to +87
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why go with == on one line and is on the next?

40 changes: 36 additions & 4 deletions backend/problem/problem_details.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass

from django.db.models import Exists, OuterRef
from django.http import QueryDict
from django.db.models import QuerySet, Q

Expand Down Expand Up @@ -59,7 +60,6 @@ def get_filters(query_params: QueryDict, user: User | None = None) -> Q | None:
"""
dataset = query_params.get("dataset")
entailment_label = query_params.get("entailmentLabel")
gold = query_params.get("gold")
text = query_params.get("text")
hidden = query_params.get("hidden", None)

Expand All @@ -70,9 +70,6 @@ def get_filters(query_params: QueryDict, user: User | None = None) -> Q | None:
filters &= Q(dataset=dataset)
if entailment_label:
filters &= Q(entailment_label=entailment_label)
if gold:
logger.warning(f"Filtering by gold is not implemented yet.")
pass
if text:
filters &= Q(
Q(hypothesis__text__icontains=text) | Q(premises__text__icontains=text)
Expand All @@ -84,3 +81,38 @@ def get_filters(query_params: QueryDict, user: User | None = None) -> Q | None:
filters &= Q(hidden=hidden.lower() == 'true')

return filters


def apply_status_filter(
qs: QuerySet[Problem], query_params: QueryDict
) -> QuerySet[Problem]:
"""
Applies a status filter to the queryset based on the 'status' query parameter.
Returns the queryset unchanged if no valid status is provided.
"""
from annotation.models import KnowledgeBaseAnnotation, LabelAnnotation

status_param = query_params.get("status")
if not status_param:
return qs

has_active_kb = Exists(
KnowledgeBaseAnnotation.objects.filter(
problem=OuterRef("pk"), removed_at__isnull=True
)
)
has_active_label = Exists(
LabelAnnotation.objects.filter(
problem=OuterRef("pk"), removed_at__isnull=True
)
)

match status_param:
case Problem.Status.GOLD:
return qs.filter(gold=True)
case Problem.Status.SILVER:
return qs.filter(gold=False).filter(has_active_kb | has_active_label)
case Problem.Status.BRONZE:
return qs.filter(gold=False).exclude(has_active_kb | has_active_label)
case _:
return qs
3 changes: 3 additions & 0 deletions backend/problem/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ProblemSerializer(serializers.ModelSerializer):
hypothesis = serializers.SerializerMethodField()
entailmentLabel = serializers.CharField(source="entailment_label")
extraData = serializers.SerializerMethodField()
status = serializers.CharField(read_only=True)

class Meta:
model = Problem
Expand All @@ -33,6 +34,8 @@ class Meta:
"extraData",
"base",
"hidden",
"gold",
"status",
]

def get_premises(self, problem: Problem):
Expand Down
33 changes: 0 additions & 33 deletions backend/problem/serializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,6 @@

from annotation.models import KnowledgeBaseAnnotation
from .serializers import ProblemInputSerializer
from .models import Problem, Sentence


@pytest.fixture
def hypothesis_sentence(db):
return Sentence.objects.create(text="Hypothesis")


@pytest.fixture
def premise_sentence(db):
return Sentence.objects.create(text="Premise")


@pytest.fixture
def user_problem(db, hypothesis_sentence, premise_sentence):
problem = Problem.objects.create(
dataset=Problem.Dataset.USER,
hypothesis=hypothesis_sentence,
extra_data={},
)
problem.premises.add(premise_sentence)
return problem


@pytest.fixture
def non_user_problem(db, hypothesis_sentence, premise_sentence):
problem = Problem.objects.create(
dataset=Problem.Dataset.SICK,
hypothesis=hypothesis_sentence,
extra_data={},
)
problem.premises.add(premise_sentence)
return problem


@pytest.mark.django_db
Expand Down
29 changes: 29 additions & 0 deletions backend/problem/views/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from problem.problem_details import (
get_filters,
apply_status_filter,
get_related_problem_ids,
)
from problem.models import Problem
Expand Down Expand Up @@ -45,6 +46,11 @@ def has_permission(self, request, view):
return super().has_permission(request, view) and request.user.can_change_problem_visibility


class ChangeProblemStatusPermission(IsAuthenticated):
def has_permission(self, request, view):
return super().has_permission(request, view) and request.user.can_change_problem_status


class ProblemView(ModelViewSet):
queryset = Problem.objects.all()
serializer_class = ProblemSerializer
Expand All @@ -56,6 +62,8 @@ def get_permissions(self):
return [EditProblemPermission()]
if self.action == "set_visibility":
return [ChangeProblemVisibilityPermission()]
if self.action == "set_status":
return [ChangeProblemStatusPermission()]
return [IsAuthenticatedOrReadOnly()]

def list(self, request: Request) -> Response:
Expand All @@ -69,9 +77,28 @@ def list(self, request: Request) -> Response:
if filters is not None:
qs = qs.filter(filters)

qs = apply_status_filter(qs, request.query_params)

serializer = self.get_serializer(qs, many=True)
return Response(serializer.data, status=HTTP_200_OK)

@action(detail=True, methods=["post"], url_path="set-status")
def set_status(self, request: Request, pk: int) -> Response:
"""
Toggles the gold status of a Problem.
Expects a JSON body with a boolean 'gold' field.
"""
problem = get_object_or_404(Problem, id=pk)
gold = request.data.get("gold")
if not isinstance(gold, bool):
return Response(
{"detail": "'gold' must be a boolean."},
status=HTTP_400_BAD_REQUEST,
)
Comment on lines +93 to +97
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this validation come with DRF out of the box?

problem.gold = gold
problem.save(update_fields=["gold"])
return Response({"gold": problem.gold, "status": problem.status}, status=HTTP_200_OK)

@action(detail=False, methods=["get"], url_path="first")
def first(self, request: Request) -> Response:
"""
Expand Down Expand Up @@ -115,6 +142,8 @@ def _get_problem_response(self, request: Request, pk: int | None) -> Response:
if filters is not None:
qs = qs.filter(filters).distinct()

qs = apply_status_filter(qs, request.query_params)

problem = None
if pk is not None:
try:
Expand Down
7 changes: 7 additions & 0 deletions backend/user/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def can_change_problem_visibility(self) -> bool:
"""
return self.has_perm("problem.change_problem_visibility")

@property
def can_change_problem_status(self) -> bool:
"""
Determines whether the user can change problem status (gold/ungold).
"""
return self.has_perm("problem.change_problem_status")

@property
def can_edit_kb(self) -> bool:
"""
Expand Down
Loading
Loading