@@ -26,13 +26,15 @@ import threading
2626import opcode
2727import os
2828import types
29+ from weakref import WeakSet
2930
3031NOP_VALUE: int = opcode.opmap[' NOP' ]
3132
3233# The Op code should be 2 bytes as stated in
3334# https://docs.python.org/3/library/dis.html
3435# if sys.version_info[0:2] >= (3, 11):
35- NOP_BYTES: bytes = NOP_VALUE.to_bytes(2 , byteorder = byteorder)
36+ NOP_BYTES_LEN: int = 2
37+ NOP_BYTES: bytes = NOP_VALUE.to_bytes(NOP_BYTES_LEN, byteorder = byteorder)
3638
3739# This should be true for Python >=3.11a1
3840HAS_CO_QUALNAME: bool = hasattr (types.CodeType, ' co_qualname' )
@@ -150,6 +152,24 @@ cdef inline int64 compute_line_hash(uint64 block_hash, uint64 linenum):
150152 return block_hash ^ linenum
151153
152154
155+ cdef inline object multibyte_rstrip(bytes bytecode):
156+ """
157+ Returns:
158+ result (tuple[bytes, int])
159+ - First item is the bare unpadded bytecode
160+ - Second item is the number of :py:const:`NOP_BYTES`
161+ ``bytecode`` has been padded with
162+ """
163+ npad: int = 0
164+ nop_len: int = - NOP_BYTES_LEN
165+ nop_bytes: bytes = NOP_BYTES
166+ unpadded: bytes = bytecode
167+ while unpadded.endswith(nop_bytes):
168+ unpadded = unpadded[:nop_len]
169+ npad += 1
170+ return (unpadded, npad)
171+
172+
153173if CAN_USE_SYS_MONITORING:
154174 def _is_main_thread () -> bool:
155175 return threading.current_thread() == threading.main_thread()
@@ -318,8 +338,13 @@ cdef class LineProfiler:
318338 cdef public double timer_unit
319339 cdef public object threaddata
320340
321- # This is shared between instances and threads
341+ # These are shared between instances and threads
342+ # type: dict[int, set[LineProfiler]], int = thread id
322343 _all_active_instances = {}
344+ # type: dict[bytes, int], bytes = bytecode
345+ _all_paddings = {}
346+ # type: dict[int, weakref.WeakSet[LineProfiler]], int = func id
347+ _all_instances_by_funcs = {}
323348
324349 def __init__ (self , *functions ):
325350 self .functions = []
@@ -345,39 +370,74 @@ cdef class LineProfiler:
345370 )
346371 try :
347372 code = func.__code__
373+ func_id = id (func)
348374 except AttributeError :
349375 try :
350376 code = func.__func__.__code__
377+ func_id = id (func.__func__)
351378 except AttributeError :
352379 import warnings
353380 warnings.warn(" Could not extract a code object for the object %r " % (func,))
354381 return
355382
383+ # Note: if we are to alter the code object, other profilers
384+ # which previously added this function would still expect the
385+ # old bytecode, and thus will not see anything when the function
386+ # is executed;
387+ # hence:
388+ # - When doing bytecode padding, take into account all instances
389+ # which refers to the same base bytecode to ensure
390+ # disambiguation
391+ # - Update all existing instances referring to the old code
392+ # object
393+ # Since no code padding is/can be done with Cython mock
394+ # "code objects", it is *probably* okay to only do the special
395+ # handling on the non-Cython branch.
396+ # XXX: tests for the above assertion if necessary
397+ co_code: bytes = code.co_code
356398 code_hashes = []
357- if any (code.co_code): # Normal Python functions
358- if code.co_code in self .dupes_map:
359- self .dupes_map[code.co_code] += [code]
360- # code hash already exists, so there must be a duplicate
361- # function. add no-op
362- co_padding : bytes = NOP_BYTES * (len (self .dupes_map[code.co_code]) + 1 )
363- co_code = code.co_code + co_padding
364- CodeType = type (code)
399+ if any (co_code): # Normal Python functions
400+ # Figure out how much padding we need and strip the bytecode
401+ base_co_code: bytes
402+ npad_code: int
403+ base_co_code, npad_code = multibyte_rstrip(co_code)
404+ try :
405+ npad = self ._all_paddings[base_co_code]
406+ except KeyError :
407+ npad = 0
408+ self ._all_paddings[base_co_code] = max (npad, npad_code) + 1
409+ try :
410+ profilers_to_update = self ._all_instances_by_funcs[func_id]
411+ profilers_to_update.add(self )
412+ except KeyError :
413+ profilers_to_update = WeakSet({self })
414+ self ._all_instances_by_funcs[func_id] = profilers_to_update
415+ # Maintain `.dupes_map` (legacy)
416+ try :
417+ self .dupes_map[base_co_code].append(code)
418+ except KeyError :
419+ self .dupes_map[base_co_code] = [code]
420+ if npad > npad_code:
421+ # Code hash already exists, so there must be a duplicate
422+ # function (on some instance);
423+ # (re-)pad with no-op
424+ co_code = base_co_code + NOP_BYTES * npad
365425 code = _code_replace(func, co_code = co_code)
366426 try :
367427 func.__code__ = code
368428 except AttributeError as e:
369429 func.__func__.__code__ = code
370- else :
371- self .dupes_map[code.co_code] = [code]
430+ else : # No re-padding -> no need to update the other profs
431+ profilers_to_update = { self }
372432 # TODO: Since each line can be many bytecodes, this is kinda
373433 # inefficient
374434 # See if this can be sped up by not needing to iterate over
375435 # every byte
376- for offset, _ in enumerate (code. co_code):
377- code_hash = compute_line_hash (
378- hash ((code.co_code)),
379- PyCode_Addr2Line( < PyCodeObject * > code, offset))
380- code_hashes.append(code_hash )
436+ for offset, _ in enumerate (co_code):
437+ code_hashes.append (
438+ compute_line_hash(
439+ hash (co_code),
440+ PyCode_Addr2Line( < PyCodeObject * > code, offset)) )
381441 else : # Cython functions have empty/zero bytecodes
382442 if CANNOT_LINE_TRACE_CYTHON:
383443 return
@@ -400,13 +460,21 @@ cdef class LineProfiler:
400460 # We can't replace the code object on Cython functions, but
401461 # we can *store* a copy with the correct metadata
402462 code = code.replace(co_filename = cython_source)
403- for code_hash in code_hashes:
404- if not self ._c_code_map.count(code_hash):
405- try :
406- self .code_hash_map[code].append(code_hash)
407- except KeyError :
408- self .code_hash_map[code] = [code_hash]
409- self ._c_code_map[code_hash]
463+ profilers_to_update = {self }
464+ # Update `._c_code_map` and `.code_hash_map` with the new line
465+ # hashes on `self` (and other instances profiling the same
466+ # function if we padded the bytecode)
467+ for instance in profilers_to_update:
468+ prof = < LineProfiler> instance
469+ try :
470+ line_hashes = prof.code_hash_map[code]
471+ except KeyError :
472+ line_hashes = prof.code_hash_map[code] = []
473+ for code_hash in code_hashes:
474+ line_hash = < int64> code_hash
475+ if not prof._c_code_map.count(line_hash):
476+ line_hashes.append(line_hash)
477+ prof._c_code_map[line_hash]
410478
411479 self .functions.append(func)
412480
@@ -541,35 +609,29 @@ cdef class LineProfiler:
541609 """
542610 cdef dict cmap = self ._c_code_map
543611
544- stats = {}
612+ all_entries = {}
545613 for code in self .code_hash_map:
546614 entries = []
547615 for entry in self .code_hash_map[code]:
548- entries += list (cmap[entry].values())
616+ entries.extend (cmap[entry].values())
549617 key = label(code)
550618
551- # Merge duplicate line numbers, which occur for branch entrypoints like `if`
552- nhits_by_lineno = {}
553- total_time_by_lineno = {}
619+ # Merge duplicate line numbers, which occur for branch
620+ # entrypoints like `if`
621+ entries_by_lineno = all_entries.setdefault(key, {})
554622
555623 for line_dict in entries:
556624 _, lineno, total_time, nhits = line_dict.values()
557- nhits_by_lineno[lineno] = nhits_by_lineno.setdefault(lineno, 0 ) + nhits
558- total_time_by_lineno[lineno] = total_time_by_lineno.setdefault(lineno, 0 ) + total_time
559-
560- entries = [(lineno, nhits, total_time_by_lineno[lineno]) for lineno, nhits in nhits_by_lineno.items()]
561- entries.sort()
562-
563- # NOTE: v4.x may produce more than one entry per line. For example:
564- # 1: for x in range(10):
565- # 2: pass
566- # will produce a 1-hit entry on line 1, and 10-hit entries on lines 1 and 2
567- # This doesn't affect `print_stats`, because it uses the last entry for a given line (line number is
568- # used a dict key so earlier entries are overwritten), but to keep compatability with other tools,
569- # let's only keep the last entry for each line
570- # Remove all but the last entry for each line
571- entries = list ({e[0 ]: e for e in entries}.values())
572- stats[key] = entries
625+ orig_nhits, orig_total_time = entries_by_lineno.get(
626+ lineno, (0 , 0 ))
627+ entries_by_lineno[lineno] = (orig_nhits + nhits,
628+ orig_total_time + total_time)
629+
630+ # Aggregate the timing data
631+ stats = {
632+ key: sorted ((line, nhits, time)
633+ for line, (nhits, time) in entries_by_lineno.items())
634+ for key, entries_by_lineno in all_entries.items()}
573635 return LineStats(stats, self .timer_unit)
574636
575637
0 commit comments