/* Copyright (C) 2024 Simo Sorce <simo@redhat.com>
   Copyright 2025 NXP
   SPDX-License-Identifier: Apache-2.0 */

#include "provider.h"

#if SKEY_SUPPORT == 1

#include "cipher.h"
#include "openssl/prov_ssl.h"
#include "openssl/rand.h"
#include <string.h>

#define BITS_TO_BYTES(bits) ((bits + 7) / 8)
#define BYTES_TO_BITS(bytes) ((bytes) * 8)

#define MAX_PADDING 256;
#define AESBLOCK 16 /* 128 bits for all AES modes except GCM */

#define IVSIZE_cbc AESBLOCK
#define IVSIZE_cfb AESBLOCK
#define IVSIZE_cfb1 AESBLOCK
#define IVSIZE_cfb8 AESBLOCK
#define IVSIZE_ctr AESBLOCK
#define IVSIZE_cts AESBLOCK
#define IVSIZE_ecb 0
#define IVSIZE_gcm 12
#define IVSIZE_ofb AESBLOCK
#define IVSIZE_poly1305 12

DISPATCH_CIPHER_FN(cipher, freectx);
DISPATCH_CIPHER_FN(common, dupctx);
DISPATCH_CIPHER_FN(cipher, encrypt_init);
DISPATCH_CIPHER_FN(cipher, decrypt_init);
DISPATCH_CIPHER_FN(cipher, update);
DISPATCH_CIPHER_FN(cipher, final);
DISPATCH_CIPHER_FN(common, cipher);
DISPATCH_CIPHER_FN(common, get_ctx_params);
DISPATCH_CIPHER_FN(common, set_ctx_params);
DISPATCH_CIPHER_FN(common, gettable_ctx_params);
DISPATCH_CIPHER_FN(common, settable_ctx_params);
DISPATCH_CIPHER_FN(cipher, encrypt_skey_init);
DISPATCH_CIPHER_FN(cipher, decrypt_skey_init);

struct p11prov_cipher_ctx {
    P11PROV_CTX *provctx;

    P11PROV_OBJ *key;
    int keysize;

    bool pad;

    CK_MECHANISM mech;
    CK_FLAGS operation;

    P11PROV_SESSION *session;
    enum {
        SESS_UNUSED,
        SESS_INITIALIZED,
        SESS_FINALIZED,
    } session_state;

    /* OpenSSL violates layering separation and decided
     * to process AES CBC MAC/padding handling in TLS 1.x < 1.3
     * in the lower cipher layer, so we have to do it here as well
     * for compatibility ... */
    unsigned int tlsver;
    size_t tlsmacsize;
    unsigned char *tlsmac;

    size_t ivsize;

    unsigned char *aad;
    size_t aadsize;
};

static void *p11prov_cipher_newctx(void *provctx, int size, size_t ivsize,
                                   CK_ULONG mechanism)
{
    P11PROV_CTX *ctx = (P11PROV_CTX *)provctx;
    struct p11prov_cipher_ctx *cctx;

    P11PROV_debug("New Cipher context for mechanism %ld (key size: %d)",
                  mechanism, size);

    cctx = OPENSSL_zalloc(sizeof(struct p11prov_cipher_ctx));
    if (cctx == NULL) {
        return NULL;
    }

    cctx->provctx = ctx;
    cctx->mech.mechanism = mechanism;
    cctx->keysize = size / 8;
    cctx->ivsize = ivsize;

    /* OpenSSL Pads by default */
    cctx->pad = true;

    return cctx;
}

static const OSSL_PARAM cipher_gettable_params[] = {
    OSSL_PARAM_uint(OSSL_CIPHER_PARAM_MODE, NULL),
    OSSL_PARAM_size_t(OSSL_CIPHER_PARAM_KEYLEN, NULL),
    OSSL_PARAM_size_t(OSSL_CIPHER_PARAM_IVLEN, NULL),
    OSSL_PARAM_size_t(OSSL_CIPHER_PARAM_BLOCK_SIZE, NULL),
    OSSL_PARAM_int(OSSL_CIPHER_PARAM_AEAD, NULL),
    OSSL_PARAM_int(OSSL_CIPHER_PARAM_CUSTOM_IV, NULL),
    OSSL_PARAM_int(OSSL_CIPHER_PARAM_CTS, NULL),
    OSSL_PARAM_int(OSSL_CIPHER_PARAM_TLS1_MULTIBLOCK, NULL),
    OSSL_PARAM_int(OSSL_CIPHER_PARAM_HAS_RAND_KEY, NULL),
    OSSL_PARAM_END
};

static const OSSL_PARAM *p11prov_cipher_gettable_params(void *provctx)
{
    return cipher_gettable_params;
}

static struct {
    const char *name;
    int flag;
} param_to_flag[] = {
    { OSSL_CIPHER_PARAM_AEAD, MODE_flag_aead },
    { OSSL_CIPHER_PARAM_CUSTOM_IV, MODE_flag_custom_iv },
    { OSSL_CIPHER_PARAM_CTS, MODE_flag_cts },
    { OSSL_CIPHER_PARAM_TLS1_MULTIBLOCK, MODE_flag_tls1_mb },
    { OSSL_CIPHER_PARAM_HAS_RAND_KEY, MODE_flag_rand_key },
    { NULL, 0 },
};

