/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.protocol.handler;

import de.rub.nds.protocol.exception.AdjustmentException;
import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.tlsattacker.core.constants.AlgorithmResolver;
import de.rub.nds.tlsattacker.core.constants.CipherSuite;
import de.rub.nds.tlsattacker.core.constants.CompressionMethod;
import de.rub.nds.tlsattacker.core.constants.DigestAlgorithm;
import de.rub.nds.tlsattacker.core.constants.ExtensionType;
import de.rub.nds.tlsattacker.core.constants.HKDFAlgorithm;
import de.rub.nds.tlsattacker.core.constants.ProtocolVersion;
import de.rub.nds.tlsattacker.core.constants.Tls13KeySetType;
import de.rub.nds.tlsattacker.core.crypto.HKDFunction;
import de.rub.nds.tlsattacker.core.layer.context.TlsContext;
import de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler;
import de.rub.nds.tlsattacker.core.protocol.message.CoreClientHelloMessage;
import de.rub.nds.tlsattacker.core.protocol.message.HelloMessage;
import de.rub.nds.tlsattacker.core.record.cipher.RecordCipherFactory;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeyDerivator;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySet;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.security.NoSuchAlgorithmException;
import java.util.LinkedList;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public abstract class CoreClientHelloHandler<Message extends CoreClientHelloMessage>
extends HandshakeMessageHandler<Message> {
    private static final Logger LOGGER = LogManager.getLogger();

    public CoreClientHelloHandler(TlsContext tlsContext) {
        super(tlsContext);
    }

    @Override
    public void adjustContext(Message message) {
        this.adjustProtocolVersion(message);
        this.adjustSessionID(message);
        this.adjustClientSupportedCipherSuites(message);
        this.adjustClientSupportedCompressions(message);
        if (this.isCookieFieldSet(message)) {
            this.adjustDTLSCookie(message);
        }
        this.adjustExtensions(message);
        this.warnOnConflictingExtensions();
        this.adjustRandomContext(message);
        if (this.tlsContext.getChooser().getSelectedProtocolVersion().is13() && this.tlsContext.isExtensionNegotiated(ExtensionType.EARLY_DATA)) {
            try {
                this.adjustEarlyTrafficSecret();
                this.setClientRecordCipherEarly();
            }
            catch (CryptoException ex) {
                throw new AdjustmentException("Could not adjust", (Throwable)ex);
            }
        }
    }

    private boolean isCookieFieldSet(Message message) {
        return ((CoreClientHelloMessage)message).getCookie() != null;
    }

    private void adjustClientSupportedCipherSuites(Message message) {
        List<CipherSuite> suiteList = this.convertCipherSuites((byte[])((CoreClientHelloMessage)message).getCipherSuites().getValue());
        this.tlsContext.setClientSupportedCipherSuites(suiteList);
        if (suiteList != null) {
            LOGGER.debug("Set ClientSupportedCipherSuites in Context to {}", (Object)suiteList.toString());
        } else {
            LOGGER.debug("Set ClientSupportedCipherSuites in Context to null");
        }
    }

    private void adjustClientSupportedCompressions(Message message) {
        List<CompressionMethod> compressionList = this.convertCompressionMethods((byte[])((CoreClientHelloMessage)message).getCompressions().getValue());
        this.tlsContext.setClientSupportedCompressions(compressionList);
        LOGGER.debug("Set ClientSupportedCompressions in Context to {}", (Object)compressionList.toString());
    }

    private void adjustDTLSCookie(Message message) {
        byte[] dtlsCookie = (byte[])((CoreClientHelloMessage)message).getCookie().getValue();
        this.tlsContext.setDtlsCookie(dtlsCookie);
        LOGGER.debug("Set DTLS Cookie in Context to {}", (Object)dtlsCookie);
    }

    private void adjustSessionID(Message message) {
        byte[] sessionId = (byte[])((HelloMessage)message).getSessionId().getValue();
        this.tlsContext.setClientSessionId(sessionId);
        LOGGER.debug("Set SessionId in Context to {}", (Object)sessionId);
    }

    private void adjustProtocolVersion(Message message) {
        ProtocolVersion version = ProtocolVersion.getProtocolVersion((byte[])((HelloMessage)message).getProtocolVersion().getValue());
        if (version != null) {
            this.tlsContext.setHighestClientProtocolVersion(version);
            LOGGER.debug("Set HighestClientProtocolVersion in Context to {}", (Object)version.name());
        } else {
            LOGGER.warn("Did not Adjust ProtocolVersion since version is undefined {}", ((HelloMessage)message).getProtocolVersion().getValue());
        }
    }

    private void adjustRandomContext(Message message) {
        this.tlsContext.setClientRandom((byte[])((HelloMessage)message).getRandom().getValue());
        LOGGER.debug("Set ClientRandom in Context to {}", (Object)this.tlsContext.getClientRandom());
    }

    private List<CompressionMethod> convertCompressionMethods(byte[] bytesToConvert) {
        LinkedList<CompressionMethod> list = new LinkedList<CompressionMethod>();
        for (byte b : bytesToConvert) {
            CompressionMethod method = CompressionMethod.getCompressionMethod(b);
            if (method == null) {
                LOGGER.warn("Could not convert {} into a CompressionMethod", (Object)b);
                continue;
            }
            list.add(method);
        }
        return list;
    }

    private List<CipherSuite> convertCipherSuites(byte[] bytesToConvert) {
        if (bytesToConvert.length % 2 != 0) {
            LOGGER.warn("Cannot convert: {} to a List<CipherSuite>", (Object)bytesToConvert);
            return null;
        }
        LinkedList<CipherSuite> list = new LinkedList<CipherSuite>();
        for (int i = 0; i < bytesToConvert.length; i += 2) {
            byte[] copied = new byte[]{bytesToConvert[i], bytesToConvert[i + 1]};
            CipherSuite suite = CipherSuite.getCipherSuite(copied);
            if (suite == null) {
                LOGGER.warn("Cannot convert {} to a CipherSuite", (Object)copied);
                continue;
            }
            list.add(suite);
        }
        return list;
    }

    @Override
    public void adjustContextAfterSerialize(Message message) {
        if (this.tlsContext.getChooser().getConnectionEndType() == ConnectionEndType.CLIENT && this.tlsContext.isExtensionProposed(ExtensionType.EARLY_DATA)) {
            try {
                this.adjustEarlyTrafficSecret();
                this.setClientRecordCipherEarly();
            }
            catch (CryptoException ex) {
                LOGGER.warn("Encountered an exception in adjust after Serialize", (Throwable)ex);
            }
        }
    }

    private void adjustEarlyTrafficSecret() throws CryptoException {
        HKDFAlgorithm hkdfAlgorithm = AlgorithmResolver.getHKDFAlgorithm(this.tlsContext.getChooser().getEarlyDataCipherSuite());
        DigestAlgorithm digestAlgo = AlgorithmResolver.getDigestAlgorithm(ProtocolVersion.TLS13, this.tlsContext.getChooser().getEarlyDataCipherSuite());
        byte[] earlySecret = HKDFunction.extract(hkdfAlgorithm, new byte[0], this.tlsContext.getChooser().getEarlyDataPsk());
        this.tlsContext.setEarlySecret(earlySecret);
        byte[] earlyTrafficSecret = HKDFunction.deriveSecret(hkdfAlgorithm, digestAlgo.getJavaName(), this.tlsContext.getChooser().getEarlySecret(), "c e traffic", this.tlsContext.getDigest().getRawBytes(), this.tlsContext.getChooser().getSelectedProtocolVersion());
        this.tlsContext.setClientEarlyTrafficSecret(earlyTrafficSecret);
        LOGGER.debug("EarlyTrafficSecret: {}", (Object)earlyTrafficSecret);
    }

    private void setClientRecordCipherEarly() throws CryptoException {
        try {
            this.tlsContext.setActiveClientKeySetType(Tls13KeySetType.EARLY_TRAFFIC_SECRETS);
            LOGGER.debug("Setting cipher for client to use early secrets");
            KeySet clientKeySet = KeyDerivator.generateKeySet(this.tlsContext, ProtocolVersion.TLS13, this.tlsContext.getActiveClientKeySetType());
            if (this.tlsContext.getChooser().getConnectionEndType() == ConnectionEndType.SERVER) {
                this.tlsContext.getRecordLayer().updateDecryptionCipher(RecordCipherFactory.getRecordCipher(this.tlsContext, clientKeySet, this.tlsContext.getChooser().getEarlyDataCipherSuite(), this.tlsContext.getReadConnectionId()));
            } else if (this.tlsContext.getRecordLayer() != null) {
                this.tlsContext.getRecordLayer().updateEncryptionCipher(RecordCipherFactory.getRecordCipher(this.tlsContext, clientKeySet, this.tlsContext.getChooser().getEarlyDataCipherSuite(), this.tlsContext.getWriteConnectionId()));
            }
        }
        catch (NoSuchAlgorithmException ex) {
            LOGGER.error("Unable to generate KeySet - unknown algorithm");
            throw new CryptoException((Throwable)ex);
        }
    }

    private void warnOnConflictingExtensions() {
        if (this.tlsContext.getTalkingConnectionEndType() == this.tlsContext.getChooser().getMyConnectionPeer() && this.tlsContext.isExtensionProposed(ExtensionType.MAX_FRAGMENT_LENGTH) && this.tlsContext.isExtensionProposed(ExtensionType.RECORD_SIZE_LIMIT)) {
            LOGGER.warn("Client sent max_fragment_length AND record_size_limit extensions");
        }
    }
}

