Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 57 additions & 198 deletions src/airflow_docker/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,24 @@
# under the License.
import ast
import json
from tempfile import TemporaryDirectory
from typing import Dict, Iterable, List, Optional, Union

import airflow.configuration as conf
import airflow_docker_helper
import six
from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.hooks.docker_hook import DockerHook
from airflow.models import BaseOperator, SkipMixin
from airflow.models import SkipMixin
from airflow.providers.docker.hooks.docker import DockerHook
from airflow.providers.docker.operators.docker import (
DockerOperator as AirflowDockerOperator,
)
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
from airflow.utils.file import TemporaryDirectory
from docker import APIClient, tls

from airflow_docker.conf import get_boolean_default, get_default
from airflow_docker.ext import delegate_to_extensions, register_extensions
from airflow_docker.utils import get_config
from docker import APIClient, tls

DEFAULT_HOST_TEMPORARY_DIRECTORY = "/tmp/airflow"


class ShortCircuitMixin(SkipMixin):
Expand Down Expand Up @@ -79,7 +81,7 @@ def execute(self, context):
class BranchMixin(SkipMixin):
def execute(self, context):
branch = super(BranchMixin, self).execute(context)
if isinstance(branch, six.string_types):
if isinstance(branch, str):
branch = [branch]
self.log.info("Following branch %s", branch)
self.log.info("Marking other directly downstream tasks as skipped")
Expand Down Expand Up @@ -107,7 +109,7 @@ def execute(self, context):


@register_extensions
class BaseDockerOperator(object):
class DockerOperator(AirflowDockerOperator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't going to work with the sensor I don't think

"""
Execute a command inside a docker container.

Expand All @@ -127,31 +129,31 @@ class BaseDockerOperator(object):
:param api_version: Remote API version. Set to ``auto`` to automatically
detect the server's version.
:type api_version: str
:param auto_remove: Auto-removal of the container on daemon side when the
container's process exits.
The default is True.
:type auto_remove: bool
:param command: Command to be run in the container. (templated)
:type command: str or list
:param container_name: Name of the container. Optional (templated)
:type container_name: str or None
:param cpus: Number of CPUs to assign to the container.
This value gets multiplied with 1024. See
https://docs.docker.com/engine/reference/run/#cpu-share-constraint
:type cpus: float
:param dns: Docker custom DNS servers
:type dns: list of strings
:param dns_search: Docker custom DNS search domain
:type dns_search: list of strings
:param docker_url: URL of the host running the docker daemon.
Default is unix://var/run/docker.sock
:type docker_url: str
:param environment: Environment variables to set in the container. (templated)
:type environment: dict
:param force_pull: Pull the docker image on every run. Default is True.
:param private_environment: Private environment variables to set in the container.
These are not templated, and hidden from the website.
:type private_environment: dict
:param force_pull: Pull the docker image on every run. Default is False.
:type force_pull: bool
:param mem_limit: Maximum amount of memory the container can use.
Either a float value, which represents the limit in bytes,
or a string like ``128m`` or ``1g``.
:type mem_limit: float or str
:param host_tmp_dir: Specify the location of the temporary directory on the host which will
be mapped to tmp_dir. If not provided defaults to using the standard system temp directory.
:type host_tmp_dir: str
:param network_mode: Network mode for the container.
:type network_mode: str
:param tls_ca_cert: Path to a PEM-encoded certificate authority
Expand All @@ -176,65 +178,51 @@ class BaseDockerOperator(object):
:type user: int or str
:param volumes: List of volumes to mount into the container, e.g.
``['/host/path:/container/path', '/host/path2:/container/path2:ro']``.
:type volumes: list
:param working_dir: Working directory to
set on the container (equivalent to the -w switch the docker client)
:type working_dir: str
:param xcom_push: Does the stdout will be pushed to the next step using XCom.
The default is False.
:type xcom_push: bool
:param xcom_all: Push all the stdout or just the last line.
The default is False (last line).
:type xcom_all: bool
:param docker_conn_id: ID of the Airflow connection to use
:type docker_conn_id: str
:param dns: Docker custom DNS servers
:type dns: list[str]
:param dns_search: Docker custom DNS search domain
:type dns_search: list[str]
:param auto_remove: Auto-removal of the container on daemon side when the
container's process exits.
The default is False.
:type auto_remove: bool
:param shm_size: Size of ``/dev/shm`` in bytes. The size must be
greater than 0. If omitted uses system default.
:type shm_size: int
:param tty: Allocate pseudo-TTY to the container
This needs to be set see logs of the Docker container.
:type tty: bool

:param provide_context: If True, make a serialized form of the context available.
:type provide_context: bool

:param environment_preset: The name of the environment-preset to pull from the config.
If omitted defaults to the "default" key, see `EnvironmentPresetExtension`.
:type environment_preset: string
"""

template_fields = ("command", "environment", "extra_kwargs")
template_ext = (".sh", ".bash")
template_fields = ('command', 'environment', 'container_name', "extra_kwargs")
known_extra_kwargs = set()

@apply_defaults
def __init__(
self,
image,
api_version=None,
entrypoint=None,
command=None,
cpus=1.0,
docker_url="unix://var/run/docker.sock",
environment=None,
force_pull=get_boolean_default("force_pull", True),
mem_limit=None,
network_mode=get_default("network_mode", None),
tls_ca_cert=None,
tls_client_cert=None,
tls_client_key=None,
tls_hostname=None,
tls_ssl_version=None,
tmp_dir="/tmp/airflow",
user=None,
volumes=None,
working_dir=None,
xcom_push=False,
xcom_all=False,
docker_conn_id=None,
dns=None,
dns_search=None,
auto_remove=get_boolean_default("auto_remove", True),
shm_size=None,
provide_context=False,
*args,
**kwargs
):
self,
image: str,
force_pull: bool = get_boolean_default("force_pull", True),
network_mode: Optional[str] = get_default("network_mode", None),
auto_remove: bool = get_boolean_default("auto_remove", True),
provide_context=False,
*args,
**kwargs) -> None:

