1010from django .core .exceptions import ObjectDoesNotExist
1111from django .core .paginator import Paginator
1212from django .db .models import (
13- Count ,
1413 Case ,
14+ Count ,
1515 Exists ,
1616 F ,
1717 IntegerField ,
1818 Max ,
1919 OuterRef ,
2020 Q ,
2121 Subquery ,
22- When ,
2322 Sum ,
23+ When ,
2424)
25+ from django .db .models .functions import Coalesce
2526from django .utils import timezone
2627from rest_framework import status
2728from rest_framework .response import Response
@@ -308,8 +309,17 @@ def validate_thread_and_user(
308309 ValueError: If the thread or user is not found.
309310 """
310311 try :
311- thread = CommentThread .objects .get (pk = int (thread_id ))
312- user = ForumUser .objects .get (user__pk = user_id )
312+ # Optimize: Use select_related to avoid N+1 queries
313+ thread = CommentThread .objects .select_related ("author" , "closed_by" ).get (
314+ pk = int (thread_id )
315+ )
316+ user = (
317+ ForumUser .objects .select_related ("user" )
318+ .prefetch_related (
319+ "user__course_stats" , "user__read_states__last_read_times"
320+ )
321+ .get (user__pk = user_id )
322+ )
313323 except ObjectDoesNotExist as exc :
314324 raise ValueError ("User / Thread doesn't exist" ) from exc
315325
@@ -348,8 +358,17 @@ def get_pinned_unpinned_thread_serialized_data(
348358 Raises:
349359 ValueError: If the serialization is not valid.
350360 """
351- user = ForumUser .objects .get (user__pk = user_id )
352- updated_thread = CommentThread .objects .get (pk = thread_id )
361+ # Optimize: Use select_related to avoid N+1 queries
362+ user = (
363+ ForumUser .objects .select_related ("user" )
364+ .prefetch_related (
365+ "user__course_stats" , "user__read_states__last_read_times"
366+ )
367+ .get (user__pk = user_id )
368+ )
369+ updated_thread = CommentThread .objects .select_related (
370+ "author" , "closed_by"
371+ ).get (pk = thread_id )
353372 user_data = user .to_dict ()
354373 context = {
355374 "user_id" : user_data ["_id" ],
@@ -401,35 +420,41 @@ def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]:
401420 Returns:
402421 dict[str, int]: A dictionary mapping thread IDs to their corresponding abuse-flagged comment count.
403422 """
404- abuse_flagger_count_subquery = (
423+ # Optimize: Use aggregation to count abuse flaggers per thread in bulk
424+ comment_content_type = ContentType .objects .get_for_model (Comment )
425+
426+ # Get all comments for these threads
427+ comment_ids = Comment .objects .filter (
428+ comment_thread__pk__in = thread_ids
429+ ).values_list ("pk" , flat = True )
430+
431+ if not comment_ids :
432+ return {}
433+
434+ # Count abuse flaggers per comment using aggregation
435+ abuse_flagged_counts = (
405436 AbuseFlagger .objects .filter (
406- content_type = ContentType . objects . get_for_model ( Comment ) ,
407- content_object_id = OuterRef ( "pk" ) ,
437+ content_type = comment_content_type ,
438+ content_object_id__in = comment_ids ,
408439 )
409440 .values ("content_object_id" )
410441 .annotate (count = Count ("pk" ))
411- .values ("count" )
412442 )
413443
414- abuse_flagged_comments = (
415- Comment .objects .filter (
416- comment_thread__pk__in = thread_ids ,
417- )
418- .annotate (
419- abuse_flaggers_count = Subquery (
420- abuse_flagger_count_subquery , output_field = IntegerField ()
421- )
444+ # Map comment IDs back to thread IDs
445+ comment_to_thread = dict (
446+ Comment .objects .filter (pk__in = comment_ids ).values_list (
447+ "pk" , "comment_thread_id"
422448 )
423- .filter (abuse_flaggers_count__gt = 0 )
424449 )
425450
426- result = {}
427- for comment in abuse_flagged_comments :
428- thread_pk = str ( comment . comment_thread . pk )
429- if thread_pk not in result :
430- result [ thread_pk ] = 0
431- abuse_flaggers = "abuse_flaggers_count"
432- result [thread_pk ] += getattr ( comment , abuse_flaggers )
451+ result : dict [ str , int ] = {}
452+ for item in abuse_flagged_counts :
453+ comment_id = item [ "content_object_id" ]
454+ thread_id = comment_to_thread . get ( comment_id )
455+ if thread_id :
456+ thread_pk = str ( thread_id )
457+ result [thread_pk ] = result . get ( thread_pk , 0 ) + item [ "count" ]
433458
434459 return result
435460
@@ -457,28 +482,43 @@ def get_read_states(
457482 except User .DoesNotExist :
458483 return read_states
459484
460- threads = CommentThread .objects .filter (pk__in = thread_ids )
485+ # Convert thread_ids to integers for database queries
486+ try :
487+ thread_ids_int = [int (tid ) for tid in thread_ids ]
488+ except (ValueError , TypeError ):
489+ return read_states
490+
491+ threads = CommentThread .objects .filter (pk__in = thread_ids_int ).values (
492+ "pk" , "last_activity_at"
493+ )
494+ thread_dict = {thread ["pk" ]: thread for thread in threads }
495+
461496 read_state = ReadState .objects .filter (user = user , course_id = course_id ).first ()
462497 if not read_state :
463498 return read_states
464499
465- read_dates = read_state .last_read_times
500+ last_read_times = read_state .last_read_times .select_related (
501+ "comment_thread"
502+ ).filter (comment_thread_id__in = thread_ids_int )
466503
467- for thread in threads :
468- read_date = read_dates .filter (comment_thread = thread ).first ()
469- if not read_date :
504+ for read_date in last_read_times :
505+ thread_id = read_date .comment_thread .pk
506+ thread = thread_dict .get (thread_id )
507+ if not thread :
470508 continue
471509
472- last_activity_at = thread . last_activity_at
510+ last_activity_at = thread [ " last_activity_at" ]
473511 is_read = read_date .timestamp >= last_activity_at
512+
513+ # Count unread comments for this thread
474514 unread_comment_count = (
475515 Comment .objects .filter (
476- comment_thread = thread , created_at__gte = read_date .timestamp
516+ comment_thread_id = thread_id , created_at__gte = read_date .timestamp
477517 )
478518 .exclude (author__pk = user_id )
479519 .count ()
480520 )
481- read_states [str (thread . pk )] = [is_read , unread_comment_count ]
521+ read_states [str (thread_id )] = [is_read , unread_comment_count ]
482522
483523 return read_states
484524
@@ -524,11 +564,14 @@ def get_endorsed(thread_ids: list[str]) -> dict[str, bool]:
524564 Returns:
525565 dict[str, bool]: A dictionary of thread IDs to their endorsed status (True if endorsed, False otherwise).
526566 """
527- endorsed_comments = Comment .objects .filter (
528- comment_thread__pk__in = thread_ids , endorsed = True
567+ # Optimize: Use values_list to avoid loading full objects
568+ endorsed_thread_ids = (
569+ Comment .objects .filter (comment_thread__pk__in = thread_ids , endorsed = True )
570+ .values_list ("comment_thread_id" , flat = True )
571+ .distinct ()
529572 )
530573
531- return {str (comment . comment_thread . pk ): True for comment in endorsed_comments }
574+ return {str (thread_id ): True for thread_id in endorsed_thread_ids }
532575
533576 @staticmethod
534577 def get_user_read_state_by_course_id (
@@ -729,24 +772,44 @@ def handle_threads_query(
729772 base_query = base_query .filter (
730773 commentable_id__in = commentable_ids ,
731774 )
775+ # Annotate comments count
732776 base_query = base_query .annotate (
733- votes_point = Sum ("uservote__vote" ),
734- comments_count = Count ("comment" , distinct = True ),
735- )
736-
737- base_query = base_query .annotate (
738- votes_point = Sum ("uservote__vote" , distinct = True ),
739777 comments_count = Count ("comment" , distinct = True ),
740778 )
741779
742780 sort_criteria = cls .get_sort_criteria (sort_key )
743781
782+ # Only annotate votes_point if sorting by votes to avoid performance issues
783+ # Otherwise calculate votes separately in bulk
784+ if sort_key == "votes" :
785+ comment_thread_content_type = ContentType .objects .get_for_model (
786+ CommentThread
787+ )
788+ base_query = base_query .annotate (
789+ votes_point = Coalesce (
790+ Subquery (
791+ UserVote .objects .filter (
792+ content_type = comment_thread_content_type ,
793+ content_object_id = OuterRef ("pk" ),
794+ )
795+ .values ("content_object_id" )
796+ .annotate (votes_sum = Sum ("vote" ))
797+ .values ("votes_sum" )[:1 ],
798+ output_field = IntegerField (),
799+ ),
800+ 0 ,
801+ ),
802+ )
803+
804+ base_query = base_query .select_related ("author" , "closed_by" )
805+
744806 comment_threads = (
745807 base_query .order_by (* sort_criteria ) if sort_criteria else base_query
746808 )
747809 thread_count = base_query .count ()
748810
749811 if raw_query :
812+ comment_threads = comment_threads .prefetch_related ("comment_set" )
750813 return {
751814 "result" : [
752815 comment_thread .to_dict () for comment_thread in comment_threads
@@ -762,6 +825,7 @@ def handle_threads_query(
762825 to_skip = (page - 1 ) * per_page
763826 has_more = False
764827
828+ # Note: iterator() doesn't support prefetch_related, so we don't use it here
765829 for thread in comment_threads .iterator ():
766830 thread_key = str (thread .pk )
767831 if (
@@ -777,6 +841,8 @@ def handle_threads_query(
777841 skipped += 1
778842 num_pages = page + 1 if has_more else page
779843 else :
844+ # Apply prefetch_related when not using iterator()
845+ comment_threads = comment_threads .prefetch_related ("comment_set" )
780846 threads = [thread .pk for thread in comment_threads ]
781847 page = max (1 , page )
782848 start = per_page * (page - 1 )
@@ -820,7 +886,10 @@ def prepare_thread(
820886 Returns:
821887 dict[str, Any]: A dictionary representing the prepared thread data.
822888 """
823- thread = CommentThread .objects .get (pk = thread_id )
889+ # Optimize: Use select_related to avoid N+1 queries
890+ thread = CommentThread .objects .select_related ("author" , "closed_by" ).get (
891+ pk = thread_id
892+ )
824893 return {
825894 ** thread .to_dict (),
826895 "type" : "thread" ,
@@ -850,7 +919,25 @@ def threads_presentor(
850919 Returns:
851920 list[dict[str, Any]]: A list of prepared thread data.
852921 """
853- threads = CommentThread .objects .filter (pk__in = thread_ids )
922+
923+ threads = CommentThread .objects .filter (pk__in = thread_ids ).select_related (
924+ "author" , "closed_by"
925+ )
926+
927+ threads_dict = {thread .pk : thread for thread in threads }
928+
929+ # Calculate votes in bulk to avoid N+1 queries
930+ comment_thread_content_type = ContentType .objects .get_for_model (CommentThread )
931+ thread_ids_int = [int (tid ) for tid in thread_ids ]
932+ votes_aggregate = (
933+ UserVote .objects .filter (
934+ content_type = comment_thread_content_type ,
935+ content_object_id__in = thread_ids_int ,
936+ )
937+ .values ("content_object_id" )
938+ .annotate (votes_sum = Sum ("vote" ))
939+ )
940+
854941 read_states = cls .get_read_states (thread_ids , user_id , course_id )
855942 threads_endorsed = cls .get_endorsed (thread_ids )
856943 threads_flagged = (
@@ -859,7 +946,9 @@ def threads_presentor(
859946
860947 presenters = []
861948 for thread_id in thread_ids :
862- thread = threads .get (id = thread_id )
949+ thread = threads_dict .get (int (thread_id ))
950+ if not thread :
951+ continue
863952 is_read , unread_count = read_states .get (
864953 str (thread .pk ), (False , thread .comment_count )
865954 )
@@ -1693,7 +1782,10 @@ def update_comment(comment_id: str, **kwargs: Any) -> int:
16931782 @staticmethod
16941783 def get_thread_id_from_comment (comment_id : str ) -> dict [str , Any ] | None :
16951784 """Return thread_id from comment_id."""
1696- comment = Comment .objects .get (pk = comment_id )
1785+ # Optimize: Use select_related to avoid N+1 queries
1786+ comment = Comment .objects .select_related (
1787+ "comment_thread__author" , "comment_thread__closed_by"
1788+ ).get (pk = comment_id )
16971789 if comment .comment_thread :
16981790 return comment .comment_thread .to_dict ()
16991791 raise ValueError ("Comment doesn't have the thread." )
@@ -2114,20 +2206,37 @@ def get_paginated_user_stats(
21142206 cls , course_id : str , page : int , per_page : int , sort_criterion : dict [str , Any ]
21152207 ) -> dict [str , Any ]:
21162208 """Get paginated user stats."""
2117- users = User .objects .filter (
2118- Q (course_stats__course_id = course_id )
2119- & Q (course_stats__course_id__isnull = False )
2120- ).order_by (
2121- * [f"-{ key } " for key , value in sort_criterion .items () if value == - 1 ],
2122- * [key for key , value in sort_criterion .items () if value == 1 ],
2209+
2210+ users = (
2211+ User .objects .filter (
2212+ Q (course_stats__course_id = course_id )
2213+ & Q (course_stats__course_id__isnull = False )
2214+ )
2215+ .select_related ("forum" )
2216+ .prefetch_related ("course_stats" , "read_states__last_read_times" )
2217+ .order_by (
2218+ * [f"-{ key } " for key , value in sort_criterion .items () if value == - 1 ],
2219+ * [key for key , value in sort_criterion .items () if value == 1 ],
2220+ )
21232221 )
21242222
21252223 paginator = Paginator (users , per_page )
21262224 paginated_users = paginator .page (page )
21272225
2226+ user_ids = [user .pk for user in paginated_users .object_list ]
2227+ forum_users_dict = {
2228+ fu .user .pk : fu
2229+ for fu in ForumUser .objects .filter (user__pk__in = user_ids )
2230+ .select_related ("user" )
2231+ .prefetch_related (
2232+ "user__course_stats" , "user__read_states__last_read_times"
2233+ )
2234+ }
2235+
21282236 forum_users = [
2129- ForumUser .objects .get (user_id = user_id )
2130- for user_id in paginated_users .object_list
2237+ forum_users_dict [user_id ]
2238+ for user_id in user_ids
2239+ if user_id in forum_users_dict
21312240 ]
21322241 return {
21332242 "pagination" : [{"total_count" : paginator .count }],
0 commit comments