Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ exclude = .git,
.mypy_cache,
src/robusta/integrations/kubernetes/autogenerated,
src/robusta/integrations/kubernetes/custom_models.py
ignore = E501, W503, E203
ignore = E501, W503, E203, E704
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/ambv/black
rev: 23.1.0
rev: 26.3.1
hooks:
- id: black
language_version: python3
Expand Down
3 changes: 2 additions & 1 deletion enforcer/dal/robusta_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ class RobustaConfig(BaseModel):
sinks_config: List[Dict[str, Dict]]
global_config: dict


class RobustaToken(BaseModel):
store_url: str
api_key: str
account_id: str
email: str
password: str
password: str
40 changes: 12 additions & 28 deletions enforcer/dal/supabase_dal.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def __init__(self):
if not self.enabled:
logging.info("Not connecting to Robusta platform - robusta token not provided")
return
logging.info(
f"Initializing Robusta platform connection for account {self.account_id} cluster {self.cluster}"
)
logging.info(f"Initializing Robusta platform connection for account {self.account_id} cluster {self.cluster}")
options = ClientOptions(postgrest_client_timeout=SUPABASE_TIMEOUT_SECONDS)
self.client = create_client(self.url, self.api_key, options)
self.user_id = self.sign_in()
Expand All @@ -67,9 +65,7 @@ def execute_with_retry(_self):
message = exc.message or ""
if exc.code == "PGRST301" or "expired" in message.lower():
# JWT expired. Sign in again and retry the query
logging.error(
"JWT token expired/invalid, signing in to Supabase again"
)
logging.error("JWT token expired/invalid, signing in to Supabase again")
self.sign_in()
# update the session to the new one, after re-sign in
_self.session = self.client.postgrest.session
Expand All @@ -81,7 +77,7 @@ def execute_with_retry(_self):
SyncQueryRequestBuilder.execute = execute_with_retry

