Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 197 additions & 6 deletions ext/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ package ext

import (
"fmt"
"math"
"net/netip"
"reflect"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
)

const (
Expand Down Expand Up @@ -182,7 +185,11 @@ const (

var (
// Definitions for the Opaque Types
IPType = types.NewOpaqueType("net.IP")

// IPType represents a network IP address.
IPType = types.NewOpaqueType("net.IP")

// CIDRType represents a CIDR-format network range.
CIDRType = types.NewOpaqueType("net.CIDR")
)

Expand All @@ -196,13 +203,11 @@ func (*networkLib) LibraryName() string {

func (*networkLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
// 1. Register Types
cel.Types(
IPType,
CIDRType,
),

// 2. Register Functions
cel.Function(cidrFunc,
// K8s Parity: Following the pattern, this is "string_to_cidr"
cel.Overload("string_to_cidr", []*cel.Type{cel.StringType}, CIDRType,
Expand Down Expand Up @@ -288,11 +293,58 @@ func (*networkLib) CompileOptions() []cel.EnvOption {
networkFormatValidator{funcName: ipFunc, argNum: 0, check: checkIP},
networkFormatValidator{funcName: cidrFunc, argNum: 0, check: checkCIDR},
),
cel.CostEstimatorOptions(
checker.OverloadCostEstimate("string_to_cidr", estimateNetworkParseCost),
checker.OverloadCostEstimate("cidr_to_string", estimateNetworkNominalStringCost),
checker.OverloadCostEstimate("cidr_contains_cidr", estimateNetworkContainsCIDRCIDRCost),
checker.OverloadCostEstimate("cidr_contains_cidr_string", estimateNetworkContainsCIDRStringCost),
checker.OverloadCostEstimate("cidr_contains_ip_ip", estimateNetworkContainsIPIPCost),
checker.OverloadCostEstimate("cidr_contains_ip_string", estimateNetworkContainsIPStringCost),
checker.OverloadCostEstimate("ip_family", estimateNetworkNominalCost),
checker.OverloadCostEstimate("string_to_ip", estimateNetworkParseCost),
checker.OverloadCostEstimate("cidr_ip", estimateNetworkNominalOpaqueCost),
checker.OverloadCostEstimate("ip_to_string", estimateNetworkNominalStringCost),
checker.OverloadCostEstimate("ip_is_canonical", estimateIPIsCanonicalCost),
checker.OverloadCostEstimate("is_cidr", estimateNetworkParseBoolCost),
checker.OverloadCostEstimate("ip_is_global_unicast", estimateNetworkNominalCost),
checker.OverloadCostEstimate("is_ip", estimateNetworkParseBoolCost),
checker.OverloadCostEstimate("ip_is_link_local_multicast", estimateNetworkNominalCost),
checker.OverloadCostEstimate("ip_is_link_local_unicast", estimateNetworkNominalCost),
checker.OverloadCostEstimate("ip_is_loopback", estimateNetworkNominalCost),
checker.OverloadCostEstimate("cidr_is_mask", estimateNetworkNominalCost),
checker.OverloadCostEstimate("ip_is_unspecified", estimateNetworkNominalCost),
checker.OverloadCostEstimate("cidr_masked", estimateNetworkNominalOpaqueCost),
checker.OverloadCostEstimate("cidr_prefix_length", estimateNetworkNominalCost),
),
}
}

func (*networkLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
return []cel.ProgramOption{
cel.CostTrackerOptions(
interpreter.OverloadCostTracker("string_to_cidr", trackNetworkParseCost),
interpreter.OverloadCostTracker("cidr_to_string", trackNetworkNominalCost),
interpreter.OverloadCostTracker("cidr_contains_cidr", trackNetworkContainsCIDRCIDRCost),
interpreter.OverloadCostTracker("cidr_contains_cidr_string", trackNetworkContainsCIDRStringCost),
interpreter.OverloadCostTracker("cidr_contains_ip_ip", trackNetworkContainsIPIPCost),
interpreter.OverloadCostTracker("cidr_contains_ip_string", trackNetworkContainsIPStringCost),
interpreter.OverloadCostTracker("ip_family", trackNetworkNominalCost),
interpreter.OverloadCostTracker("string_to_ip", trackNetworkParseCost),
interpreter.OverloadCostTracker("cidr_ip", trackNetworkNominalCost),
interpreter.OverloadCostTracker("ip_to_string", trackNetworkNominalCost),
interpreter.OverloadCostTracker("ip_is_canonical", trackIPIsCanonicalCost),
interpreter.OverloadCostTracker("is_cidr", trackNetworkParseCost),
interpreter.OverloadCostTracker("ip_is_global_unicast", trackNetworkNominalCost),
interpreter.OverloadCostTracker("is_ip", trackNetworkParseCost),
interpreter.OverloadCostTracker("ip_is_link_local_multicast", trackNetworkNominalCost),
interpreter.OverloadCostTracker("ip_is_link_local_unicast", trackNetworkNominalCost),
interpreter.OverloadCostTracker("ip_is_loopback", trackNetworkNominalCost),
interpreter.OverloadCostTracker("cidr_is_mask", trackNetworkNominalCost),
interpreter.OverloadCostTracker("ip_is_unspecified", trackNetworkNominalCost),
interpreter.OverloadCostTracker("cidr_masked", trackNetworkNominalCost),
interpreter.OverloadCostTracker("cidr_prefix_length", trackNetworkNominalCost),
),
}
}

// networkAdapter adapts netip types while preserving existing adapters.
Expand Down Expand Up @@ -478,8 +530,7 @@ func parseIPAddr(raw string) (netip.Addr, error) {
return addr, nil
}

// --- Opaque Type Wrappers ---

// IP represents an IP address type.
type IP struct {
netip.Addr
}
Expand Down Expand Up @@ -527,6 +578,13 @@ func (i IP) Value() any {
return i.Addr
}

// Size returns the size of the IP address in bytes.
// /Used in the size estimation of the runtime cost.
func (i IP) Size() ref.Val {
return types.Int(int64(math.Ceil(float64(i.Addr.BitLen()) / 8)))
}

// CIDR represents the CIDR network mask format.
type CIDR struct {
netip.Prefix
}
Expand Down Expand Up @@ -574,6 +632,12 @@ func (c CIDR) Value() any {
return c.Prefix
}

// Size returns the size of the CIDR prefix address in bytes.
// Used in the size estimation of the runtime cost.
func (c CIDR) Size() ref.Val {
return types.Int(int64(math.Ceil(float64(c.Prefix.Bits()) / 8)))
}

// --- Static Validators ---

type argChecker func(e *cel.Env, call, arg ast.Expr) error
Expand Down Expand Up @@ -617,3 +681,130 @@ func checkCIDR(e *cel.Env, call, arg ast.Expr) error {
_, err := parseCIDR(pattern)
return err
}

// Cost estimation functions for network extensions.

func estimateNetworkParseCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) < 1 {
return nil
}
sz := estimateSize(estimator, args[0])
resultSize := rangedSizeEstimate(4, 16)
return callEstimate(sz.MultiplyByCostFactor(stringCostFactor), &resultSize)
}

func estimateNetworkParseBoolCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) < 1 {
return nil
}
sz := estimateSize(estimator, args[0])
return callEstimate(sz.MultiplyByCostFactor(stringCostFactor), nil)
}

func estimateIPIsCanonicalCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) < 1 {
return nil
}
sz := estimateSize(estimator, args[0])
return callEstimate(sz.MultiplyByCostFactor(2*stringCostFactor), nil)
}

func estimateNetworkNominalCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return callEstimate(callCostEstimate, nil)
}

func estimateNetworkNominalOpaqueCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
resultSize := rangedSizeEstimate(4, 16)
return callEstimate(callCostEstimate, &resultSize)
}

func estimateNetworkNominalStringCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
resultSize := rangedSizeEstimate(3, 45)
return callEstimate(callCostEstimate, &resultSize)
}

func estimateNetworkContainsIPIPCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
sz := rangedSizeEstimate(4, 16)
ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor)
return callEstimate(ipCompCost, nil)
}

func estimateNetworkContainsIPStringCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) < 1 {
return nil
}
sz := rangedSizeEstimate(4, 16)
ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor)
argSz := estimateSize(estimator, args[0])
ipCompCost = ipCompCost.Add(argSz.MultiplyByCostFactor(stringCostFactor))
return callEstimate(ipCompCost, nil)
}

func estimateNetworkContainsCIDRCIDRCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
sz := rangedSizeEstimate(4, 16)
ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor)
ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(stringCostFactor))
// K8s adds one for the extra IP traversal
ipCompCost = ipCompCost.Add(callCostEstimate)
return callEstimate(ipCompCost, nil)
}

func estimateNetworkContainsCIDRStringCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) < 1 {
return nil
}
sz := rangedSizeEstimate(4, 16)
ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor)
ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(stringCostFactor))
argSz := estimateSize(estimator, args[0])
ipCompCost = ipCompCost.Add(argSz.MultiplyByCostFactor(stringCostFactor))
// K8s adds one for the extra IP traversal
ipCompCost = ipCompCost.Add(callCostEstimate)
return callEstimate(ipCompCost, nil)
}

// Runtime cost tracking functions for network extensions.

func trackNetworkParseCost(args []ref.Val, result ref.Val) *uint64 {
cost := uint64(math.Ceil(float64(actualSize(args[0])) * stringCostFactor))
return &cost
}

func trackIPIsCanonicalCost(args []ref.Val, result ref.Val) *uint64 {
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * stringCostFactor))
return &cost
}

func trackNetworkNominalCost(args []ref.Val, result ref.Val) *uint64 {
return &callCost
}

func trackNetworkContainsIPIPCost(args []ref.Val, result ref.Val) *uint64 {
cidrSize := actualSize(args[0])
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor))
return &cost
}

func trackNetworkContainsIPStringCost(args []ref.Val, result ref.Val) *uint64 {
cidrSize := actualSize(args[0])
otherSize := actualSize(args[1])
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor))
cost = safeAdd(cost, uint64(math.Ceil(float64(otherSize)*stringCostFactor)))
return &cost
}

func trackNetworkContainsCIDRCIDRCost(args []ref.Val, result ref.Val) *uint64 {
cidrSize := actualSize(args[0])
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor))
cost = safeAdd(cost, uint64(math.Ceil(float64(cidrSize)*stringCostFactor)), 1)
return &cost
}

func trackNetworkContainsCIDRStringCost(args []ref.Val, result ref.Val) *uint64 {
cidrSize := actualSize(args[0])
otherSize := actualSize(args[1])
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor))
cost = safeAdd(cost, uint64(math.Ceil(float64(cidrSize)*stringCostFactor)), 1)
cost = safeAdd(cost, uint64(math.Ceil(float64(otherSize)*stringCostFactor)))
return &cost
}
Loading