#!/usr/bin/python2
#
#********************************************************
#
# Name: docker-starter.py
#
# (c) Copyright International Business Machines Corp 2017.
# US Government Users Restricted Rights - Use, duplication or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.
#
# This is a script to start a job. This script should be alive during the job execution. 
# The starter must wait for the job file to be done. The starter script sets the exit value of the job.
#********************************************************

import sys
import os
import math
import logging
import subprocess
import signal
import string
import time
import resource
import json
import fcntl
import socket
import shutil
import shlex

#append lsfdockerlib into sys.path
lsfserverdir = os.environ.get("LSF_SERVERDIR", None)

if lsfserverdir == None:
    sys.stderr.write("LSF_SERVERDIR is not defined, set LSF environment variable and try again\n")
    os._exit(254)

try:
    sys.path.append(lsfserverdir)
    from lsfdockerlib import *
except Exception as e:
    sys.stderr.write("Import error: %s, try lsfcontainerlib\n" % str(e))

    try:
        from lsfcontainerlib import *
    except Exception as e:
        sys.stderr.write("Import error: %s\n" % str(e))
        os._exit(253)

LSB_RELEASE_DATE = "2020-12-04 10:36:09"
LSB_BUILD_NUMBER= "564512"

#LSB_DOCKER_PLACE_HOLDER, used to support docker entry point job without any options/params
#like:
#docker run -it docker-temp.phibred.com/hello-world-dsi 
#in LSF, just run below:
#bsub -app docker -env LSB_CONTAINER_IMAGE=docker-temp.phibred.com/hello-world-dsi LSB_DOCKER_PLACE_HOLDER
#note: the app docker must be configured with docker exec_drive support
LSB_DOCKER_PLACE_HOLDER = "LSB_DOCKER_PLACE_HOLDER"


logger = None
def is_passenv():
    is_passenv = os.environ.get('LSF_DOCKER_ENV_UNSET', None)
    if is_passenv in [ 'y', 'Y' ]:
        return False

    return True