static int p11prov_cipher_get_params(OSSL_PARAM params[], unsigned int mode,
                                     int flags, size_t keysize, size_t ivsize,
                                     size_t blocksize)
{
    OSSL_PARAM *p;
    int ret;

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_MODE);
    if (p) {
        ret = OSSL_PARAM_set_uint(p, mode);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    for (int i = 0; param_to_flag[i].name != NULL; i++) {
        p = OSSL_PARAM_locate(params, param_to_flag[i].name);
        if (p) {
            int flag = 0;
            if ((flags & param_to_flag[i].flag) != 0) {
                flag = 1;
            }
            ret = OSSL_PARAM_set_int(p, flag);
            if (ret != RET_OSSL_OK) {
                ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
                return RET_OSSL_ERR;
            }
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_KEYLEN);
    if (p) {
        ret = OSSL_PARAM_set_size_t(p, keysize);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_BLOCK_SIZE);
    if (p) {
        ret = OSSL_PARAM_set_size_t(p, blocksize);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_IVLEN);
    if (p) {
        ret = OSSL_PARAM_set_size_t(p, ivsize);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    return RET_OSSL_OK;
}

#define p11prov_aes_get_params p11prov_common_get_params
#define p11prov_chacha20_get_params p11prov_common_get_params

static int p11prov_common_get_params(OSSL_PARAM params[], int size,
                                     size_t ivsize, int mode,
                                     CK_ULONG mechanism)
{
    int ciph_mode = 0;
    int flags = mode & MODE_flags_mask;
    size_t keysize = size / 8;
    size_t blocksize = AESBLOCK;

    switch (mode & MODE_modes_mask) {
    case MODE_ecb:
        ciph_mode = EVP_CIPH_ECB_MODE;
        break;
    case MODE_cbc:
        ciph_mode = EVP_CIPH_CBC_MODE;
        break;
    case MODE_ofb:
        ciph_mode = EVP_CIPH_OFB_MODE;
        break;
    case MODE_cfb:
        ciph_mode = EVP_CIPH_CFB_MODE;
        break;
    case MODE_ctr:
        ciph_mode = EVP_CIPH_CTR_MODE;
        break;
    case MODE_gcm:
        ciph_mode = EVP_CIPH_GCM_MODE;
        flags |= MODE_flag_aead | MODE_flag_custom_iv;
        break;
    case MODE_poly1305:
        ciph_mode = EVP_CIPH_STREAM_CIPHER;
        flags |= MODE_flag_aead | MODE_flag_custom_iv;
        break;
    default:
        ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
        return RET_OSSL_ERR;
    }

    if (flags & MODE_flag_aead) {
        blocksize = 1;
    }

    return p11prov_cipher_get_params(params, ciph_mode, flags, keysize, ivsize,
                                     blocksize);
};

static void p11prov_cipher_free_mech(CK_MECHANISM_PTR mech)
{
    if (!mech->pParameter) {
        return;
    }

    if (mech->mechanism == CKM_AES_GCM) {
        CK_GCM_MESSAGE_PARAMS_PTR gcm =
            (CK_GCM_MESSAGE_PARAMS_PTR)mech->pParameter;

        OPENSSL_clear_free(gcm->pIv, gcm->ulIvLen);
        OPENSSL_clear_free(gcm->pTag, BITS_TO_BYTES(gcm->ulTagBits));
    } else if (mech->mechanism == CKM_CHACHA20_POLY1305) {
        CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR chacha =
            (CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR)mech->pParameter;

        OPENSSL_clear_free(chacha->pNonce, chacha->ulNonceLen);
        OPENSSL_free(chacha->pTag);
    }

    OPENSSL_clear_free(mech->pParameter, mech->ulParameterLen);
}

static void p11prov_cipher_freectx(void *ctx)
{
    struct p11prov_cipher_ctx *cctx = (struct p11prov_cipher_ctx *)ctx;

    if (!cctx) {
        return;
    }

    if (cctx->session) {
        if (cctx->session_state == SESS_INITIALIZED) {
            /* Finalize any operation to avoid leaving a hanging
             * operation on this session. Ignore return errors here
             * intentionally as errors can be returned if the operation was
             * internally finalized because of a previous internal token
             * error state and, in any case, not much to be done. */
            CK_RV ret = CKR_OK;
            CK_SESSION_HANDLE sess = p11prov_session_handle(cctx->session);
            switch (cctx->operation) {
            case CKF_ENCRYPT:
                ret = p11prov_EncryptInit(cctx->provctx, sess, NULL,
                                          CK_INVALID_HANDLE);
                break;
            case CKF_DECRYPT:
                ret = p11prov_DecryptInit(cctx->provctx, sess, NULL,
                                          CK_INVALID_HANDLE);
                break;
            case CKF_MESSAGE_ENCRYPT:
                ret = p11prov_MessageEncryptInit(cctx->provctx, sess, NULL,
                                                 CK_INVALID_HANDLE);
                break;
            case CKF_MESSAGE_DECRYPT:
                ret = p11prov_MessageDecryptInit(cctx->provctx, sess, NULL,
                                                 CK_INVALID_HANDLE);
                break;
            default:
                break;
            }
            if (ret != CKR_OK) {
                /* NSS softokn has a broken interface and is incapable of
                 * dropping operations on sessions returning a generic
                 * CKR_MECHANISM_PARAM_INVALID when the mechanism is set to
                 * NULL. Attempt to force cancellation via C_SessionCancel. */
                ret =
                    p11prov_SessionCancel(cctx->provctx, sess, cctx->operation);
            }
            if (ret != CKR_OK) {
                /* When this happens the session becomes broken as
                 * we can't initialize operations on it anymore */
                p11prov_session_mark_broken(cctx->session);
            }
            cctx->session_state = SESS_FINALIZED;
        }
        p11prov_return_session(cctx->session);
    }

    p11prov_obj_free(cctx->key);
    p11prov_cipher_free_mech(&cctx->mech);
    OPENSSL_clear_free(cctx->tlsmac, cctx->tlsmacsize);
    OPENSSL_clear_free(cctx, sizeof(struct p11prov_cipher_ctx));
}

#define p11prov_aes_dupctx p11prov_common_dupctx
#define p11prov_chacha20_dupctx p11prov_common_dupctx

static void *p11prov_common_dupctx(void *ctx)
{
    return NULL;
}

static int set_iv(struct p11prov_cipher_ctx *ctx, const unsigned char *iv,
                  size_t ivlen)
{
    /* Free parameter first, as OpenSSL apparently can "init" without
     * keys and just set the IV, and then re-init again with the IV
     * or even set the IV again via parameters ... */
    if (ctx->mech.pParameter) {
        OPENSSL_clear_free(ctx->mech.pParameter, ctx->mech.ulParameterLen);
        ctx->mech.pParameter = NULL;
        ctx->mech.ulParameterLen = 0;
    }
    /* If IV is null it means the app is either trying to clear a context
     * for reuse or did the initialization w/o IV and intends to init again
     * or pass the IV via params, ether way just bail out, the mech will
     * fail to initialize later if the application forgets to set the IV
     * and the mechanism requires it */
    if (iv != NULL && ivlen != 0) {
        if (ctx->mech.mechanism == CKM_AES_CTR) {
            if (ivlen > 16) {
                return CKR_MECHANISM_PARAM_INVALID;
            }
            struct CK_AES_CTR_PARAMS *ctr_params =
                OPENSSL_malloc(sizeof(struct CK_AES_CTR_PARAMS));
            if (!ctr_params) {
                return CKR_HOST_MEMORY;
            }
            memcpy(ctr_params->cb, iv, ivlen);
            ctr_params->ulCounterBits = ivlen;
            ctx->mech.pParameter = ctr_params;
            ctx->mech.ulParameterLen = sizeof(struct CK_AES_CTR_PARAMS);
        } else {
            ctx->mech.pParameter = OPENSSL_memdup(iv, ivlen);
            if (!ctx->mech.pParameter) {
                return CKR_HOST_MEMORY;
            }
            ctx->mech.ulParameterLen = ivlen;
        }
    }
    return CKR_OK;
}

static CK_RV p11prov_cipher_prep_gcm(CK_MECHANISM_PTR mech,
                                     const unsigned char *iv, size_t ivlen)
{
    if (!mech) {
        return CKR_ARGUMENTS_BAD;
    }

    if (ivlen > EVP_MAX_IV_LENGTH) {
        return CKR_MECHANISM_PARAM_INVALID;
    }

    if (!mech->pParameter) {
        mech->ulParameterLen = sizeof(CK_GCM_MESSAGE_PARAMS);
        mech->pParameter = OPENSSL_zalloc(mech->ulParameterLen);
        if (!mech->pParameter) {
            return CKR_HOST_MEMORY;
        }
    }

    CK_GCM_MESSAGE_PARAMS_PTR gcm = (CK_GCM_MESSAGE_PARAMS_PTR)mech->pParameter;

    if (iv && ivlen != 0) {
        gcm->pIv = OPENSSL_memdup(iv, ivlen);
        if (!gcm->pIv) {
            return CKR_HOST_MEMORY;
        }
    }

    gcm->ulIvLen = ivlen;
    gcm->ulTagBits = BYTES_TO_BITS(EVP_MAX_AEAD_TAG_LENGTH);
    if (!gcm->pTag) {
        gcm->pTag = OPENSSL_zalloc(EVP_MAX_AEAD_TAG_LENGTH);
        if (!gcm->pTag) {
            return CKR_HOST_MEMORY;
        }
    }

    /* The IV fixed bits and IV generator are only used
     * in the context of TLS 1.2. Mark them as n/a until
     * OpenSSL explicitly sets a (partial) fixed IV. */
    gcm->ulIvFixedBits = 0;
    gcm->ivGenerator = CKG_NO_GENERATE;

    return CKR_OK;
}

static CK_RV p11prov_cipher_prep_chacha20_poly1305(CK_MECHANISM_PTR mech,
                                                   const unsigned char *iv,
                                                   size_t ivlen)
{
    if (!mech) {
        return CKR_ARGUMENTS_BAD;
    }

    if (ivlen != 0 && ivlen != 12) {
        return CKR_MECHANISM_PARAM_INVALID;
    }

    if (!mech->pParameter) {
        mech->ulParameterLen = sizeof(CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS);
        mech->pParameter = OPENSSL_zalloc(mech->ulParameterLen);
        if (!mech->pParameter) {
            return CKR_HOST_MEMORY;
        }
    }

    CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR chacha =
        (CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR)mech->pParameter;

    if (iv && ivlen != 0) {
        chacha->pNonce = OPENSSL_memdup(iv, ivlen);
        if (!chacha->pNonce) {
            return CKR_HOST_MEMORY;
        }
    }

    chacha->ulNonceLen = ivlen;
    if (!chacha->pTag) {
        chacha->pTag = OPENSSL_zalloc(EVP_CHACHAPOLY_TLS_TAG_LEN);
        if (!chacha->pTag) {
            return CKR_HOST_MEMORY;
        }
    }

    return CKR_OK;
}

#define p11prov_aes_set_ctx_params p11prov_common_set_ctx_params
#define p11prov_chacha20_set_ctx_params p11prov_common_set_ctx_params

static int p11prov_common_set_ctx_params(void *vctx, const OSSL_PARAM params[]);

static CK_RV p11prov_cipher_prep_mech(struct p11prov_cipher_ctx *ctx,
                                      const unsigned char *iv, size_t ivlen,
                                      const OSSL_PARAM params[])
{
    bool param_as_iv = false;
    CK_RV rv = CKR_OK;
    int ret;

    switch (ctx->mech.mechanism) {
    case CKM_AES_ECB:
        /* ECB has no ck params */
        break;

    case CKM_AES_CBC:
    case CKM_AES_CBC_PAD:
    case CKM_AES_CTS:
    case CKM_AES_OFB:
    case CKM_AES_CFB128:
    case CKM_AES_CFB1:
    case CKM_AES_CFB8:
    case CKM_AES_CTR:
        param_as_iv = true;
        break;

    case CKM_AES_GCM:
        return p11prov_cipher_prep_gcm(&ctx->mech, iv, ivlen);

    case CKM_CHACHA20_POLY1305:
        return p11prov_cipher_prep_chacha20_poly1305(&ctx->mech, iv, ivlen);

    default:
        P11PROV_debug("invalid mechanism (ctx=%p, iv=%p, "
                      "ivlen=%lu, params=%p)",
                      ctx, iv, ivlen, params);
        return CKR_MECHANISM_INVALID;
    }

    if (param_as_iv) {
        rv = set_iv(ctx, iv, ivlen);
        if (rv != CKR_OK) {
            return rv;
        }
    }

    ret = p11prov_common_set_ctx_params(ctx, params);
    if (ret != RET_OSSL_OK) {
        P11PROV_debug("invalid mechanism param (ctx=%p, iv=%p, "
                      "ivlen=%lu, params=%p)",
                      ctx, iv, ivlen, params);
        return CKR_MECHANISM_PARAM_INVALID;
    }

    return CKR_OK;
}

static CK_RV p11prov_cipher_op_init(void *ctx, void *keydata, CK_FLAGS op,
                                    const unsigned char *iv, size_t ivlen,
                                    const OSSL_PARAM params[])
{
    struct p11prov_cipher_ctx *cctx = (struct p11prov_cipher_ctx *)ctx;
    P11PROV_OBJ *key = (P11PROV_OBJ *)keydata;
    CK_RV rv;

    rv = p11prov_ctx_status(cctx->provctx);
    if (rv != CKR_OK) {
        return rv;
    }

    cctx->operation = op;

    rv = p11prov_cipher_prep_mech(cctx, iv, ivlen, params);
    if (rv != CKR_OK) {
        return rv;
    }

    /* If keydata is NULL, it means the application will pass the key later,
     * this is allowed in legacy initialization, so skip full init until we
     * have all the pieces. */
    if (key) {
        cctx->key = p11prov_obj_ref(key);
        if (cctx->key == NULL) {
            return CKR_KEY_NEEDED;
        }
    }

    return CKR_OK;
}

static CK_RV p11prov_cipher_session_init(struct p11prov_cipher_ctx *cctx)
{
    CK_RV rv;

    if (cctx->tlsver != 0 && cctx->mech.mechanism == CKM_AES_CBC_PAD) {
        /* In the special TLS mode we handle de-padding and mac extraction
         * outside the pkcs11 module to conform to what OpenSSL does */
        cctx->mech.mechanism = CKM_AES_CBC;
    }

    rv = p11prov_try_session_ref(cctx->key, cctx->mech.mechanism, true, false,
                                 &cctx->session);
    if (rv != CKR_OK) {
        return rv;
    }

    switch (cctx->operation) {
    case CKF_ENCRYPT:
        rv = p11prov_EncryptInit(
            cctx->provctx, p11prov_session_handle(cctx->session), &cctx->mech,
            p11prov_obj_get_handle(cctx->key));
        break;
    case CKF_DECRYPT:
        rv = p11prov_DecryptInit(
            cctx->provctx, p11prov_session_handle(cctx->session), &cctx->mech,
            p11prov_obj_get_handle(cctx->key));
        break;
    case CKF_MESSAGE_ENCRYPT:
        rv = p11prov_MessageEncryptInit(
            cctx->provctx, p11prov_session_handle(cctx->session), &cctx->mech,
            p11prov_obj_get_handle(cctx->key));
        break;
    case CKF_MESSAGE_DECRYPT:
        rv = p11prov_MessageDecryptInit(
            cctx->provctx, p11prov_session_handle(cctx->session), &cctx->mech,
            p11prov_obj_get_handle(cctx->key));
        break;
    default:
        rv = CKR_GENERAL_ERROR;
    }

    if (rv == CKR_OK) {
        cctx->session_state = SESS_INITIALIZED;
    }

    return rv;
}

static int p11prov_cipher_legacy_init(void *ctx, CK_FLAGS op,
                                      const unsigned char *key, size_t keylen,
                                      const unsigned char *iv, size_t ivlen,
                                      const OSSL_PARAM params[])
{
    struct p11prov_cipher_ctx *cctx = (struct p11prov_cipher_ctx *)ctx;
    P11PROV_OBJ *skey = NULL;
    CK_RV rv;

    rv = p11prov_ctx_status(cctx->provctx);
    if (rv != CKR_OK) {
        return RET_OSSL_ERR;
    }

    if (key != NULL && keylen > 0) {
        /* The only way to fulfill this request is by importing the AES key
         * in the token as a session object */
        skey =
            p11prov_obj_import_secret_key(cctx->provctx, CKK_AES, key, keylen);
        if (!skey) {
            return RET_OSSL_ERR;
        }
    }

    rv = p11prov_cipher_op_init(ctx, skey, op, iv, ivlen, params);

    p11prov_obj_free(skey);

    if (rv != CKR_OK) {
        return RET_OSSL_ERR;
    }
    return RET_OSSL_OK;
}

static CK_FLAGS p11prov_cipher_get_op(struct p11prov_cipher_ctx *ctx,
                                      CK_FLAGS def_flag)
{
    if (ctx->mech.mechanism == CKM_AES_GCM
        || ctx->mech.mechanism == CKM_CHACHA20_POLY1305) {
        switch (def_flag) {
        case CKF_ENCRYPT:
            return CKF_MESSAGE_ENCRYPT;
        case CKF_DECRYPT:
            return CKF_MESSAGE_DECRYPT;
        default:
            ERR_raise(ERR_LIB_PROV, PROV_R_CIPHER_OPERATION_FAILED);
            return 0;
        }
    }

    return def_flag;
}

static int p11prov_cipher_encrypt_init(void *ctx, const unsigned char *key,
                                       size_t keylen, const unsigned char *iv,
                                       size_t ivlen, const OSSL_PARAM params[])
{
    P11PROV_debug("encrypt init (ctx=%p, key=%p, iv=%p, params=%p)", ctx, key,
                  iv, params);

    return p11prov_cipher_legacy_init(ctx,
                                      p11prov_cipher_get_op(ctx, CKF_ENCRYPT),
                                      key, keylen, iv, ivlen, params);
}

static int p11prov_cipher_decrypt_init(void *ctx, const unsigned char *key,
                                       size_t keylen, const unsigned char *iv,
                                       size_t ivlen, const OSSL_PARAM params[])
{
    P11PROV_debug("decrypt init (ctx=%p, key=%p, iv=%p, params=%p)", ctx, key,
                  iv, params);

    return p11prov_cipher_legacy_init(ctx,
                                      p11prov_cipher_get_op(ctx, CKF_DECRYPT),
                                      key, keylen, iv, ivlen, params);
}

static int p11prov_cipher_encrypt_skey_init(void *ctx, void *keydata,
                                            const unsigned char *iv,
                                            size_t ivlen,
                                            const OSSL_PARAM params[])
{
    CK_RV rv;

    P11PROV_debug("encrypt skey init (ctx=%p, key=%p, params=%p)", ctx, keydata,
                  params);

    rv = p11prov_cipher_op_init(ctx, keydata,
                                p11prov_cipher_get_op(ctx, CKF_ENCRYPT), iv,
                                ivlen, params);
    if (rv != CKR_OK) {
        return RET_OSSL_ERR;
    }

    return RET_OSSL_OK;
}

static int p11prov_cipher_decrypt_skey_init(void *ctx, void *keydata,
                                            const unsigned char *iv,
                                            size_t ivlen,
                                            const OSSL_PARAM params[])
{
    CK_RV rv;

    P11PROV_debug("decrypt skey init (ctx=%p, key=%p, params=%p)", ctx, keydata,
                  params);

    rv = p11prov_cipher_op_init(ctx, keydata,
                                p11prov_cipher_get_op(ctx, CKF_DECRYPT), iv,
                                ivlen, params);
    if (rv != CKR_OK) {
        return RET_OSSL_ERR;
    }

    return RET_OSSL_OK;
}

/* This function needs to be executed in constant time */
static CK_RV tlsunpad(struct p11prov_cipher_ctx *cctx, unsigned char *out,
                      CK_ULONG inlen, CK_ULONG *outlen)
{
    CK_RV rv = CKR_GENERAL_ERROR;
    CK_ULONG overhead = cctx->tlsmacsize + 1; /* mac size + padlen byte */
    CK_ULONG maxcheck = MAX_PADDING;
    CK_ULONG padsize = out[inlen - 1];
    CK_ULONG olen = inlen;
    CK_ULONG pass;
    CK_ULONG i, j;

    /* Remove explicit IV for TLS 1.1 and 1.2 */
    if (cctx->tlsver != TLS1_VERSION) {
        /* This is a bad interface as it make it seem that
         * the returned output buffer is incorrectly pointing
         * at the IV and not the data, but OpenSSL will in turn
         * offset the buffer later, based on knowledge that this
         * cipher return a length that excludes the IV from the
         * count. */
        out += AESBLOCK;
        olen = inlen - AESBLOCK;
    }

    /* olen is public known so can be checked normally */
    if (olen < overhead) {
        return CKR_BUFFER_TOO_SMALL;
    }

    if (olen < cctx->tlsmacsize) {
        return CKR_BUFFER_TOO_SMALL;
    }

    if (maxcheck > olen) {
        maxcheck = olen;
    }

    /* olen must not be smaller than padsize + overhead */
    pass = ~constant_smaller_mask(olen, overhead + padsize);

    /* creates a mask so that we check only the padding bytes
     * without revealing the padding length in a conditional.
     * mask is 0xff when i < padsize, and 0 otherwise, allowing
     * us to scan the whole buffer while really only testing for
     * equality only the padding part, as the xoring with non-pad
     * data is ignored my the empty mask. We skip checking the
     * last value itself as that is always == padsize */
    for (i = 0; i < maxcheck - 1; i++) {
        unsigned char mask = constant_smaller_mask(i, padsize);
        unsigned char data = out[olen - i - 2];

        pass &= ~(mask & (padsize ^ data));
    }

    /* renormalize to a CK_ULONG */
    pass = constant_equal_mask(pass, 0xff);

    if (cctx->tlsmacsize > 0) {
        unsigned char randmac[EVP_MAX_MD_SIZE];
        CK_ULONG mac_pos = olen - cctx->tlsmacsize - (pass & (padsize + 1));
        CK_ULONG mac_area = 0;
        int err = RET_OSSL_ERR;

        /* allocate space for the mac */
        cctx->tlsmac = OPENSSL_zalloc(cctx->tlsmacsize);
        if (!cctx->tlsmac) {
            return CKR_GENERAL_ERROR;
        }

        /* random mac we return if something is wrong */
        err = RAND_bytes_ex(p11prov_ctx_get_libctx(cctx->provctx), randmac,
                            sizeof(randmac), 0);
        if (err != RET_OSSL_OK) {
            return CKR_GENERAL_ERROR;
        }

        /* olen and mac size are public data, so we can do this
         * assignment without bothering with constant time */
        if (olen > cctx->tlsmacsize + 256) {
            mac_area = olen - cctx->tlsmacsize - 256;
        }

        for (i = mac_area; i < olen; i++) {
            for (j = 0; j < cctx->tlsmacsize; j++) {
                unsigned char mask =
                    ~constant_smaller_mask(i, mac_pos)
                    & constant_smaller_mask(i, mac_pos + cctx->tlsmacsize)
                    & constant_equal_mask(i, j + mac_pos);
                cctx->tlsmac[j] |= out[i] & mask;
            }
        }

        /* on depadding failure overwrite with random data */
        for (j = 0; j < cctx->tlsmacsize; j++) {
            cctx->tlsmac[j] =
                constant_select_byte_mask(cctx->tlsmac[j], randmac[j], pass);
        }

        rv = CKR_OK;
    } else {
        /* no MAC to check just return the result */
        if (pass + 1 == 0) {
            rv = CKR_OK;
        }
    }

    *outlen = olen - cctx->tlsmacsize - (pass & (padsize + 1));
    return rv;
}

static CK_RV tls_aead_get_data(CK_MECHANISM_PTR mech, data_buffer *explicitiv,
                               data_buffer *tag)
{
    /* In TLS 1.2, OpenSSL provides a buffer with this layout:
     * [explicit IV] [plaintext] [authentication tag]
     *
     * In the encryption case, it expects the provider to fill in
     * the explicit IV and tag, and to overwrite the plaintext with
     * ciphertext. Explicit IV can be either:
     * 0 bytes: CHACHA20-POLY1305
     * 4 bytes: AES-GCM
     */

    if (mech->mechanism == CKM_AES_GCM) {
        CK_GCM_MESSAGE_PARAMS_PTR gcm =
            (CK_GCM_MESSAGE_PARAMS_PTR)mech->pParameter;

        explicitiv->data = gcm->pIv + BITS_TO_BYTES(gcm->ulIvFixedBits);
        explicitiv->length = EVP_GCM_TLS_EXPLICIT_IV_LEN;

        tag->data = gcm->pTag;
        tag->length = BITS_TO_BYTES(gcm->ulTagBits);
    } else if (mech->mechanism == CKM_CHACHA20_POLY1305) {
        CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR chacha =
            (CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR)mech->pParameter;

        explicitiv->data = NULL;
        explicitiv->length = 0;

        tag->data = chacha->pTag;
        tag->length = EVP_CHACHAPOLY_TLS_TAG_LEN;
    } else {
        return CKR_MECHANISM_INVALID;
    }

    return CKR_OK;
}

static CK_RV tls_pre_aead(struct p11prov_cipher_ctx *cctx,
                          const unsigned char **in, size_t *inl,
                          unsigned char **out, size_t *outl)
{
    data_buffer iv = { 0 };
    data_buffer tag = { 0 };
    CK_RV rv;

    rv = tls_aead_get_data(&cctx->mech, &iv, &tag);
    if (rv != CKR_OK) {
        return rv;
    }

    if (cctx->mech.mechanism == CKM_CHACHA20_POLY1305) {
        CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR chacha =
            (CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR)cctx->mech.pParameter;

        /* RFC7905: Before encrypting, XOR the nonce with the sequence number,
         * which is encoded in the AAD. */
        for (size_t i = 0; i < 8; i++)
            chacha->pNonce[i + 4] ^= cctx->aad[i];
    }

    if (cctx->operation == CKF_MESSAGE_DECRYPT) {
        if (iv.data && iv.length) {
            memcpy(iv.data, *in, iv.length);
        }
        if (tag.data && tag.length) {
            memcpy(tag.data, *in + *inl - tag.length, tag.length);
        }
    }

    *in += iv.length;
    *inl -= (iv.length + tag.length);
    *out += iv.length;
    *outl = *inl;

    return CKR_OK;
}

static CK_RV tls_post_aead(struct p11prov_cipher_ctx *cctx, unsigned char *out,
                           size_t *outl)
{
    data_buffer explicitiv = { 0 };
    data_buffer tag = { 0 };
    CK_RV rv;

    rv = tls_aead_get_data(&cctx->mech, &explicitiv, &tag);
    if (rv != CKR_OK) {
        return rv;
    }

    if (cctx->mech.mechanism == CKM_CHACHA20_POLY1305) {
        CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR chacha =
            (CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR)cctx->mech.pParameter;

        /* Post encryption, XOR the nonce again with the sequence number,
         * to restore the original nonce value. */
        for (size_t i = 0; i < 8; i++)
            chacha->pNonce[i + 4] ^= cctx->aad[i];
    }

    if (cctx->operation == CKF_MESSAGE_ENCRYPT) {
        if (explicitiv.data && explicitiv.length) {
            memcpy(out - explicitiv.length, explicitiv.data, explicitiv.length);
        }
        if (tag.data && tag.length) {
            memcpy(out + *outl, tag.data, tag.length);
        }

        *outl += (explicitiv.length + tag.length);
    }

    return CKR_OK;
}

static int p11prov_cipher_update(void *ctx, unsigned char *out, size_t *outl,
                                 size_t outsize, const unsigned char *in,
                                 size_t inl)
{
    struct p11prov_cipher_ctx *cctx = (struct p11prov_cipher_ctx *)ctx;
    CK_SESSION_HANDLE session_handle;
    CK_ULONG outlen = outsize;
    CK_ULONG inlen = inl;
    CK_RV rv;

    if (!in) {
        ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DATA);
        return RET_OSSL_ERR;
    }

    if (cctx->tlsver != 0) {
        /* Special OpenSSL layering violating mode.
         * A single update is a full record.
         * Inputs need to be consistent with stricter requirements */
        if (in != out || outsize < inl || !cctx->pad) {
            ERR_raise(ERR_LIB_PROV, PROV_R_CIPHER_OPERATION_FAILED);
            return RET_OSSL_ERR;
        }
    }

    /* When not doing TLS, or when doing TLS1.3, the `in` buffer
     * contains the authentication data, and `out` is NULL */
    if (cctx->tlsver == 0 || cctx->tlsver == TLS1_3_VERSION) {
        if (in && inl && !out) {
            OPENSSL_clear_free(cctx->aad, cctx->aadsize);
            cctx->aad = NULL;
            cctx->aadsize = 0;

            cctx->aad = OPENSSL_memdup(in, inl);
            if (!cctx->aad) {
                P11PROV_raise(cctx->provctx, CKR_HOST_MEMORY,
                              "Memory allocation failed");
                return RET_OSSL_ERR;
            }
            cctx->aadsize = inl;

            *outl = 0;
            return RET_OSSL_OK;
        }
    }

    if (!cctx->session) {
        rv = p11prov_cipher_session_init(cctx);
        if (rv != CKR_OK) {
            return RET_OSSL_ERR;
        }
    }
    session_handle = p11prov_session_handle(cctx->session);

    switch (cctx->operation) {
    case CKF_ENCRYPT:
        if (cctx->tlsver != 0) {
            size_t padsize = AESBLOCK - (inl % AESBLOCK);
            unsigned char padval = (unsigned char)(padsize - 1);

            if (outsize < inl + padsize) {
                rv = CKR_BUFFER_TOO_SMALL;
                P11PROV_raise(cctx->provctx, rv, "Output buffer too small");
                return RET_OSSL_ERR;
            }
            inlen += padsize;
            if ((inlen % AESBLOCK) != 0) {
                rv = CKR_ARGUMENTS_BAD;
                P11PROV_raise(cctx->provctx, rv, "Invalid input buffer size");
                return RET_OSSL_ERR;
            }
            /* add the padding, relies on in == out and therefore enough
             * space available in the buffer */
            memset(&out[inl], padval, padsize);

            /* in TLS mode we must use single shot encryption to properly
             * auto-finalize the session as OpenSSL won't */
            rv = p11prov_Encrypt(cctx->provctx, session_handle, (void *)in,
                                 inlen, out, &outlen);

            cctx->session_state = SESS_FINALIZED;
            /* unconditionally return the session */
            p11prov_return_session(cctx->session);
            cctx->session = NULL;
        } else {
            rv = p11prov_EncryptUpdate(cctx->provctx, session_handle,
                                       (void *)in, inlen, out, &outlen);
        }
        break;
    case CKF_DECRYPT:
        if (cctx->tlsver != 0) {
            if ((inlen % AESBLOCK) != 0) {
                rv = CKR_ARGUMENTS_BAD;
                P11PROV_raise(cctx->provctx, rv, "Invalid input buffer size");
                return RET_OSSL_ERR;
            }
            /* in TLS mode we must use single shot decryption to properly
             * auto-finalize the session as OpenSSL won't */
            rv = p11prov_Decrypt(cctx->provctx, session_handle, (void *)in,
                                 inlen, out, &outlen);

            cctx->session_state = SESS_FINALIZED;
            /* unconditionally return the session */
            p11prov_return_session(cctx->session);
            cctx->session = NULL;

            if (rv != CKR_OK) {
                P11PROV_raise(cctx->provctx, rv, "Decryption failure");
                return RET_OSSL_ERR;
            }
            /* remove padding and fill in tlsmac as needed */
            if (cctx->tlsmac) {
                OPENSSL_clear_free(cctx->tlsmac, cctx->tlsmacsize);
                cctx->tlsmac = NULL;
            }

            /* Assumes inlen = outlen on correct decryption */
            rv = tlsunpad(cctx, out, inlen, &outlen);
        } else {
            rv = p11prov_DecryptUpdate(cctx->provctx, session_handle,
                                       (void *)in, inlen, out, &outlen);
        }
        break;
    case CKF_MESSAGE_ENCRYPT:
        if (cctx->tlsver == TLS1_2_VERSION) {
            rv = tls_pre_aead(cctx, &in, &inl, &out, &outlen);
            if (rv != CKR_OK) {
                P11PROV_raise(cctx->provctx, rv, "AEAD encryption failure");
                return RET_OSSL_ERR;
            }
        }

        rv = p11prov_EncryptMessage(
            cctx->provctx, session_handle, cctx->mech.pParameter,
            cctx->mech.ulParameterLen, (CK_BYTE_PTR)cctx->aad, cctx->aadsize,
            (CK_BYTE_PTR)in, inl, out, &outlen);

        if (rv == CKR_OK && cctx->tlsver == TLS1_2_VERSION) {
            rv = tls_post_aead(cctx, out, &outlen);
        }

        OPENSSL_clear_free(cctx->aad, cctx->aadsize);
        cctx->aad = NULL;
        cctx->aadsize = 0;
        break;
    case CKF_MESSAGE_DECRYPT:
        if (cctx->tlsver == TLS1_2_VERSION) {
            rv = tls_pre_aead(cctx, &in, &inl, &out, &outlen);
            if (rv != CKR_OK) {
                P11PROV_raise(cctx->provctx, rv, "AEAD encryption failure");
                return RET_OSSL_ERR;
            }
        }

        rv = p11prov_DecryptMessage(
            cctx->provctx, session_handle, cctx->mech.pParameter,
            cctx->mech.ulParameterLen, (CK_BYTE_PTR)cctx->aad, cctx->aadsize,
            (CK_BYTE_PTR)in, inl, out, &outlen);

        /* No need to call tls_post_aead() after decryption */

        OPENSSL_clear_free(cctx->aad, cctx->aadsize);
        cctx->aad = NULL;
        cctx->aadsize = 0;
        break;
    default:
        rv = CKR_GENERAL_ERROR;
    }

    if (rv != CKR_OK) {
        return RET_OSSL_ERR;
    }

    *outl = outlen;
    return RET_OSSL_OK;
}

static int p11prov_cipher_final(void *ctx, unsigned char *out, size_t *outl,
                                size_t outsize)
{
    struct p11prov_cipher_ctx *cctx = (struct p11prov_cipher_ctx *)ctx;
    CK_ULONG outlen = outsize;
    CK_RV rv;

    if (!cctx->session) {
        return RET_OSSL_ERR;
    }

    switch (cctx->operation) {
    case CKF_ENCRYPT:
        rv = p11prov_EncryptFinal(
            cctx->provctx, p11prov_session_handle(cctx->session), out, &outlen);
        break;
    case CKF_DECRYPT:
        rv = p11prov_DecryptFinal(
            cctx->provctx, p11prov_session_handle(cctx->session), out, &outlen);
        break;
    case CKF_MESSAGE_ENCRYPT:
        rv = p11prov_MessageEncryptFinal(cctx->provctx,
                                         p11prov_session_handle(cctx->session));
        break;
    case CKF_MESSAGE_DECRYPT:
        rv = p11prov_MessageDecryptFinal(cctx->provctx,
                                         p11prov_session_handle(cctx->session));
        break;
    default:
        rv = CKR_GENERAL_ERROR;
    }

    cctx->session_state = SESS_FINALIZED;
    /* unconditionally return session here as well */
    p11prov_return_session(cctx->session);
    cctx->session = NULL;

    if (rv != CKR_OK) {
        return RET_OSSL_ERR;
    }

    *outl = outlen;
    return RET_OSSL_OK;
}

#define p11prov_aes_cipher p11prov_common_cipher
#define p11prov_chacha20_cipher p11prov_common_cipher

static int p11prov_common_cipher(void *ctx, unsigned char *out, size_t *outl,
                                 size_t outsize, const unsigned char *in,
                                 size_t inl)
{
    return RET_OSSL_ERR;
}

static int p11prov_aead_get_ctx_params(struct p11prov_cipher_ctx *cctx,
                                       OSSL_PARAM params[])
{
    unsigned char *tag = NULL;
    size_t taglen = 0;
    OSSL_PARAM *p = NULL;
    int ret = 0;

    CK_GCM_MESSAGE_PARAMS_PTR gcm = NULL;
    CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR chacha = NULL;

    switch (cctx->mech.mechanism) {
    case CKM_AES_GCM:
        gcm = (CK_GCM_MESSAGE_PARAMS_PTR)cctx->mech.pParameter;

        if (gcm) {
            tag = gcm->pTag;
            taglen = BITS_TO_BYTES(gcm->ulTagBits);
        }
        break;

    case CKM_CHACHA20_POLY1305:
        chacha =
            (CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR)cctx->mech.pParameter;

        if (chacha) {
            tag = chacha->pTag;
            taglen = EVP_CHACHAPOLY_TLS_TAG_LEN;
        }
        break;

    default:
        ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_AEAD);
        return RET_OSSL_ERR;
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_AEAD_TAGLEN);
    if (p) {
        ret = OSSL_PARAM_set_size_t(params, taglen);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_AEAD_TAG);
    if (p) {
        ret = OSSL_PARAM_set_octet_string(p, tag, taglen);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_AEAD_TLS1_AAD_PAD);
    if (p) {
        ret = OSSL_PARAM_set_size_t(p, taglen);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    return RET_OSSL_OK;
}

#define p11prov_aes_get_ctx_params p11prov_common_get_ctx_params
#define p11prov_chacha20_get_ctx_params p11prov_common_get_ctx_params

static int p11prov_common_get_ctx_params(void *ctx, OSSL_PARAM params[])
{
    struct p11prov_cipher_ctx *cctx = ctx;
    OSSL_PARAM *p;
    int ret;

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_IVLEN);
    if (p) {
        ret = OSSL_PARAM_set_size_t(p, cctx->ivsize);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_PADDING);
    if (p) {
        int pad = 0;
        if (cctx->pad) {
            pad = 1;
        }
        ret = OSSL_PARAM_set_uint(p, pad);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_IV);
    if (p) {
        ret = OSSL_PARAM_set_octet_string(p, cctx->mech.pParameter,
                                          cctx->mech.ulParameterLen);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_UPDATED_IV);
    if (p) {
        ret = OSSL_PARAM_set_octet_string(p, cctx->mech.pParameter,
                                          cctx->mech.ulParameterLen);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_NUM);
    if (p) {
        int num = 0;
        ret = OSSL_PARAM_set_uint(p, num);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_KEYLEN);
    if (p) {
        size_t keylen = cctx->keysize;
        ret = OSSL_PARAM_set_size_t(p, keylen);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_TLS_MAC);
    if (p) {
        ret = OSSL_PARAM_set_octet_ptr(p, cctx->tlsmac, cctx->tlsmacsize);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_SET_PARAMETER);
            return RET_OSSL_ERR;
        }
    }

    if (cctx->mech.mechanism == CKM_AES_GCM
        || cctx->mech.mechanism == CKM_CHACHA20_POLY1305) {
        return p11prov_aead_get_ctx_params(cctx, params);
    }

    return RET_OSSL_OK;
}

