diff --git a/piccolo_api/crud/endpoints.py b/piccolo_api/crud/endpoints.py index bb0832b0..8fb340ba 100644 --- a/piccolo_api/crud/endpoints.py +++ b/piccolo_api/crud/endpoints.py @@ -718,8 +718,38 @@ def _apply_filters( fields = params.fields if fields: model_dict = self.pydantic_model_optional(**fields).dict() + target_foreign_key_columns = { + i: i._meta.name + for i in self.table._meta.foreign_key_columns + if i._foreign_key_meta.target_column is not None + } for field_name in fields.keys(): - value = model_dict.get(field_name, ...) + for key, val in target_foreign_key_columns.items(): + if field_name == val: + target_column_fk_name: t.Any = [ + c._meta.params.get("target_column") + for c in key._foreign_key_meta.resolved_references._meta._foreign_key_references # noqa: E501 + if c._meta.params.get("target_column") is not None + ][0] + reference_table = ( + key._foreign_key_meta.resolved_references + ) + target_column_query: t.Any = ( + reference_table.select() + .where( + reference_table._meta.primary_key + == int(fields[field_name]) + ) + .first() + .run_sync() + ) + value = target_column_query[ + target_column_fk_name._meta.name + ] + break + else: + value = model_dict.get(field_name, ...) + if value is ...: raise MalformedQuery( f"{field_name} isn't a valid field name." @@ -993,7 +1023,16 @@ async def detail(self, request: Request) -> Response: try: row_id = self.table._meta.primary_key.value_type(row_id) except ValueError: - return Response("The ID is invalid", status_code=400) + for i in self.table._meta._foreign_key_references: + target = i._meta.params.get("target_column") + if target is not None: + reference_target_pk: t.Any = ( + await self.table.select(self.table._meta.primary_key) + .where(target == row_id) + .first() + .run() + ) + row_id = reference_target_pk[self.table._meta.primary_key] if ( not await self.table.exists() @@ -1127,7 +1166,6 @@ async def put_single( } try: - await cls.update(values).where( cls._meta.primary_key == row_id ).run() diff --git a/tests/crud/test_crud_endpoints.py b/tests/crud/test_crud_endpoints.py index bcc5d0b5..d8e557df 100644 --- a/tests/crud/test_crud_endpoints.py +++ b/tests/crud/test_crud_endpoints.py @@ -303,8 +303,7 @@ def test_get_ids(self): self.assertEqual(response.status_code, 200) # Make sure the content is correct: - response_json = response.json() - self.assertEqual(response_json[str(movie.id)], "Star Wars") + self.assertEqual(response.json(), {"1": "Star Wars"}) def test_get_ids_with_search(self): """ diff --git a/tests/crud/test_custom_pk.py b/tests/crud/test_custom_pk.py index be7e67c9..aaf8cc3f 100644 --- a/tests/crud/test_custom_pk.py +++ b/tests/crud/test_custom_pk.py @@ -85,8 +85,3 @@ def test_patch(self): self.assertEqual( movie, {"id": self.movie.id, "name": "Star Wars", "rating": 2000} ) - - def test_invalid_id(self): - response = self.client.get("/abc123/") - self.assertEqual(response.status_code, 400) - self.assertEqual(response.content, b"The ID is invalid") diff --git a/tests/crud/test_target_column_pk.py b/tests/crud/test_target_column_pk.py new file mode 100644 index 00000000..5961aedd --- /dev/null +++ b/tests/crud/test_target_column_pk.py @@ -0,0 +1,51 @@ +from unittest import TestCase + +from piccolo.columns.column_types import ForeignKey, Varchar +from piccolo.table import Table +from starlette.testclient import TestClient + +from piccolo_api.crud.endpoints import PiccoloCRUD + + +class Serie(Table): + name = Varchar(length=100, unique=True) + + +class Review(Table): + reviewer = Varchar() + serie = ForeignKey(Serie, target_column=Serie.name) + + +class TestTargetPK(TestCase): + """ + Make sure PiccoloCRUD works with Tables with a non-primary key column. + """ + + def setUp(self): + Serie.create_table(if_not_exists=True).run_sync() + Review.create_table(if_not_exists=True).run_sync() + + def tearDown(self): + Review.alter().drop_table().run_sync() + Serie.alter().drop_table().run_sync() + + def test_target_column_pk(self): + Serie(name="Devs").save().run_sync() + Review(reviewer="John Doe", serie="Devs").save().run_sync() + + review = Review.select(Review.serie.id).first().run_sync() + + self.client = TestClient(PiccoloCRUD(table=Serie, read_only=False)) + response = self.client.get("/Devs/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["id"], review["serie.id"]) + + self.client = TestClient(PiccoloCRUD(table=Review, read_only=False)) + response = self.client.get( + "/", params={"serie": f"{review['serie.id']}"} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.json(), + {"rows": [{"id": 1, "reviewer": "John Doe", "serie": "Devs"}]}, + )