self.extra_kwargs = {
known_key: kwargs.pop(known_key)
for known_key in self.known_extra_kwargs
Expand All @@ -243,148 +231,26 @@ def __init__(
if known_key in kwargs
}

super(BaseDockerOperator, self).__init__(*args, **kwargs)
self.api_version = api_version
self.auto_remove = auto_remove
self.command = command
self.entrypoint = entrypoint
self.cpus = cpus
self.dns = dns
self.dns_search = dns_search
self.docker_url = docker_url
self.environment = environment or {}
self.force_pull = force_pull
self.image = image
self.mem_limit = mem_limit
self.network_mode = network_mode
self.tls_ca_cert = tls_ca_cert
self.tls_client_cert = tls_client_cert
self.tls_client_key = tls_client_key
self.tls_hostname = tls_hostname
self.tls_ssl_version = tls_ssl_version
self.tmp_dir = tmp_dir
self.user = user
self.volumes = volumes or []
self.working_dir = working_dir
self.xcom_push_flag = xcom_push
self.xcom_all = xcom_all
self.docker_conn_id = docker_conn_id
self.shm_size = shm_size
self.provide_context = provide_context

self.cli = None
self.container = None
self._host_client = None # Shim for attaching a test client
super().__init__(*args, force_pull=force_pull, network_mode=network_mode, auto_remove=auto_remove, **kwargs)

def get_hook(self):
return DockerHook(
docker_conn_id=self.docker_conn_id,
base_url=self.docker_url,
version=self.api_version,
tls=self.__get_tls_config(),
)
self._host_client = None # Shim for attaching a test client

def _execute(self, context):
self.log.info("Starting docker container from image %s", self.image)
def execute(self, context):
# Hook for creating mounted meta directories
self.prepare_host_tmp_dir(context, self.host_tmp_dir)
self.prepare_environment(context, self.host_tmp_dir)

tls_config = self.__get_tls_config()
if self.provide_context:
self.write_context(context, self.host_tmp_dir)

if self.docker_conn_id:
self.cli = self.get_hook().get_conn()
else:
self.cli = APIClient(
base_url=self.docker_url, version=self.api_version, tls=tls_config
)
super().execute(context)

if self.force_pull or len(self.cli.images(name=self.image)) == 0:
self.log.info("Pulling docker image %s", self.image)
for l in self.cli.pull(self.image, stream=True):
output = json.loads(l.decode("utf-8").strip())
if "status" in output:
self.log.info("%s", output["status"])

with TemporaryDirectory(
prefix="airflowtmp", dir=self.host_tmp_base_dir
) as host_tmp_dir:
self.environment["AIRFLOW_TMP_DIR"] = self.tmp_dir
additional_volumes = ["{0}:{1}".format(host_tmp_dir, self.tmp_dir)]

# Hook for creating mounted meta directories
self.prepare_host_tmp_dir(context, host_tmp_dir)
self.prepare_environment(context, host_tmp_dir)

if self.provide_context:
self.write_context(context, host_tmp_dir)

self.container = self.cli.create_container(
command=self.get_command(),
entrypoint=self.entrypoint,
environment=self.environment,
host_config=self.cli.create_host_config(
auto_remove=self.auto_remove,
binds=self.volumes + additional_volumes,
network_mode=self.network_mode,
shm_size=self.shm_size,
dns=self.dns,
dns_search=self.dns_search,
cpu_shares=int(round(self.cpus * 1024)),
mem_limit=self.mem_limit,
),
image=self.image,
user=self.user,
working_dir=self.working_dir,
)
self.cli.start(self.container["Id"])

line = ""
for line in self.cli.logs(container=self.container["Id"], stream=True):
line = line.strip()
if hasattr(line, "decode"):
line = line.decode("utf-8")
self.log.info(line)

result = self.cli.wait(self.container["Id"])
if result["StatusCode"] != 0:
raise AirflowException("docker container failed: " + repr(result))

# Move the in-container xcom-pushes into airflow.
result = self.host_client.get_xcom_push_data(host_tmp_dir)
for row in result:
self.xcom_push(context, key=row["key"], value=row["value"])

if self.xcom_push_flag:
return (
self.cli.logs(container=self.container["Id"])
if self.xcom_all
else str(line)
)
# Move the in-container xcom-pushes into airflow.
result = self.host_client.get_xcom_push_data(self.host_tmp_dir)
for row in result:
self.xcom_push(context, key=row["key"], value=row["value"])

return self.do_meta_operation(context, host_tmp_dir)

def get_command(self):
if self.command is not None and self.command.strip().find("[") == 0:
commands = ast.literal_eval(self.command)
else:
commands = self.command
return commands

def on_kill(self):
if self.cli is not None:
self.log.info("Stopping docker container")
self.cli.stop(self.container["Id"])

def __get_tls_config(self):
tls_config = None
if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key:
tls_config = tls.TLSConfig(
ca_cert=self.tls_ca_cert,
client_cert=(self.tls_client_cert, self.tls_client_key),
verify=True,
ssl_version=self.tls_ssl_version,
assert_hostname=self.tls_hostname,
)
self.docker_url = self.docker_url.replace("tcp://", "https://")
return tls_config
return self.do_meta_operation(context, self.host_tmp_dir)

def do_meta_operation(self, context, host_tmp_dir):
pass
Expand All @@ -400,13 +266,6 @@ def prepare_host_tmp_dir(self, context, host_tmp_dir):
def write_context(self, context, host_tmp_dir):
self.host_client.write_context(context, host_tmp_dir)

@property
def host_tmp_base_dir(self):
try:
return conf.get("worker", "host_temporary_directory")
except AirflowConfigException:
return DEFAULT_HOST_TEMPORARY_DIRECTORY

def host_meta_dir(self, context, host_tmp_dir):
return airflow_docker_helper.get_host_meta_path(host_tmp_dir)

Expand Down