static int p11prov_common_set_ctx_params(void *vctx, const OSSL_PARAM params[])
{
    struct p11prov_cipher_ctx *ctx = (struct p11prov_cipher_ctx *)vctx;
    const OSSL_PARAM *p;

    p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_PADDING);
    if (p) {
        unsigned int pad;
        int ret = OSSL_PARAM_get_uint(p, &pad);
        if (ret != RET_OSSL_OK) {
            ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_GET_PARAMETER);
            return RET_OSSL_ERR;
        }
        if (pad > 1) {
            ERR_raise(ERR_LIB_PROV, PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE);
            return RET_OSSL_ERR;
        }
        ctx->pad = pad == 1;

        switch (ctx->mech.mechanism) {
        case CKM_AES_CBC:
            if (ctx->pad) {
                ctx->mech.mechanism = CKM_AES_CBC_PAD;
            }
            break;

        case CKM_AES_CBC_PAD:
            if (!ctx->pad) {
                ctx->mech.mechanism = CKM_AES_CBC;
            }
            break;

        default:
            if (ctx->pad) {
                /* FIXME: we need to do our padding as there is no _PAD mode
                 * for non CBC modes in PKCS#11 */
                ERR_raise(ERR_LIB_PROV,
                          PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE);
                return RET_OSSL_ERR;
            }
        }
    }

    if (ctx->mech.mechanism == CKM_AES_CTS) {
        p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_CTS_MODE);
        if (p) {
            const char *mode;
            int ret = OSSL_PARAM_get_utf8_ptr(p, &mode);
            if (ret != RET_OSSL_OK) {
                CK_RV rv = CKR_MECHANISM_PARAM_INVALID;
                P11PROV_raise(ctx->provctx, rv, "Invalid mode parameter");
                return RET_OSSL_ERR;
            }
            /* Currently only CS1 is supported */
            if (strcmp(mode, OSSL_CIPHER_CTS_MODE_CS1) != 0) {
                CK_RV rv = CKR_MECHANISM_PARAM_INVALID;
                P11PROV_raise(ctx->provctx, rv, "Unsupported mode: %s", mode);
                return RET_OSSL_ERR;
            }
        }
    }

    p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_TLS_VERSION);
    if (p) {
        CK_RV rv = CKR_MECHANISM_PARAM_INVALID;
        unsigned int version;
        int ret = OSSL_PARAM_get_uint(p, &version);
        if (ret != RET_OSSL_OK) {
            P11PROV_raise(ctx->provctx, rv, "Invalid TLS Version parameter");
            return RET_OSSL_ERR;
        }
        switch (version) {
        case TLS1_VERSION:
        case TLS1_1_VERSION:
        case TLS1_2_VERSION:
        case TLS1_3_VERSION:
            ctx->tlsver = version;
            break;
        default:
            P11PROV_raise(ctx->provctx, rv, "Unsupported TLS Version");
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_TLS_MAC_SIZE);
    if (p) {
        CK_RV rv = CKR_MECHANISM_PARAM_INVALID;
        size_t macsize;
        int ret = OSSL_PARAM_get_size_t(p, &macsize);
        if (ret != RET_OSSL_OK) {
            P11PROV_raise(ctx->provctx, rv, "Invalid TLS MAC Size parameter");
            return RET_OSSL_ERR;
        }
        if (macsize > EVP_MAX_MD_SIZE) {
            P11PROV_raise(ctx->provctx, rv, "Invalid TLS Mac Size");
            return RET_OSSL_ERR;
        }
        ctx->tlsmacsize = macsize;
    }

    p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_IVLEN);
    if (p) {
        CK_RV rv = CKR_MECHANISM_PARAM_INVALID;
        size_t ivlen = 0;
        int ret = OSSL_PARAM_get_size_t(p, &ivlen);
        if (ret != RET_OSSL_OK) {
            P11PROV_raise(ctx->provctx, CKR_MECHANISM_PARAM_INVALID,
                          "Invalid AEAD IV Length parameter");
            return RET_OSSL_ERR;
        }

        if (ctx->mech.mechanism == CKM_AES_GCM) {
            rv = p11prov_cipher_prep_gcm(&ctx->mech, NULL, ivlen);
        } else if (ctx->mech.mechanism == CKM_CHACHA20_POLY1305) {
            rv = p11prov_cipher_prep_chacha20_poly1305(&ctx->mech, NULL, ivlen);
        } else {
            P11PROV_raise(ctx->provctx, rv,
                          "AEAD IV Length not supported for this mechanism");
            return RET_OSSL_ERR;
        }

        if (rv != CKR_OK) {
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_TLS1_IV_FIXED);
    if (p) {
        if (ctx->mech.mechanism == CKM_AES_GCM) {
            CK_GCM_MESSAGE_PARAMS_PTR gcm =
                (CK_GCM_MESSAGE_PARAMS_PTR)ctx->mech.pParameter;

            if (gcm->pIv) {
                OPENSSL_clear_free(gcm->pIv, gcm->ulIvLen);
                gcm->pIv = NULL;
                gcm->ulIvLen = 0;
                gcm->ulIvFixedBits = 0;
            }

            gcm->ulIvLen =
                EVP_GCM_TLS_FIXED_IV_LEN + EVP_GCM_TLS_EXPLICIT_IV_LEN;
            gcm->pIv = OPENSSL_zalloc(gcm->ulIvLen);
            if (!gcm->pIv) {
                P11PROV_raise(ctx->provctx, CKR_HOST_MEMORY,
                              "Memory allocation failed");
                return RET_OSSL_ERR;
            }

            int ret = OSSL_PARAM_get_octet_string(
                p, (void **)&gcm->pIv, gcm->ulIvLen, &gcm->ulIvFixedBits);
            if (ret != RET_OSSL_OK || gcm->pIv == NULL) {
                P11PROV_raise(ctx->provctx, CKR_HOST_MEMORY,
                              "Memory allocation failed");
                return RET_OSSL_ERR;
            }

            gcm->ulIvFixedBits = BYTES_TO_BITS(gcm->ulIvFixedBits);
            gcm->ivGenerator = CKG_GENERATE_COUNTER;
        } else {
            P11PROV_raise(ctx->provctx, CKR_MECHANISM_PARAM_INVALID,
                          "Fixed IV not supported for this mechanism");
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_TLS1_AAD);
    if (p) {
        if (ctx->mech.mechanism == CKM_AES_GCM
            || ctx->mech.mechanism == CKM_CHACHA20_POLY1305) {
            OPENSSL_clear_free(ctx->aad, ctx->aadsize);
            ctx->aad = NULL;
            ctx->aadsize = 0;

            int ret = OSSL_PARAM_get_octet_string(
                p, (void **)&ctx->aad, EVP_AEAD_TLS1_AAD_LEN, &ctx->aadsize);
            if (ret != RET_OSSL_OK) {
                P11PROV_raise(ctx->provctx, CKR_GENERAL_ERROR,
                              "Invalid AAD parameter");
                return RET_OSSL_ERR;
            }

            size_t ivlen = 0;
            size_t taglen = 0;

            /* OpenSSL encodes the record length in the last two bytes of AAD, this
             * value needs to be adjusted. See also gcm_tls_init() in OpenSSL. */
            if (ctx->mech.mechanism == CKM_AES_GCM) {
                ivlen = EVP_GCM_TLS_EXPLICIT_IV_LEN;
                taglen = EVP_GCM_TLS_TAG_LEN;
            } else if (ctx->mech.mechanism == CKM_CHACHA20_POLY1305) {
                ivlen = 0;
                taglen = EVP_CHACHAPOLY_TLS_TAG_LEN;
            }

            size_t len =
                ctx->aad[ctx->aadsize - 2] << 8 | ctx->aad[ctx->aadsize - 1];
            if (len <= ivlen) {
                P11PROV_raise(ctx->provctx, CKR_GENERAL_ERROR,
                              "Invalid AAD encoded length");
                return RET_OSSL_ERR;
            }
            len -= ivlen;

            if (ctx->operation == CKF_MESSAGE_DECRYPT) {
                if (len <= taglen) {
                    P11PROV_raise(ctx->provctx, CKR_GENERAL_ERROR,
                                  "Invalid AAD encoded length");
                    return RET_OSSL_ERR;
                }
                len -= taglen;
            }

            ctx->aad[ctx->aadsize - 2] = (unsigned char)((len >> 8) & 0xff);
            ctx->aad[ctx->aadsize - 1] = (unsigned char)(len & 0xff);
        } else {
            P11PROV_raise(ctx->provctx, CKR_MECHANISM_PARAM_INVALID,
                          "AAD not supported for this mechanism");
            return RET_OSSL_ERR;
        }
    }

    p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_AEAD_TAG);
    if (p) {
        unsigned char *tag = NULL;
        size_t taglen = 0;
        int ret = OSSL_PARAM_get_octet_string(p, (void **)&tag,
                                              EVP_MAX_AEAD_TAG_LENGTH, &taglen);
        if (ret != RET_OSSL_OK) {
            P11PROV_raise(ctx->provctx, CKR_MECHANISM_PARAM_INVALID,
                          "Invalid AEAD Tag parameter");
            return RET_OSSL_ERR;
        }

        if (ctx->mech.mechanism == CKM_AES_GCM) {
            CK_GCM_MESSAGE_PARAMS_PTR gcm =
                (CK_GCM_MESSAGE_PARAMS_PTR)ctx->mech.pParameter;
            OPENSSL_clear_free(gcm->pTag, BITS_TO_BYTES(gcm->ulTagBits));
            gcm->pTag = tag;
            gcm->ulTagBits = BYTES_TO_BITS(taglen);
        } else if (ctx->mech.mechanism == CKM_CHACHA20_POLY1305) {
            CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR chacha =
                (CK_SALSA20_CHACHA20_POLY1305_MSG_PARAMS_PTR)
                    ctx->mech.pParameter;
            OPENSSL_clear_free(chacha->pTag, EVP_CHACHAPOLY_TLS_TAG_LEN);
            chacha->pTag = tag;
        } else {
            OPENSSL_clear_free(tag, taglen);
            P11PROV_raise(ctx->provctx, CKR_MECHANISM_PARAM_INVALID,
                          "AEAD Tag not supported for this mechanism");
            return RET_OSSL_ERR;
        }
    }

    return RET_OSSL_OK;
}

