#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# This file is part of cepces.
#
# cepces is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cepces is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with cepces.  If not, see <http://www.gnu.org/licenses/>.
#
# pylint: disable=broad-except,invalid-name

"""
This is a submission helper for automatic user certificate enrollment.

It requires that you have a valid kerberos ticket in your credential cache (check with `klist`).
This is normally automatically created during login with a domain account by SSSD.
You can manually acquire a kerberos ticket via `kinit username@DOMAIN.TLD`.
"""

import os
import sys
import time
from cryptography import x509
from datetime import datetime, timezone
from cepces.user import UserEnrollment, ApprovalPendingException


def sleep_or_exit(interval):
    if interval < 1:
        sys.exit(0)
    else:
        time.sleep(interval)

if __name__ == "__main__":
    g_overrides = {}
    k_overrides = {
        # this script is intended to use as normal user with
        # existing ccache created during login, e.g. /tmp/krb5cc_1000
        # so we disable the ccache creation feature and
        # pass empty principals to use the default one
        "ccache": "False",
        "principals": "",
    }
    user_enrollment = UserEnrollment(g_overrides, k_overrides)
    key_file, cert_file, req_file, profile, renew_days, key_size = user_enrollment.service._config.get_user_config()
    poll_interval = int(user_enrollment.service._config.poll_interval)

    while True:
        try:
            if os.path.isfile(req_file):
                print(f"Found pending request file {req_file}, checking approval...")
                with open(req_file, "r") as f:
                    lines = f.readlines()
                    request_id = lines[0].strip()
                    reference = lines[1].strip()
                user_enrollment.poll(cert_file, request_id, reference)
                os.unlink(req_file)

            elif os.path.isfile(cert_file):
                print(f"Found cert file {cert_file}, checking expiration...")
                with open(cert_file, "rb") as f:
                    cert = x509.load_pem_x509_certificate(f.read())
                    if hasattr(cert, 'not_valid_after_utc'):
                        remaining_days = (cert.not_valid_after_utc - datetime.now(timezone.utc)).days
                    else:
                        remaining_days = (cert.not_valid_after - datetime.now()).days
                    if renew_days > 0 and remaining_days < renew_days:
                        print(f"Expires in {remaining_days}d (< {renew_days}d), trying to request a new cert...")
                        user_enrollment.request(key_file, cert_file, profile, key_size, None)
                    else:
                        print(f"Expires in {remaining_days}d (> {renew_days}d), nothing to do.")
                        sleep_or_exit(poll_interval)
                        continue

            else:
                print(f"Cert file {cert_file} not found, try requesting one...")
                user_enrollment.request(key_file, cert_file, profile, key_size, None)

            sleep_or_exit(poll_interval)

        except ApprovalPendingException as e:
            # output the "cookie" that can be used to later poll the status
            print(f"Certificate approval pending (ID: {e.request_id}), trying to poll in {poll_interval}s.")
            # Create request file with secure permissions (0600 = user read/write only)
            try:
                fd = os.open(req_file, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o600)
                with open(fd, "w") as f:
                    f.write(str(e.request_id)+"\n"+e.reference)
            except OSError as write_error:
                print(f"Error: Failed to write request file: {write_error}", file=sys.stderr)
            sleep_or_exit(poll_interval)

        except Exception as e:
            print("Error:", e, f"(trying again in {poll_interval}s)", file=sys.stderr)
            sleep_or_exit(poll_interval)