class DockerStarter:
    def __init__(self, argv):
        self.uuidfile = os.environ.get('LSB_EXEC_DRIVER_ID', None)
        self.options = os.environ.get('LSB_CONTAINER_OPTIONS', None)
        self.docker_image = os.environ.get('LSB_DOCKER_IMAGE', None)
        self.docker_wrapper = os.environ.get('LSB_DOCKER_CLASS', None)
        self.envbak = os.environ
        self.binary_prefix = '/usr/bin'
        self.cgroup_type = 'systemd'
        self.limit_options = os.environ.get('LSB_DOCKER_LIMITS_OPTIONS', None)

        if self.docker_wrapper is None:
            self.binary_name = os.environ.get('LSB_CONTAINER_NAME', None)
        else:
            self.binary_name = self.docker_wrapper
            os.putenv('NV_HOST', '')
            os.putenv('NV_DOCKER', '')
            cuda_orig_dev = os.environ.get('CUDA_VISIBLE_DEVICES_ORIG', None)
            if cuda_orig_dev:
                os.putenv('NV_GPU', cuda_orig_dev)
        self.extra_opts = []
        devices = os.environ.get('CUDA_VISIBLE_DEVICES')
        if not devices:
            devices = os.environ.get('CUDA_VISIBLE_DEVICES_ORIG')
        if devices and len(devices.strip()) > 0:
            self.extra_opts.extend(["--env", "NVIDIA_VISIBLE_DEVICES=%s" % (devices.strip())])
        else:
            self.extra_opts.extend(["--env", "NVIDIA_VISIBLE_DEVICES=none"])
        self.argv = argv

    def clean(self):
        uuidfile = self.envbak.get('LSB_EXEC_DRIVER_ID', None)
        inputfile = self.envbak.get("LSB_CONTAINER_USER_INPUT", None)

        try:
            if uuidfile:
                os.unlink(uuidfile)
            if inputfile:
                os.unlink(inputfile)
        except OSError:
            pass

        return

    def inspect_image(self):
        """
        the function do below routine:
        1. check whether the image contains an entry point or not
        2. and merge image environment 
        """

        #should check whether image exists or not
        dbgmsg("history image.")
        cmd = "%s/%s history %s" % (self.binary_prefix, self.binary_name, self.docker_image)

        try:
            out, err, rc = runcmd(cmd)
        except Exception as e:
            errmsg("%s\n" % str(e))
            lsbexit(-1)

        if (rc != 0) or (out is None):
            #pull image first
            try:
                out, err, rc = self.pull_image()
            except Exception as e:
                errmsg("pull image %s exception: %s" % (self.docker_image, str(e)))
                lsbexit(-1)

            if rc != 0:
                errmsg("pull image %s failed:%s" % (self.docker_image, "\n".join(err)))
                lsbexit(-1)
        
        dbgmsg("inspect image.")
        cmd = "%s/%s inspect %s" % (self.binary_prefix, self.binary_name, self.docker_image)

        try:
            out, err, rc = runcmd(cmd)
        except Exception as e:
            errmsg(str(e))
            lsbexit(-1)

        if rc != 0:
            lsbexit(-1)
        
        if (len(out) < 1):
            lsbexit(-1)

        jimages = json.loads(str( "\n".join(out)))
        if jimages is None or len(jimages) == 0:
            return False
        
        jimage = jimages[0]
        #dbgmsg("jimage: %s" % (str(jimage)))

        #merge environment variables
        self.merge_image_env(jimage)

        if self.options: 
            options = self.options.split(" ")
            if options:
                for option in options:
                    opt = option.strip()
                    if opt.startswith("--entrypoint"):
                        return True
                        
        if (('Config' not in  jimage.keys()) or ('Entrypoint' not in jimage['Config'].keys())):
            return False

        epoint = jimage['Config']['Entrypoint']

        if epoint is None:
            dbgmsg("entry point is None")
            return False
        
        dbgmsg("entry point: %s" % epoint)
        return True

    def trim_options(self):
        #check whether passenv, if pass, do nothing, if not, remove --env from self.options

        if is_passenv():
            return self.options.split()

        newoptions = []
        
        dbgmsg("self.options: %s" % self.options)
        oldoptions = self.options.split()

        #example opitons:
        #-u 0:0 -v /scratch/dev/xianwu/env/ulaworks002:/scratch/dev/xianwu/env/ulaworks002 
        #-v /root:/root -w /root/ -v /root/.lsbatch:/root/.lsbatch -v /tmp:/tmp --env-file=/tmp/.sbd.env.ulaworks002.691 --rm

        skipnext = False
        for option in oldoptions:
            if skipnext:
                skipnext = False
                continue

            if option.strip() == "--env":
                skipnext = True
                continue

            opt = option.strip()
            if opt.startswith("--env-file="):
                envfile = opt.split("=")[1]
                dbgmsg("option --env-file = %s, skipped" % envfile)
                continue

            newoptions.append(option.strip())

        if len(newoptions) < 1:
            return []

        #dbgmsg("new options: %s" % " ".join(newoptions))
        return newoptions



    def get_user_input(self):
        dbgmsg("get user input.")

        inputfile = os.environ.get("LSB_CONTAINER_USER_INPUT", None)

        if not os.path.exists(inputfile):
            return None
        
        content =  []
        fp = None

        try:
            fp = open(inputfile, 'r')
            content = fp.readlines()
        except Exception as e:
            raise Exception("cannot open inputfile: %s" % (str(e)))
        finally:
            if fp != None:
                fp.close()

        if len(content) == 0:
            return None

        return " ".join(content)
    
    def removeuuid(self):
        """
        used to handle brequeue case.
        """
        
        # uuid file like "lsf.ulaworks008.job.4006.1595437295.task.2.1595437300532899"
        tmplist = self.uuidfile.split(".")
        olduuid = ".".join(tmplist[0:len(tmplist) - 1]) 

        runcmd("rm -fr /tmp/%s*" % olduuid)


    def run_job(self):
        dbgmsg("run docker job.")

        if None in [self.uuidfile, self.options, self.docker_image, self.binary_name]:
            return -1

        is_entry_point_image = self.inspect_image()

        cmd = ['%s/%s' % (self.binary_prefix, self.binary_name),  'run', '--cidfile', self.uuidfile]

        self.removeuuid()

        if self.cgroup_type == 'systemd' and (not is_str_empty(self.limit_options)):
            dbgmsg("limit options: %s; cgroup type: %s" % (str(self.limit_options), str(self.cgroup_type)))
            cmd.extend(self.limit_options.split())
            
        taskid = os.getenv("LSF_PM_TASKID")

        # for podman only add --userns=keep-id
        if ispodman():
            self.options = "%s %s" % (self.options,  "--userns=keep-id")

        cmd.extend(self.trim_options())

        dbgmsg("docker extra_opts: %s, type: %s" % (self.extra_opts, str(type(self.extra_opts))))
        if len(self.extra_opts):
            cmd.extend(self.extra_opts)

        if ENABLE_CONTAINER_NAME:
            namestr = ""

            container_name = os.environ.get("LSB_DOCKER_NAME", None);
            #env container name string:
            #normal job: LSB_DOCKER_NAME=job.5811
            #blaunch job: LSB_DOCKER_NAME=job.5811
            #array job: LSB_DOCKER_NAME=job.5811.1
            
            #blaunch task id like blow:
            #LSF_PM_TASKID=2
            taskid = os.environ.get("LSF_PM_TASKID", None)

            #cluster name: LSB_EXEC_CLUSTER=ulaworks001
            cluster_name = os.environ.get("LSB_EXEC_CLUSTER", None);

            if None in [ container_name, cluster_name ]:
                errmsg("cannot get container name or cluster name")

            namestr = "%s.%s" % (cluster_name, container_name)

            if ((taskid != None) and (int(taskid) > 0)):
                namestr = "%s.task.%s" % (namestr, taskid)

            dbgmsg("Adding container name option: %s" % namestr)

            cmd.extend(["--name", namestr])

        # SPILLBOX fix 
        exec_user = os.environ.get('LSB_EXEC_DRIVER_USER', None)
        cmd.extend(["--user", exec_user])
        if os.environ.get("SPILLBOX_VNC") is not None:
            cmd.extend(["-p", "3389", "--shm-size=1g", "--detach", "--volume", "/run/dbus:/run/dbus", "--volume", "/sys/fs/cgroup:/sys/fs/cgroup"])
        else:
            cmd.extend(["--net=host"])

        if os.environ.get("SPLAPP_INTERACTIVE_RUN") is not None:
            if os.environ.get("SPLAPP_LSF_ATTACH_TTY") is not None:
                cmd.extend(["-it"])
            else:
                cmd.extend(["-i"])

        cmd.append(self.docker_image)

        if is_entry_point_image:
            try:
                userinput = self.get_user_input()
            except Exception as e:
                errmsg(str(e))
                return -1

            if userinput == None:
                return -1

            dbgmsg("user input: %s" % userinput)

            if userinput.strip() != LSB_DOCKER_PLACE_HOLDER:
                cmd.extend(userinput.split())
        else:
            cmd.extend(self.argv)

        #dbgmsg("is_passenv: %s" % str(is_passenv()))
        if not is_passenv():
            os.environ = {}
        


        dbgmsg("final cmd: %s" % str(cmd))

        oldmask = None
        try:
            oldmask = os.umask(0o022)
            if os.environ.get("SPILLBOX_VNC") is not None:
                f = open(os.devnull, 'w')
                proc = subprocess.Popen(cmd, stdout=f)
            else:
                if os.environ.get("SPLAPP_LSF_ATTACH_TTY") is not None:
                    cmd_str = " ".join(cmd)
                    #cmd_str = " ".join(shlex.quote(c) for c in cmd)
                    cmdtty = ["/usr/bin/script", "-q", "-f", "/dev/null", "-c", cmd_str]
                    #cmdtty = ["/usr/bin/script", "-q", "-f", "/dev/null", "-c", shlex.quote(cmd_str)]
                    proc = subprocess.Popen(cmdtty)
                else:
                    proc = subprocess.Popen(cmd)

            pid, status = os.waitpid(proc.pid, 0)
            if os.environ.get("SPILLBOX_VNC") is not None:
                try:
                    with open(uuidfile, 'r') as file:
                        content = file.read()
                        cmd = "%s/%s port %s | awk '{print $3}' | awk -F: '{print $2}'" % (self.binary_prefix, self.binary_name, content)
                        try:
                            out, err, rc = runcmd(cmd)
                        except Exception as e:
                            errmsg("%s\n" % str(e))
                            lsbexit(-1)

                        if (rc == 0) or (out is not None):
                            vnc = "Remote Desktop Server : %s:%s\n" % (socket.getfqdn(), str(out[0]))
                            sys.stderr.write(vnc)
                except FileNotFoundError:
                    sys.stderr.write("Error: Docker UUID file not found.")

        except Exception as e:
            self.clean()
            os.umask(oldmask)
            return -1

        if os.WIFSIGNALED(status):
            os.kill(os.getpid(), os.WTERMSIG(status))

        if os.WIFEXITED(status):
            ret = os.WEXITSTATUS(status)
        else:
            ret = proc.returncode

        self.clean()
        os.umask(oldmask)

        return ret

    def merge_image_env(self, jimage):
        dbgmsg("merge image env.")
        if ('Config' not in jimage.keys()) or ('Env' not in jimage['Config'].keys()):
            return

        envs = jimage['Config']['Env']

        dbgmsg("envs: %s" % str(envs))

        
        # SPILLBOX fix
        if envs is None:
            return

        for env in envs:
            if len(env) < 1:
                continue
            
            index = env.find("=")
            if index < 0:
                dbgmsg("cannot find '=' in env: %s" % env)
                continue

            envname = env[0:index]
            if envname in ['PATH', 'LD_LIBRARY_PATH', 'LD_PRELOAD']:
                
                #handle LD_LIBRARY_PATH
                if envname == "LD_LIBRARY_PATH":
                    lsflibdir = os.environ.get("LSF_LIBDIR", None)
                    if lsflibdir is not None:
                        env = "%s:%s" %(env, lsflibdir)

                #get system environment value and merge them together
                sysenv = os.environ.get(envname)
                if (sysenv is not None) and (len(sysenv) > 0):
                    self.extra_opts.extend(["--env", "%s:%s" % (env, sysenv)])
                else:
                    self.extra_opts.extend(["--env", env])

        dbgmsg("extra_opts: %s" % self.extra_opts)

    def update_image(self):
        dbgmsg("update image.")
        isupdate = os.environ.get('LSB_CONTAINER_IMAGE_UPDATE', None)

        if (isupdate is not None) and isupdate in ['y', 'Y']:
            try:
                out, err, ret = self.pull_image()
            except Exception as e:
                errmsg("pull image %s exception: %s" % (self.docker_image, str(e)))

            if ret != 0 :
                conclusion="Update docker image (%s) failed." % self.docker_image
                jobid = os.environ.get('LSB_BATCH_JID', None)
                if jobid is not None:

                    cmd = 'bpost -d "%s. %s" %s' % (err.replace("\n", ""), conclusion, jobid)
                    try:
                        runcmd(cmd)
                    except:
                        pass

                lsbexit(ret)

    def pull_image(self):
        dbgmsg("pull image.")
        cmd = "%s/%s pull %s" % (self.binary_prefix, self.binary_name, self.docker_image)

        try:
            out, err, ret = runcmd(cmd)
        except Exception as e:
            errmsg("pull image exception: %s\nout: %s\nerr: %s\n" % (str(e), str(out), str(err)))
            lsbexit(-1)

        if ret != 0:
            errmsg("pull image error out: %s\n, err: %s\n" % (str(out), str(err)))
            lsbexit(-1)

        return out, err, ret

    def add_cgroup_parent(self):
        dbgmsg("add cgroup parent to docker job if it is not managed by systemd")
        if is_str_empty(self.binary_name):
            errexit("Unable to get contaner name: LSB_CONTAINER_NAME or LSB_DOCKER_CLASS")

        cgtype = CGHelper.get_docker_cgroup_type(self.binary_prefix, self.binary_name)
        if cgtype == "cgroupfs":
            self.cgroup_type = 'cgroupfs'
            cg_parent_dir = os.environ.get("LSB_JOB_CGROUP_PARENT", None)
            self.extra_opts.extend(["--cgroup-parent", cg_parent_dir])
        elif cgtype == "systemd":
            dbgmsg("The cgroupdriver of docker daemon is managed by systemd, no cgroup parent shall be added")