static const OSSL_PARAM p11prov_aes_generic_gettable_ctx_params[] = {
    OSSL_PARAM_size_t(OSSL_CIPHER_PARAM_KEYLEN, NULL),
    OSSL_PARAM_size_t(OSSL_CIPHER_PARAM_IVLEN, NULL),
    OSSL_PARAM_uint(OSSL_CIPHER_PARAM_PADDING, NULL),
    OSSL_PARAM_uint(OSSL_CIPHER_PARAM_NUM, NULL),
    OSSL_PARAM_octet_string(OSSL_CIPHER_PARAM_IV, NULL, 0),
    OSSL_PARAM_octet_string(OSSL_CIPHER_PARAM_UPDATED_IV, NULL, 0),
    OSSL_PARAM_octet_string(OSSL_CIPHER_PARAM_TLS_MAC, NULL, 0),
    OSSL_PARAM_END
};

#define p11prov_aes_gettable_ctx_params p11prov_common_gettable_ctx_params
#define p11prov_chacha20_gettable_ctx_params p11prov_common_gettable_ctx_params

static const OSSL_PARAM *p11prov_common_gettable_ctx_params(void *vctx,
                                                            void *provctx)
{
    struct p11prov_cipher_ctx *ctx = (struct p11prov_cipher_ctx *)vctx;

    if (!ctx) {
        /* There are some cases where openssl will ask for context
         * parameters but will pass NULL for the context, for now
         * we return the generic parameters, but in future we may
         * need to allocate shim functions for each cipher in their
         * dispatch table if it becomes important to return different
         * results for each cipher */
        return p11prov_aes_generic_gettable_ctx_params;
    }

    switch (ctx->mech.mechanism) {
    case CKM_AES_ECB:
    case CKM_AES_CBC_PAD:
    case CKM_AES_OFB:
    case CKM_AES_CFB128:
    case CKM_AES_CFB1:
    case CKM_AES_CFB8:
    case CKM_AES_CTR:
    case CKM_AES_CTS:
        return p11prov_aes_generic_gettable_ctx_params;
    }
    return NULL;
}

