/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.protocol.preparator;

import de.rub.nds.protocol.util.SilentByteArrayOutputStream;
import de.rub.nds.tlsattacker.core.config.Config;
import de.rub.nds.tlsattacker.core.constants.ExtensionType;
import de.rub.nds.tlsattacker.core.constants.HandshakeMessageType;
import de.rub.nds.tlsattacker.core.layer.data.Preparator;
import de.rub.nds.tlsattacker.core.protocol.ProtocolMessagePreparator;
import de.rub.nds.tlsattacker.core.protocol.ProtocolMessageSerializer;
import de.rub.nds.tlsattacker.core.protocol.message.ClientHelloMessage;
import de.rub.nds.tlsattacker.core.protocol.message.HandshakeMessage;
import de.rub.nds.tlsattacker.core.protocol.message.ServerHelloMessage;
import de.rub.nds.tlsattacker.core.protocol.message.extension.EncryptedServerNameIndicationExtensionMessage;
import de.rub.nds.tlsattacker.core.protocol.message.extension.ExtensionMessage;
import de.rub.nds.tlsattacker.core.protocol.message.extension.KeyShareExtensionMessage;
import de.rub.nds.tlsattacker.core.protocol.message.extension.PreSharedKeyExtensionMessage;
import de.rub.nds.tlsattacker.core.protocol.preparator.extension.EncryptedServerNameIndicationExtensionPreparator;
import de.rub.nds.tlsattacker.core.protocol.preparator.extension.ExtensionPreparator;
import de.rub.nds.tlsattacker.core.protocol.preparator.extension.PreSharedKeyExtensionPreparator;
import de.rub.nds.tlsattacker.core.protocol.serializer.HandshakeMessageSerializer;
import de.rub.nds.tlsattacker.core.workflow.chooser.Chooser;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public abstract class HandshakeMessagePreparator<T extends HandshakeMessage>
extends ProtocolMessagePreparator<T> {
    private static final Logger LOGGER = LogManager.getLogger();

    public HandshakeMessagePreparator(Chooser chooser, T message) {
        super(chooser, message);
    }

    protected void prepareMessageLength(int length) {
        ((HandshakeMessage)this.message).setLength(length);
        LOGGER.debug("Length: {}", ((HandshakeMessage)this.message).getLength().getValue());
    }

    private void prepareMessageType(HandshakeMessageType type) {
        ((HandshakeMessage)this.message).setType(type.getValue());
        LOGGER.debug("Type: {}", ((HandshakeMessage)this.message).getType().getValue());
    }

    private void prepareMessageContent(byte[] content) {
        ((HandshakeMessage)this.message).setMessageContent(content);
        LOGGER.debug("Handshake message content: {}", ((HandshakeMessage)this.message).getMessageContent().getValue());
    }

    @Override
    protected void prepareProtocolMessageContents() {
        this.prepareHandshakeMessageContents();
        this.prepareEncapsulatingFields();
    }

    public void prepareEncapsulatingFields() {
        ProtocolMessageSerializer serializer = ((HandshakeMessage)this.message).getSerializer(this.chooser.getContext());
        byte[] content = ((HandshakeMessageSerializer)serializer).serializeHandshakeMessageContent();
        this.prepareMessageContent(content);
        this.prepareMessageLength(((byte[])((HandshakeMessage)this.message).getMessageContent().getValue()).length);
        this.prepareMessageType(((HandshakeMessage)this.message).getHandshakeMessageType());
    }

    public void autoSelectExtensions(Config tlsConfig, Set<ExtensionType> proposedExtensions, Set<ExtensionType> forbiddenExtensions, ExtensionType ... exceptions) {
        this.setExtensionsBasedOnProposals(((HandshakeMessage)this.message).createConfiguredExtensions(tlsConfig), proposedExtensions, forbiddenExtensions, exceptions);
        LOGGER.debug("Automatically selected extensions for message {}: {}", (Object)((HandshakeMessage)this.message).getHandshakeMessageType().name(), (Object)((HandshakeMessage)this.message).getExtensions().stream().map(ExtensionMessage::getExtensionTypeConstant).map(Enum::name).collect(Collectors.joining(",")));
    }

    public final void setExtensionsBasedOnProposals(List<ExtensionMessage> configuredExtensions, Set<ExtensionType> clientProposedExtensions, Set<ExtensionType> forbiddenExtensions, ExtensionType ... exceptions) {
        ((HandshakeMessage)this.message).setExtensions(new LinkedList<ExtensionMessage>());
        List<ExtensionType> listedExceptions = Arrays.asList(exceptions);
        configuredExtensions.stream().filter(configuredExtension -> !forbiddenExtensions.contains((Object)configuredExtension.getExtensionTypeConstant()) && (clientProposedExtensions.contains((Object)configuredExtension.getExtensionTypeConstant()) || listedExceptions.contains((Object)configuredExtension.getExtensionTypeConstant()))).forEach(((HandshakeMessage)this.message)::addExtension);
    }

    protected abstract void prepareHandshakeMessageContents();

    protected void prepareExtensions() {
        SilentByteArrayOutputStream stream = new SilentByteArrayOutputStream();
        if (((HandshakeMessage)this.message).getExtensions() != null) {
            for (ExtensionMessage extensionMessage : ((HandshakeMessage)this.message).getExtensions()) {
                if (extensionMessage instanceof KeyShareExtensionMessage && this.message instanceof ServerHelloMessage) {
                    ServerHelloMessage serverHello = (ServerHelloMessage)this.message;
                    KeyShareExtensionMessage ksExt = (KeyShareExtensionMessage)extensionMessage;
                    if (serverHello.setRetryRequestModeInKeyShare()) {
                        ksExt.setRetryRequestMode(true);
                    }
                }
                ((ExtensionPreparator)extensionMessage.getPreparator(this.chooser.getContext())).prepare();
                stream.write((byte[])extensionMessage.getExtensionBytes().getValue());
            }
        }
        ((HandshakeMessage)this.message).setExtensionBytes(stream.toByteArray());
        LOGGER.debug("ExtensionBytes: {}", ((HandshakeMessage)this.message).getExtensionBytes().getValue());
    }

    protected void afterPrepareExtensions() {
        SilentByteArrayOutputStream stream = new SilentByteArrayOutputStream();
        if (((HandshakeMessage)this.message).getExtensions() != null) {
            for (ExtensionMessage extensionMessage : ((HandshakeMessage)this.message).getExtensions()) {
                Preparator preparator = extensionMessage.getPreparator(this.chooser.getContext());
                if (extensionMessage instanceof PreSharedKeyExtensionMessage && this.message instanceof ClientHelloMessage && this.chooser.getConnectionEndType() == ConnectionEndType.CLIENT) {
                    ((PreSharedKeyExtensionPreparator)preparator).setClientHello((ClientHelloMessage)this.message);
                    preparator.afterPrepare();
                } else if (extensionMessage instanceof EncryptedServerNameIndicationExtensionMessage && this.message instanceof ClientHelloMessage && this.chooser.getConnectionEndType() == ConnectionEndType.CLIENT) {
                    ClientHelloMessage clientHelloMessage = (ClientHelloMessage)this.message;
                    ((EncryptedServerNameIndicationExtensionPreparator)preparator).setClientHelloMessage(clientHelloMessage);
                    preparator.afterPrepare();
                }
                if (extensionMessage.getExtensionBytes() != null && extensionMessage.getExtensionBytes().getValue() != null) {
                    stream.write((byte[])extensionMessage.getExtensionBytes().getValue());
                    continue;
                }
                LOGGER.debug("If we are in a SSLv2 or SSLv3 Connection we do not add extensions, as SSL did not contain extensions");
                LOGGER.debug("If however, the extensions are prepared, we will add them");
            }
        }
        ((HandshakeMessage)this.message).setExtensionBytes(stream.toByteArray());
        this.prepareEncapsulatingFields();
        LOGGER.debug("ExtensionBytes: {}", ((HandshakeMessage)this.message).getExtensionBytes().getValue());
    }

    protected void prepareExtensionLength() {
        ((HandshakeMessage)this.message).setExtensionsLength(((byte[])((HandshakeMessage)this.message).getExtensionBytes().getValue()).length);
        LOGGER.debug("ExtensionLength: {}", ((HandshakeMessage)this.message).getExtensionsLength().getValue());
    }
}

