@@ -44,6 +44,7 @@ def _get_invalid_idx_value(cls):
4444 """
4545 if torch .onnx .is_in_onnx_export ():
4646 if cls .SUBFUNC_ENABLED :
47+ # TODO: should not return 0 remove this if condition, it can hurt perf
4748 return 0
4849 else :
4950 return torch .iinfo (torch .int32 ).max
@@ -722,9 +723,22 @@ def full_cache_update_chunked(
722723 cache_kwargs : Optional [Dict [str , Any ]] = None ,
723724 ) -> Tuple [torch .Tensor , torch .Tensor ]:
724725 position_ids = cache_kwargs .get ("position_ids" )
726+ batch_index = cache_kwargs .get ("batch_index" )
727+ invalid_idx_value = InvalidIndexProvider ._get_invalid_idx_value ()
725728
726- self .key_cache [layer_idx ] = CtxScatterFunc .apply (self .key_cache [layer_idx ], position_ids , key_states )
727- self .value_cache [layer_idx ] = CtxScatterFunc .apply (self .value_cache [layer_idx ], position_ids , value_states )
729+ # Scatter
730+ if batch_index is not None :
731+ if torch .onnx .is_in_onnx_export ():
732+ scatter_position_ids = torch .where (position_ids < 0 , torch .iinfo (torch .int32 ).max , position_ids )
733+ self .key_cache [layer_idx ] = CtxScatterFuncCB .apply (
734+ self .key_cache [layer_idx ], batch_index , scatter_position_ids , key_states
735+ )
736+ self .value_cache [layer_idx ] = CtxScatterFuncCB .apply (
737+ self .value_cache [layer_idx ], batch_index , scatter_position_ids , value_states
738+ )
739+ else :
740+ self .key_cache [layer_idx ] = CtxScatterFunc .apply (self .key_cache [layer_idx ], position_ids , key_states )
741+ self .value_cache [layer_idx ] = CtxScatterFunc .apply (self .value_cache [layer_idx ], position_ids , value_states )
728742
729743 k_out , v_out = self .key_cache [layer_idx ], self .value_cache [layer_idx ]
730744
@@ -733,11 +747,13 @@ def full_cache_update_chunked(
733747 ctx_indices = torch .arange (ctx_len )[None , None , ...]
734748 gather_limit = position_ids .max (1 , keepdim = True ).values .unsqueeze (1 )
735749 invalid_mask = ctx_indices > gather_limit
736-
737- invalid_idx_value = InvalidIndexProvider ._get_invalid_idx_value ()
738750 ctx_indices = torch .where (invalid_mask , invalid_idx_value , ctx_indices )
739- k_out = CtxGatherFunc .apply (k_out , ctx_indices , ctx_len )
740- v_out = CtxGatherFunc .apply (v_out , ctx_indices , ctx_len )
751+ if batch_index is not None :
752+ k_out = CtxGatherFuncCB .apply (k_out , batch_index , ctx_indices , ctx_len )
753+ v_out = CtxGatherFuncCB .apply (v_out , batch_index , ctx_indices , ctx_len )
754+ else :
755+ k_out = CtxGatherFunc .apply (k_out , ctx_indices , ctx_len )
756+ v_out = CtxGatherFunc .apply (v_out , ctx_indices , ctx_len )
741757 v_out = torch .where (invalid_mask .unsqueeze (- 1 ), torch .tensor (0.0 , dtype = torch .float32 ), v_out )
742758
743759 return k_out , v_out
@@ -750,26 +766,40 @@ def sliding_window_update_chunked(
750766 cache_kwargs : Optional [Dict [str , Any ]] = None ,
751767 ) -> Tuple [torch .Tensor , torch .Tensor ]:
752768 position_ids = cache_kwargs .get ("position_ids" )
769+ batch_index = cache_kwargs .get ("batch_index" )
770+ invalid_idx_value = InvalidIndexProvider ._get_invalid_idx_value ()
753771
754- self .key_cache [layer_idx ] = CtxScatterFunc .apply (self .key_cache [layer_idx ], position_ids , key_states )
755- self .value_cache [layer_idx ] = CtxScatterFunc .apply (self .value_cache [layer_idx ], position_ids , value_states )
772+ if batch_index is not None :
773+ if torch .onnx .is_in_onnx_export ():
774+ scatter_position_ids = torch .where (position_ids < 0 , torch .iinfo (torch .int32 ).max , position_ids )
775+ self .key_cache [layer_idx ] = CtxScatterFuncCB .apply (
776+ self .key_cache [layer_idx ], batch_index , scatter_position_ids , key_states
777+ )
778+ self .value_cache [layer_idx ] = CtxScatterFuncCB .apply (
779+ self .value_cache [layer_idx ], batch_index , scatter_position_ids , value_states
780+ )
781+ else :
782+ self .key_cache [layer_idx ] = CtxScatterFunc .apply (self .key_cache [layer_idx ], position_ids , key_states )
783+ self .value_cache [layer_idx ] = CtxScatterFunc .apply (self .value_cache [layer_idx ], position_ids , value_states )
756784
757785 k_out , v_out = self .key_cache [layer_idx ], self .value_cache [layer_idx ]
758786 sliding_window_len = cache_kwargs .get ("sliding_window" )
787+
759788 # Gather
760789 ctx_len = position_ids .shape [1 ] + sliding_window_len
761790 ctx_indices = torch .arange (ctx_len )[None , None , ...]
762- # positive_pos_ids = torch.where(position_ids<0, 0, position_ids)
763791 first_pos_idx = position_ids [0 ][0 ]
764792 add_idx = torch .where (first_pos_idx >= sliding_window_len , first_pos_idx - sliding_window_len , 0 )
765793 ctx_indices += add_idx
766794 gather_limit = position_ids .max (1 , keepdim = True ).values .unsqueeze (1 )
767795 invalid_mask = ctx_indices > gather_limit
768-
769- invalid_idx_value = InvalidIndexProvider ._get_invalid_idx_value ()
770796 ctx_indices = torch .where (invalid_mask , invalid_idx_value , ctx_indices )
771- k_out = CtxGatherFunc .apply (k_out , ctx_indices , ctx_len )
772- v_out = CtxGatherFunc .apply (v_out , ctx_indices , ctx_len )
797+ if batch_index is not None :
798+ k_out = CtxGatherFuncCB .apply (k_out , batch_index , ctx_indices , ctx_len )
799+ v_out = CtxGatherFuncCB .apply (v_out , batch_index , ctx_indices , ctx_len )
800+ else :
801+ k_out = CtxGatherFunc .apply (k_out , ctx_indices , ctx_len )
802+ v_out = CtxGatherFunc .apply (v_out , ctx_indices , ctx_len )
773803 v_out = torch .where (invalid_mask .unsqueeze (- 1 ), torch .tensor (0.0 , dtype = torch .float32 ), v_out )
774804
775805 return k_out , v_out
0 commit comments