#define GENERIC_SETTABLE_CTX_PARAMS() \
    OSSL_PARAM_uint(OSSL_CIPHER_PARAM_PADDING, NULL)
/* Supported by OpenSSL but not here:
 * OSSL_CIPHER_PARAM_NUM (uint)
 * OSSL_CIPHER_PARAM_USE_BITS (uint)
 */

static const OSSL_PARAM p11prov_aes_generic_settable_ctx_params[] = {
    GENERIC_SETTABLE_CTX_PARAMS(),
    OSSL_PARAM_uint(OSSL_CIPHER_PARAM_TLS_VERSION, NULL),
    OSSL_PARAM_size_t(OSSL_CIPHER_PARAM_TLS_MAC_SIZE, NULL), OSSL_PARAM_END
};

static const OSSL_PARAM p11prov_aes_cts_settable_ctx_params[] = {
    GENERIC_SETTABLE_CTX_PARAMS(),
    OSSL_PARAM_utf8_string(OSSL_CIPHER_PARAM_CTS_MODE, NULL, 0), OSSL_PARAM_END
};

#define p11prov_aes_settable_ctx_params p11prov_common_settable_ctx_params
#define p11prov_chacha20_settable_ctx_params p11prov_common_settable_ctx_params

static const OSSL_PARAM *p11prov_common_settable_ctx_params(void *vctx,
                                                            void *provctx)
{
    struct p11prov_cipher_ctx *ctx = (struct p11prov_cipher_ctx *)vctx;
    if (!ctx) {
        /* See the explanation in p11prov_aes_gettable_ctx_params() for
         * why we handle this case this way */
        return p11prov_aes_generic_settable_ctx_params;
    }
    switch (ctx->mech.mechanism) {
    case CKM_AES_ECB:
    case CKM_AES_CBC_PAD:
    case CKM_AES_OFB:
    case CKM_AES_CFB128:
    case CKM_AES_CFB1:
    case CKM_AES_CFB8:
    case CKM_AES_CTR:
        return p11prov_aes_generic_settable_ctx_params;
    case CKM_AES_CTS:
        return p11prov_aes_cts_settable_ctx_params;
    }
    return NULL;
}