@staticmethod
def __load_robusta_config() -> (Optional[RobustaToken],Optional[str]):
def __load_robusta_config() -> (Optional[RobustaToken], Optional[str]):
config_file_path = ROBUSTA_CONFIG_PATH
env_ui_token = os.environ.get("ROBUSTA_UI_TOKEN")
cluster_name = os.environ.get("CLUSTER_NAME")
Expand All @@ -92,9 +88,7 @@ def __load_robusta_config() -> (Optional[RobustaToken],Optional[str]):
decoded = base64.b64decode(env_ui_token)
return RobustaToken(**json.loads(decoded)), cluster_name
except binascii.Error:
raise Exception(
"binascii.Error encountered. The Robusta UI token is not a valid base64."
)
raise Exception("binascii.Error encountered. The Robusta UI token is not a valid base64.")
except json.JSONDecodeError:
raise Exception(
"json.JSONDecodeError encountered. The Robusta UI token could not be parsed as JSON after being base64 decoded."
Expand All @@ -112,10 +106,7 @@ def __load_robusta_config() -> (Optional[RobustaToken],Optional[str]):
if "robusta_sink" in conf.keys():
token = conf["robusta_sink"].get("token")
if not token:
raise Exception(
"No robusta token provided.\n"
"Please set a valid Robusta UI token.\n "
)
raise Exception("No robusta token provided.\n" "Please set a valid Robusta UI token.\n ")
env_replacement_token = get_env_replacement(token)
if env_replacement_token:
token = env_replacement_token
Expand All @@ -131,9 +122,7 @@ def __load_robusta_config() -> (Optional[RobustaToken],Optional[str]):
decoded = base64.b64decode(token)
return RobustaToken(**json.loads(decoded)), config.global_config.get("cluster_name")
except binascii.Error:
raise Exception(
"binascii.Error encountered. The robusta token provided is not a valid base64."
)
raise Exception("binascii.Error encountered. The robusta token provided is not a valid base64.")
except json.JSONDecodeError:
raise Exception(
"json.JSONDecodeError encountered. The Robusta token provided could not be parsed as JSON after being base64 decoded."
Expand Down Expand Up @@ -167,12 +156,8 @@ def __init_config(self) -> bool:

def sign_in(self) -> str:
logging.info("Supabase DAL login")
res = self.client.auth.sign_in_with_password(
{"email": self.email, "password": self.password}
)
self.client.auth.set_session(
res.session.access_token, res.session.refresh_token
)
res = self.client.auth.sign_in_with_password({"email": self.email, "password": self.password})
self.client.auth.set_session(res.session.access_token, res.session.refresh_token)
self.client.postgrest.auth(res.session.access_token)
return res.user.id

Expand Down Expand Up @@ -200,7 +185,7 @@ def get_latest_krr_scan(self, current_scan_id: Optional[str]) -> (Optional[str],
latest_scan_data = sorted_scans[0]
else:
latest_scan_data = scans_meta_response.data[0]

latest_scan_id = latest_scan_data["scan_id"]

if latest_scan_id == current_scan_id:
Expand All @@ -211,7 +196,9 @@ def get_latest_krr_scan(self, current_scan_id: Optional[str]) -> (Optional[str],
scan_datetime = datetime.fromisoformat(scan_start)
max_age = timedelta(hours=SCAN_AGE_HOURS_THRESHOLD)
if datetime.now(timezone.utc) - scan_datetime > max_age:
logging.warning(f"Latest scan {latest_scan_id} is too old (started {scan_start}). No fresh KRR scan available.")
logging.warning(
f"Latest scan {latest_scan_id} is too old (started {scan_start}). No fresh KRR scan available."
)
return None, None

scans_results_response = (
Expand All @@ -229,6 +216,3 @@ def get_latest_krr_scan(self, current_scan_id: Optional[str]) -> (Optional[str],
except Exception:
logging.exception("Supabase error while retrieving krr scan data")
return None, None



105 changes: 54 additions & 51 deletions enforcer/enforcer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,44 +35,43 @@
# Configure logging
logger = logging.getLogger()
logHandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
logHandler.setFormatter(formatter)
logger.addHandler(logHandler)
logger.setLevel(os.environ.get("LOG_LEVEL", "INFO"))

# Define the mention pattern regex
MENTION_PATTERN = re.compile(r'@[\w.-]+')
MENTION_PATTERN = re.compile(r"@[\w.-]+")
ENFORCE = "enforce"
IGNORE = "ignore"

app = FastAPI(
title="KRR Enforcer mutation webhook",
description="A KRR recommendations mutating webhook server for Kubernetes",
version="1.0.0"
version="1.0.0",
)

dal = SupabaseDal()
recommendation_store = RecommendationStore(dal)
owner_store = OwnerStore()


class AdmissionReview(BaseModel):
apiVersion: str
kind: str
request: Dict[str, Any]


def admission_allowed(request: AdmissionReview) -> Dict[str, Any]:
return \
{
"apiVersion": "admission.k8s.io/v1",
"kind": "AdmissionReview",
"response": {
"uid": request.request.get('uid'),
"allowed": True
}
return {
"apiVersion": "admission.k8s.io/v1",
"kind": "AdmissionReview",
"response": {"uid": request.request.get("uid"), "allowed": True},
}


def enforce_pod(pod: Dict[str, Any]) -> bool:
mode = pod.get('metadata', {}).get('annotations', {}).get("admission.robusta.dev/krr-mutation-mode", None)
mode = pod.get("metadata", {}).get("annotations", {}).get("admission.robusta.dev/krr-mutation-mode", None)
if mode == ENFORCE:
return True
elif mode == IGNORE:
Expand All @@ -85,30 +84,29 @@ def enforce_pod(pod: Dict[str, Any]) -> bool:
async def mutate(request: AdmissionReview):
"""
Handle mutating webhook requests from Kubernetes.

Args:
request (AdmissionReview): The admission review request from Kubernetes

Returns:
dict: Admission review response
"""
start_time = time.time()
try:
logging.debug("Admission request received %s", request)
# Extract the object being reviewed
object_to_review = request.request.get('object', {})
kind = request.request.get('kind', {}).get('kind')
object_to_review = request.request.get("object", {})
kind = request.request.get("kind", {}).get("kind")

if kind == "ReplicaSet": # use create/delete admission requests, to track new/removed replica sets owners
owner_store.handle_rs_admission(request.request)
operation = request.request.get('operation', 'UNKNOWN')
operation = request.request.get("operation", "UNKNOWN")
replicaset_admissions.labels(operation=operation).inc()
admission_duration.labels(kind='ReplicaSet').observe(time.time() - start_time)
admission_duration.labels(kind="ReplicaSet").observe(time.time() - start_time)
# Update rs_owners size metric
rs_owners_size.set(owner_store.get_rs_owners_count())
return admission_allowed(request)


if kind != "Pod":
logger.warning(f"Received unexpected resource mutation: {kind}")
return admission_allowed(request)
Expand Down Expand Up @@ -144,12 +142,12 @@ async def mutate(request: AdmissionReview):
logger.debug("Pod Recommendations %s", recommendations)

patches = []

containers = object_to_review.get("spec", {}).get("containers", [])
for i, container in enumerate(containers):
container_name = container.get("name")
patches.extend(patch_container_resources(i, container, recommendations.get(container_name)))

# Record metrics for Pod mutation
was_mutated = len(patches) > 0
reason = "success" if was_mutated else "no_changes_needed"
Expand All @@ -166,91 +164,96 @@ async def mutate(request: AdmissionReview):
response["patchType"] = "JSONPatch"
response["patch"] = base64.b64encode(json.dumps(patches).encode()).decode()

return {
"apiVersion": "admission.k8s.io/v1",
"kind": "AdmissionReview",
"response": response
}

return {"apiVersion": "admission.k8s.io/v1", "kind": "AdmissionReview", "response": response}

except Exception as e:
logger.exception("Error processing webhook request")
# Record failure metric for Pod requests
if request.request.get('kind', {}).get('kind') == "Pod":
if request.request.get("kind", {}).get("kind") == "Pod":
pod_admission_mutations.labels(mutated="false", reason="processing_error").inc()
admission_duration.labels(kind="Pod").observe(time.time() - start_time)
raise HTTPException(status_code=500, detail=str(e))


@app.get("/health")
async def health_check():
"""
Health check endpoint.

Returns:
dict: Health status
"""
owner_store.finalize_owner_initialization() # Init loading owners from api server, after accepting api requests
return {"status": "healthy"}


@app.get("/recommendations/{namespace}/{kind}/{name}")
async def get_recommendations(namespace: str, kind: str, name: str):
"""
Get recommendations for a workload.

Args:
namespace: Kubernetes namespace
kind: Workload kind (e.g., Deployment, StatefulSet)
name: Workload name

Returns:
dict: Recommendations per container or 404 if not found
"""
try:
recommendations: WorkloadRecommendation = recommendation_store.get_recommendations(
name=name, namespace=namespace, kind=kind
)

if not recommendations:
raise HTTPException(status_code=404, detail="No recommendations found for this workload")

result = {}
for container_name, container_recommendation in recommendations.container_recommendations.items():
result[container_name] = {
"cpu": {
"request": container_recommendation.cpu.request,
"limit": container_recommendation.cpu.limit
} if container_recommendation.cpu else None,
"memory": {
"request": container_recommendation.memory.request,
"limit": container_recommendation.memory.limit
} if container_recommendation.memory else None
"cpu": (
{"request": container_recommendation.cpu.request, "limit": container_recommendation.cpu.limit}
if container_recommendation.cpu
else None
),
"memory": (
{"request": container_recommendation.memory.request, "limit": container_recommendation.memory.limit}
if container_recommendation.memory
else None
),
}

return {
"namespace": namespace,
"kind": kind,
"name": name,
"containers": result
}


return {"namespace": namespace, "kind": kind, "name": name, "containers": result}

except HTTPException:
raise
except Exception as e:
logger.exception("Error retrieving recommendations")
raise HTTPException(status_code=500, detail=str(e))


@app.get("/metrics")
async def metrics():
"""
Prometheus metrics endpoint.

Returns:
Response: Prometheus metrics in text format
"""
# Update rs_owners size metric before returning metrics
rs_owners_size.set(owner_store.get_rs_owners_count())
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)


if __name__ == "__main__":
import uvicorn

logger.info("Starting Kubernetes Webhook server on 8443...")
uvicorn.run(app, host="0.0.0.0", port=8443, ssl_keyfile=ENFORCER_SSL_KEY_FILE, ssl_certfile=ENFORCER_SSL_CERT_FILE, log_level="warning")
uvicorn.run(
app,
host="0.0.0.0",
port=8443,
ssl_keyfile=ENFORCER_SSL_KEY_FILE,
ssl_certfile=ENFORCER_SSL_CERT_FILE,
log_level="warning",
)
13 changes: 7 additions & 6 deletions enforcer/env_vars.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os

ROBUSTA_CONFIG_PATH = os.environ.get(
"ROBUSTA_CONFIG_PATH", "/etc/robusta/config/active_playbooks.yaml"
)
ROBUSTA_CONFIG_PATH = os.environ.get("ROBUSTA_CONFIG_PATH", "/etc/robusta/config/active_playbooks.yaml")
ROBUSTA_ACCOUNT_ID = os.environ.get("ROBUSTA_ACCOUNT_ID", "")
STORE_URL = os.environ.get("STORE_URL", "")
STORE_API_KEY = os.environ.get("STORE_API_KEY", "")
Expand All @@ -18,10 +16,13 @@
KRR_MUTATION_MODE_DEFAULT = os.environ.get("KRR_MUTATION_MODE_DEFAULT", "enforce")
REPLICA_SET_CLEANUP_INTERVAL = int(os.environ.get("REPLICA_SET_CLEANUP_INTERVAL", 600))
REPLICA_SET_DELETION_WAIT = int(os.environ.get("REPLICA_SET_DELETION_WAIT", 600))
SCAN_AGE_HOURS_THRESHOLD = int(os.environ.get("SCAN_AGE_HOURS_THRESHOLD", 360)) # 15 days
SCAN_AGE_HOURS_THRESHOLD = int(os.environ.get("SCAN_AGE_HOURS_THRESHOLD", 360)) # 15 days

ENFORCER_SSL_KEY_FILE = os.environ.get("ENFORCER_SSL_KEY_FILE", "")
ENFORCER_SSL_CERT_FILE = os.environ.get("ENFORCER_SSL_CERT_FILE", "")

EXCLUDED_CONTAINERS = [container_name.strip() for container_name
in os.environ.get("EXCLUDED_CONTAINERS", "").split(",") if container_name.strip()]
EXCLUDED_CONTAINERS = [
container_name.strip()
for container_name in os.environ.get("EXCLUDED_CONTAINERS", "").split(",")
if container_name.strip()
]
Loading
Loading