def sigexit(signum, frame):
    lsbexit(-signum)

def is_service_running(service_name):
    try:
        result = subprocess.Popen(
            ["sudo", "systemctl", "is-active", service_name],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        stdout, _ = result.communicate()
        return stdout.strip() == "active"
    except OSError as e:
        print("OS error: {}".format(e))
        return False

def get_os_id():
    try:
        with open("/etc/os-release") as f:
            for line in f:
                if line.startswith("ID="):
                    return line.strip().split("=")[1].strip('"')
    except FileNotFoundError:
        return None

def install_cachefilesd():
    if not os.access("/sbin/cachefilesd", os.X_OK):
        os_id = get_os_id()
        apt = "apt-get"
        if os_id in ("centos", "rhel", "rocky", "amzn"):
            apt = "yum"
        with open("/dev/null", "w") as outfile:
            subprocess.call(["sudo", apt, "install", "-y", "cachefilesd"], stdout=outfile, stderr=outfile)

def start_service(service_name):
    if service_name == "cachefilesd":
        install_cachefilesd()
        spillbox_cache_dir = os.getenv('SPILLBOX_BOOST_CLIENT_CACHE_DIR', '/default/path')
        full_path = os.path.join(spillbox_cache_dir, 'spillbox_cache')
        try:
            subprocess.check_call(['sudo', 'mkdir', '-p', full_path])
        except subprocess.CalledProcessError as e:
            print("Failed to create directory {}. Error: {}".format(full_path, e))

        content = """dir {}/spillbox_cache
brun 10%
bcull 7%
bstop 3%
frun 10%
fcull 7%
fstop 3%
#secctx system_u:system_r:cachefiles_kernel_t:s0
""".format(spillbox_cache_dir)
        file_path = '/tmp/cachefilesd.conf'
        try:
            with open(file_path, 'w') as file:
                fcntl.flock(file, fcntl.LOCK_EX)
                file.write(content)
                fcntl.flock(file, fcntl.LOCK_UN)
        except IOError as e:
                print("Failed to write configuration file {}. Error: {}".format(file_path, e))
                return
        try:
            subprocess.check_call(['sudo', 'cp', '-f', file_path, '/etc/cachefilesd.conf'])
        except subprocess.CalledProcessError:
            pass

        try:
            subprocess.check_call(['sudo', 'chmod', '644', '/etc/cachefilesd.conf'])
        except subprocess.CalledProcessError:
            pass

        content = """RUN=yes
DAEMON_OPTS=""
""".format()
        file_path = '/tmp/cachefilesd'
        try:
            with open(file_path, 'w') as file:
                fcntl.flock(file, fcntl.LOCK_EX)
                file.write(content)
                fcntl.flock(file, fcntl.LOCK_UN)
        except IOError as e:
                print("Failed to write configuration file {}. Error: {}".format(file_path, e))
                return
        try:
            subprocess.check_call(['sudo', 'cp', '-f', file_path, '/etc/default/cachefilesd'])
        except subprocess.CalledProcessError:
            pass

        try:
            subprocess.check_call(['sudo', 'chmod', '644', '/etc/default/cachefilesd'])
        except subprocess.CalledProcessError as e:
            print("Failed to update configuration file. Error: {}".format(e))
            return

        with open("/dev/null", "w") as outfile:
            subprocess.call(['sudo', 'setenforce', '0', '||', 'true'], stdout=outfile, stderr=outfile)

        if os.access("/sbin/cachefilesd", os.X_OK):
            modprobe_path = shutil.which("modprobe")
            sudocmd = "sudo"  # or "" if already root

            # Load module if not already loaded
            lsmod = run_cmd("/sbin/lsmod | grep cachefiles", capture_output=True, ignore_errors=True)
            if not lsmod and modprobe_path:
                ret = run_cmd(sudocmd + " " + modprobe_path + " cachefiles", ignore_errors=True)
                if ret is None:
                    mod = run_cmd("find /usr/lib/modules/ -name cachefiles.ko.xz", capture_output=True)
                    if mod:
                        ret = run_cmd(sudocmd + " insmod " + mod.strip(), ignore_errors=True)
                        if ret is None:
                            print("Failed to load fscache module")
    
            # Create /dev/cachefiles if it doesn't exist
            if not os.path.exists("/dev/cachefiles"):
                run_cmd(sudocmd + " mknod /dev/cachefiles c 10 56", ignore_errors=True)
    
            try:
                subprocess.check_call(["sudo", "systemctl", "restart", service_name])
            except subprocess.CalledProcessError as e:
                print("Failed to start {}: {}".format(service_name, e))

        else:
            print("missing /sbin/cachefilesd")

def main(argv):
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    updatets("start")
    docker_starter = DockerStarter(argv)
    updatets("update_image")
    docker_starter.update_image()

    updatets("signal")
    signal.signal(signal.SIGTERM, sigexit)

    if not ispodman():
        updatets("add cgroup parent")
        docker_starter.add_cgroup_parent()

    updatets("inspectimage")
    ret = docker_starter.inspect_image()

    if os.environ.get("SPILLBOX_BOOSTMODE") is not None:
        updatets("starting cachefilesd")
        service_name = "cachefilesd"
        lock_file_path = "/tmp/{service_name}.lock"
        with open(lock_file_path, "w") as lock_file:
            try:
                # Try to acquire a non-blocking exclusive lock
                fcntl.flock(lock_file, fcntl.LOCK_EX)

                if not is_service_running(service_name):
                    start_service(service_name)

                fcntl.flock(file, fcntl.LOCK_UN)
                # Lock is held until we exit the 'with' block
            except:
                # Another process already holds the lock
                pass

    updatets("startjob")
    ret = docker_starter.run_job()

    dbgmsg("exit code: %s" % ret)
    lsbexit(ret)

if __name__ == "__main__":
    if len(sys.argv) == 1:
        ret = show_file_info(sys.argv[0])
        lsbexit(ret)

    if os.environ.get('LSB_EXEC_DRIVER_ID', None) == None:
        lsbexit(-1)

    uuidfile = os.environ.get('LSB_EXEC_DRIVER_ID', None)
    #uuidfile:
    #/tmp/lsf.ulaworks004.job.131.1576083749
    if is_str_empty(uuidfile):
        lsbexit(-1)
    
    filename = uuidfile.split('/')[-1]

    logger = DriverLogger("starter", filename,  "debug")
    #logger.init()
    
    if os.environ.get("XDG_RUNTIME_DIR", None) != None:
        del(os.environ["XDG_RUNTIME_DIR"])

    main(sys.argv[1:])