DISPATCH_TABLE_CIPHER_FN(aes, 128, ecb, CKM_AES_ECB);
DISPATCH_TABLE_CIPHER_FN(aes, 192, ecb, CKM_AES_ECB);
DISPATCH_TABLE_CIPHER_FN(aes, 256, ecb, CKM_AES_ECB);
DISPATCH_TABLE_CIPHER_FN(aes, 128, cbc, CKM_AES_CBC_PAD);
DISPATCH_TABLE_CIPHER_FN(aes, 192, cbc, CKM_AES_CBC_PAD);
DISPATCH_TABLE_CIPHER_FN(aes, 256, cbc, CKM_AES_CBC_PAD);
DISPATCH_TABLE_CIPHER_FN(aes, 128, ofb, CKM_AES_OFB);
DISPATCH_TABLE_CIPHER_FN(aes, 192, ofb, CKM_AES_OFB);
DISPATCH_TABLE_CIPHER_FN(aes, 256, ofb, CKM_AES_OFB);
DISPATCH_TABLE_CIPHER_FN(aes, 128, cfb, CKM_AES_CFB128);
DISPATCH_TABLE_CIPHER_FN(aes, 192, cfb, CKM_AES_CFB128);
DISPATCH_TABLE_CIPHER_FN(aes, 256, cfb, CKM_AES_CFB128);
DISPATCH_TABLE_CIPHER_FN(aes, 128, cfb1, CKM_AES_CFB1);
DISPATCH_TABLE_CIPHER_FN(aes, 192, cfb1, CKM_AES_CFB1);
DISPATCH_TABLE_CIPHER_FN(aes, 256, cfb1, CKM_AES_CFB1);
DISPATCH_TABLE_CIPHER_FN(aes, 128, cfb8, CKM_AES_CFB8);
DISPATCH_TABLE_CIPHER_FN(aes, 192, cfb8, CKM_AES_CFB8);
DISPATCH_TABLE_CIPHER_FN(aes, 256, cfb8, CKM_AES_CFB8);
DISPATCH_TABLE_CIPHER_FN(aes, 128, ctr, CKM_AES_CTR);
DISPATCH_TABLE_CIPHER_FN(aes, 192, ctr, CKM_AES_CTR);
DISPATCH_TABLE_CIPHER_FN(aes, 256, ctr, CKM_AES_CTR);
DISPATCH_TABLE_CIPHER_FN(aes, 128, cts, CKM_AES_CTS);
DISPATCH_TABLE_CIPHER_FN(aes, 192, cts, CKM_AES_CTS);
DISPATCH_TABLE_CIPHER_FN(aes, 256, cts, CKM_AES_CTS);
DISPATCH_TABLE_CIPHER_FN(aes, 128, gcm, CKM_AES_GCM);
DISPATCH_TABLE_CIPHER_FN(aes, 192, gcm, CKM_AES_GCM);
DISPATCH_TABLE_CIPHER_FN(aes, 256, gcm, CKM_AES_GCM);
DISPATCH_TABLE_CIPHER_FN(chacha20, 256, poly1305, CKM_CHACHA20_POLY1305);

#endif
