#!/usr/bin/python3
#
# Univention Salt Kerberos Keys
#  Tool that adds a Kerberos salt to all keys found in the LDAP Backend
#
# SPDX-FileCopyrightText: 2010-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

import sys
from urllib.parse import quote

import heimdal
import ldap

import univention.config_registry


def salt_krb5Keys(principal, keys):
    context = heimdal.context()
    new_keys = []
    for k in keys:
        (keyblock, salt, kvno) = heimdal.asn1_decode_key(k)
        if salt.saltvalue():
            return []
        krb5_principal = heimdal.principal(context, principal)
        krb5_salt = heimdal.salt(context, krb5_principal)
        new_keys.append(heimdal.asn1_encode_key(keyblock, krb5_salt, kvno))
    return new_keys


class LdapConnection:

    def __init__(self, location, port, binddn, bindpw, protocol='ldaps'):
        if protocol.lower() == 'ldapi':
            location = quote(location)
        uri = '%s://%s:%s' % (protocol, location, port)

        try:
            self.lo = ldap.initialize(uri)
            self.lo.simple_bind_s(binddn, bindpw)
        except ldap.LDAPError as exc:
            print(exc)
            raise

    def ldapsearch_async(self, base, scope=ldap.SCOPE_SUBTREE, ldapfilter='(objectClass=*)', attrlist=None):
        timeout = 0
        try:
            result_id = self.lo.search(base, scope, ldapfilter, attrlist)
            while True:
                result_type, result_data = self.lo.result(result_id, timeout)
                if not result_data:
                    break
                if result_type == ldap.RES_SEARCH_ENTRY:
                    yield result_data[0]
        except ldap.LDAPError as exc:
            print(exc)
            raise

    def ldapmodify_object(self, dn, object_dict):
        try:
            modlist = [
                (ldap.MOD_REPLACE, attr, value)
                for attr, value in object_dict.items()
            ]
            return self.lo.modify_s(dn, modlist)
        except ldap.LDAPError as exc:
            print(exc)
            raise


def main():
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument("--binddn", help="binddn")
    parser.add_argument("--bindpwd", help="bindpwd")
    parser.add_argument("--bindpwdfile", help="bindpwdfile")
    options = parser.parse_args()

    ucr = univention.config_registry.ConfigRegistry()
    ucr.load()
    ldap_base = ucr['ldap/base']

    if not options.binddn:
        if ucr['server/role'] not in ('domaincontroller_master', 'domaincontroller_backup'):
            print("salt_krb5Keys: Without explicit credentials this only works on a Primary or Backup Directory Node.")
            return 1
        else:
            options.binddn = "cn=admin,%s" % ldap_base
            options.bindpwd = open('/etc/ldap.secret').read().strip()

    if options.bindpwdfile:
        options.bindpwd = open(options.bindpwdfile).read().strip()

    ldapfilter = "(objectClass=krb5Principal)"
    attrlist = ['krb5PrincipalName', 'krb5Key']

    ldaps_master_port = ucr.get('ldap/master/port', "7636")
    if ldaps_master_port == "7389":
        ldaps_master_port = "7636"

    lc = LdapConnection(ucr['ldap/master'], ldaps_master_port, options.binddn, options.bindpwd, 'ldaps')
    for dn, object_dict in lc.ldapsearch_async(ldap_base, ldapfilter=ldapfilter, attrlist=attrlist):
        if 'krb5Key' in object_dict:
            mod_krb5Keys = salt_krb5Keys(object_dict['krb5PrincipalName'][0].decode('UTF-8'), object_dict['krb5Key'])
            if mod_krb5Keys:
                lc.ldapmodify_object(dn, {'krb5Key': mod_krb5Keys})

    return 0


if __name__ == '__main__':
    sys.exit(main())
