/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.layer.impl;

import de.rub.nds.protocol.exception.CryptoException;
import de.rub.nds.protocol.exception.EndOfStreamException;
import de.rub.nds.protocol.exception.TimeoutException;
import de.rub.nds.protocol.util.SilentByteArrayOutputStream;
import de.rub.nds.tlsattacker.core.layer.AcknowledgingProtocolLayer;
import de.rub.nds.tlsattacker.core.layer.LayerConfiguration;
import de.rub.nds.tlsattacker.core.layer.LayerProcessingResult;
import de.rub.nds.tlsattacker.core.layer.constant.ImplementedLayers;
import de.rub.nds.tlsattacker.core.layer.hints.LayerProcessingHint;
import de.rub.nds.tlsattacker.core.layer.hints.QuicPacketLayerHint;
import de.rub.nds.tlsattacker.core.layer.stream.HintedInputStream;
import de.rub.nds.tlsattacker.core.layer.stream.HintedLayerInputStream;
import de.rub.nds.tlsattacker.core.quic.constants.QuicPacketType;
import de.rub.nds.tlsattacker.core.quic.constants.QuicVersion;
import de.rub.nds.tlsattacker.core.quic.crypto.QuicDecryptor;
import de.rub.nds.tlsattacker.core.quic.crypto.QuicEncryptor;
import de.rub.nds.tlsattacker.core.quic.packet.HandshakePacket;
import de.rub.nds.tlsattacker.core.quic.packet.InitialPacket;
import de.rub.nds.tlsattacker.core.quic.packet.OneRTTPacket;
import de.rub.nds.tlsattacker.core.quic.packet.QuicPacket;
import de.rub.nds.tlsattacker.core.quic.packet.RetryPacket;
import de.rub.nds.tlsattacker.core.quic.packet.VersionNegotiationPacket;
import de.rub.nds.tlsattacker.core.quic.packet.ZeroRTTPacket;
import de.rub.nds.tlsattacker.core.state.Context;
import de.rub.nds.tlsattacker.core.state.quic.QuicContext;
import java.io.IOException;
import java.io.InputStream;
import java.net.PortUnreachableException;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class QuicPacketLayer
extends AcknowledgingProtocolLayer<Context, QuicPacketLayerHint, QuicPacket> {
    private static final Logger LOGGER = LogManager.getLogger();
    private final Context context;
    private final QuicContext quicContext;
    private final QuicDecryptor decryptor;
    private final QuicEncryptor encryptor;
    private final Map<QuicPacketType, ArrayList<QuicPacket>> receivedPacketBuffer = new HashMap<QuicPacketType, ArrayList<QuicPacket>>();

    public QuicPacketLayer(Context context) {
        super(ImplementedLayers.QUICPACKET);
        this.context = context;
        this.quicContext = context.getQuicContext();
        this.decryptor = new QuicDecryptor(context.getQuicContext());
        this.encryptor = new QuicEncryptor(context.getQuicContext());
        Arrays.stream(QuicPacketType.values()).forEach(quicPacketType -> this.receivedPacketBuffer.put((QuicPacketType)((Object)quicPacketType), new ArrayList()));
    }

    @Override
    public LayerProcessingResult<QuicPacket> sendConfiguration() throws IOException {
        LayerConfiguration configuration = this.getLayerConfiguration();
        if (configuration != null && configuration.getContainerList() != null) {
            for (QuicPacket packet : this.getUnprocessedConfiguredContainers()) {
                if (this.isEmptyPacket(packet)) continue;
                try {
                    byte[] bytes = this.writePacket(packet);
                    this.addProducedContainer(packet);
                    this.getLowerLayer().sendData(null, bytes);
                }
                catch (CryptoException ex) {
                    LOGGER.error((Object)ex);
                }
            }
        }
        return this.getLayerResult();
    }

    @Override
    public LayerProcessingResult<QuicPacket> sendData(LayerProcessingHint hint, byte[] data) throws IOException {
        QuicPacketType hintedType = QuicPacketType.UNKNOWN;
        if (hint != null && hint instanceof QuicPacketLayerHint) {
            hintedType = ((QuicPacketLayerHint)hint).getQuicPacketType();
        } else {
            LOGGER.warn("Sending packet without a LayerProcessing hint. Using UNKNOWN as the type.");
        }
        List givenPackets = this.getUnprocessedConfiguredContainers();
        try {
            if (this.getLayerConfiguration().getContainerList() != null && givenPackets.size() > 0) {
                QuicPacket packet = (QuicPacket)givenPackets.get(0);
                byte[] bytes = this.writePacket(data, packet);
                this.addProducedContainer(packet);
                this.getLowerLayer().sendData(null, bytes);
            } else {
                switch (hintedType) {
                    case INITIAL_PACKET: {
                        InitialPacket initialPacket = new InitialPacket();
                        byte[] initialPacketBytes = this.writePacket(data, initialPacket);
                        this.addProducedContainer(initialPacket);
                        this.getLowerLayer().sendData(null, initialPacketBytes);
                        break;
                    }
                    case HANDSHAKE_PACKET: {
                        HandshakePacket handshakePacket = new HandshakePacket();
                        byte[] handshakePacketBytes = this.writePacket(data, handshakePacket);
                        this.addProducedContainer(handshakePacket);
                        this.getLowerLayer().sendData(null, handshakePacketBytes);
                        break;
                    }
                    case ONE_RTT_PACKET: {
                        OneRTTPacket oneRTTPacket = new OneRTTPacket();
                        byte[] oneRTTPacketBytes = this.writePacket(data, oneRTTPacket);
                        this.addProducedContainer(oneRTTPacket);
                        this.getLowerLayer().sendData(null, oneRTTPacketBytes);
                        break;
                    }
                    case ZERO_RTT_PACKET: {
                        ZeroRTTPacket zeroRTTPacket = new ZeroRTTPacket();
                        byte[] zeroRTTPacketBytes = this.writePacket(data, zeroRTTPacket);
                        this.addProducedContainer(zeroRTTPacket);
                        this.getLowerLayer().sendData(null, zeroRTTPacketBytes);
                        break;
                    }
                    case RETRY_PACKET: {
                        throw new UnsupportedOperationException("Retry Packet - Not supported yet.");
                    }
                    case VERSION_NEGOTIATION: {
                        throw new UnsupportedOperationException("Version Negotiation Packet - Not supported yet.");
                    }
                    case UNKNOWN: {
                        throw new UnsupportedOperationException("Unknown Packet - Not supported yet.");
                    }
                }
            }
        }
        catch (CryptoException ex) {
            LOGGER.error((Object)ex);
        }
        return this.getLayerResult();
    }

    @Override
    public LayerProcessingResult<QuicPacket> receiveData() {
        try {
            do {
                HintedInputStream dataStream = this.getLowerLayer().getDataStream();
                this.readPackets(dataStream);
            } while (this.shouldContinueProcessing());
        }
        catch (TimeoutException | SocketTimeoutException ex) {
            LOGGER.debug("Received a timeout");
            LOGGER.trace((Object)ex);
        }
        catch (PortUnreachableException ex) {
            LOGGER.debug("Destination port undreachable");
            LOGGER.trace((Object)ex);
        }
        catch (EndOfStreamException ex) {
            LOGGER.debug("Reached end of stream, cannot parse more messages");
            LOGGER.trace((Object)ex);
        }
        catch (IOException ex) {
            LOGGER.warn("The lower layer did not produce a data stream: ", (Throwable)ex);
        }
        return this.getLayerResult();
    }

    @Override
    public void receiveMoreDataForHint(LayerProcessingHint hint) throws IOException {
        try {
            HintedInputStream dataStream = this.getLowerLayer().getDataStream();
            this.readPackets(dataStream);
        }
        catch (PortUnreachableException ex) {
            LOGGER.debug("Received a ICMP Port Unreachable");
            LOGGER.trace((Object)ex);
        }
        catch (TimeoutException | SocketTimeoutException ex) {
            LOGGER.debug("Received a timeout");
            LOGGER.trace((Object)ex);
        }
        catch (EndOfStreamException ex) {
            LOGGER.debug("Reached end of stream, cannot parse more messages");
            LOGGER.trace((Object)ex);
        }
    }

    private void readPackets(InputStream dataStream) throws IOException {
        SilentByteArrayOutputStream outputStream = new SilentByteArrayOutputStream();
        if (dataStream.available() == 0) {
            throw new EndOfStreamException();
        }
        int firstByte = dataStream.read();
        if (firstByte == 0) {
            dataStream.readNBytes(dataStream.available());
        } else {
            QuicPacketType packetType;
            byte[] versionBytes = new byte[]{};
            if (QuicPacketType.isLongHeaderPacket(firstByte)) {
                versionBytes = dataStream.readNBytes(4);
                QuicVersion quicVersion = QuicVersion.getFromVersionBytes(versionBytes);
                if (quicVersion == QuicVersion.NULL_VERSION) {
                    packetType = QuicPacketType.VERSION_NEGOTIATION;
                } else {
                    packetType = QuicPacketType.getPacketTypeFromFirstByte(quicVersion, firstByte);
                    if (quicVersion != this.quicContext.getQuicVersion() && packetType != QuicPacketType.VERSION_NEGOTIATION) {
                        LOGGER.warn("Received packet with unexpected QUIC version, ignoring it.");
                        packetType = QuicPacketType.UNKNOWN;
                    }
                }
            } else {
                packetType = QuicPacketType.getPacketTypeFromFirstByte(this.quicContext.getQuicVersion(), firstByte);
            }
            switch (packetType) {
                case INITIAL_PACKET: {
                    this.receivedPacketBuffer.get((Object)packetType).add(this.readInitialPacket(firstByte, versionBytes, dataStream));
                    break;
                }
                case HANDSHAKE_PACKET: {
                    this.receivedPacketBuffer.get((Object)packetType).add(this.readHandshakePacket(firstByte, versionBytes, dataStream));
                    break;
                }
                case ONE_RTT_PACKET: {
                    this.receivedPacketBuffer.get((Object)packetType).add(this.readOneRTTPacket(firstByte, dataStream));
                    break;
                }
                case ZERO_RTT_PACKET: {
                    throw new UnsupportedOperationException("Unknown Packet - Not supported yet.");
                }
                case RETRY_PACKET: {
                    this.receivedPacketBuffer.get((Object)packetType).add(this.readRetryPacket(firstByte, dataStream));
                    break;
                }
                case VERSION_NEGOTIATION: {
                    this.receivedPacketBuffer.get((Object)packetType).add(this.readVersionNegotiationPacket(dataStream));
                    break;
                }
                case UNKNOWN: {
                    throw new UnsupportedOperationException("Unknown Packet - Not supported yet.");
                }
            }
        }
        this.decryptInitialPacketsInBuffer();
        this.decryptHandshakePacketsInBuffer();
        this.decryptOneRRTPacketsInBuffer();
        QuicPacketType packetTypeToProcess = this.getPacketTypeToProcessNext();
        if (packetTypeToProcess != null) {
            ArrayList<QuicPacket> packets = this.receivedPacketBuffer.get((Object)packetTypeToProcess);
            QuicPacket packet = packets.remove(0);
            LOGGER.debug("Processing {} Packet: {}", (Object)packetTypeToProcess, (Object)packet.getPlainPacketNumber());
            this.receivedPacketBuffer.put(packetTypeToProcess, packets);
            outputStream.write((byte[])packet.getUnprotectedPayload().getValue());
            this.quicContext.getReceivedPackets().add(packet.getPacketType());
        }
        if (this.currentInputStream == null) {
            this.currentInputStream = new HintedLayerInputStream(null, this);
            this.currentInputStream.extendStream(outputStream.toByteArray());
        } else {
            this.currentInputStream.extendStream(outputStream.toByteArray());
        }
        outputStream.flush();
    }

    private byte[] writePacket(byte[] data, QuicPacket packet) throws CryptoException {
        packet.setUnprotectedPayload(data);
        return this.writePacket(packet);
    }

    private byte[] writePacket(QuicPacket packet) throws CryptoException {
        switch (packet.getPacketType()) {
            case INITIAL_PACKET: {
                return this.writeInitialPacket((InitialPacket)packet);
            }
            case HANDSHAKE_PACKET: {
                return this.writeHandshakePacket((HandshakePacket)packet);
            }
            case ONE_RTT_PACKET: {
                return this.writeOneRTTPacket((OneRTTPacket)packet);
            }
            case ZERO_RTT_PACKET: {
                return this.writeZeroRTTPacket((ZeroRTTPacket)packet);
            }
            case RETRY_PACKET: {
                throw new UnsupportedOperationException("Retry Packet - Not supported yet.");
            }
            case VERSION_NEGOTIATION: {
                throw new UnsupportedOperationException("Version Negotiation Packet - Not supported yet.");
            }
            case UNKNOWN: {
                throw new UnsupportedOperationException("Unknown Packet - Not supported yet.");
            }
        }
        return null;
    }

    private byte[] writeInitialPacket(InitialPacket packet) throws CryptoException {
        packet.getPreparator(this.context).prepare();
        this.encryptor.encryptInitialPacket(packet);
        packet.updateFlagsWithEncodedPacketNumber();
        this.encryptor.addHeaderProtectionInitial(packet);
        return packet.getSerializer(this.context).serialize();
    }

    private byte[] writeHandshakePacket(HandshakePacket packet) throws CryptoException {
        packet.getPreparator(this.context).prepare();
        this.encryptor.encryptHandshakePacket(packet);
        packet.updateFlagsWithEncodedPacketNumber();
        this.encryptor.addHeaderProtectionHandshake(packet);
        return packet.getSerializer(this.context).serialize();
    }

    private byte[] writeOneRTTPacket(OneRTTPacket packet) throws CryptoException {
        packet.getPreparator(this.context).prepare();
        this.encryptor.encryptOneRRTPacket(packet);
        packet.updateFlagsWithEncodedPacketNumber();
        this.encryptor.addHeaderProtectionOneRRT(packet);
        return packet.getSerializer(this.context).serialize();
    }

    private byte[] writeZeroRTTPacket(ZeroRTTPacket packet) throws CryptoException {
        packet.getPreparator(this.context).prepare();
        this.encryptor.encryptZeroRTTPacket(packet);
        packet.updateFlagsWithEncodedPacketNumber();
        this.encryptor.addHeaderProtectionZeroRTT(packet);
        return packet.getSerializer(this.context).serialize();
    }

    private InitialPacket readInitialPacket(int flags, byte[] versionBytes, InputStream dataStream) {
        InitialPacket packet = new InitialPacket((byte)flags, versionBytes);
        packet.getParser(this.context, dataStream).parse(packet);
        return packet;
    }

    private InitialPacket decryptIntitialPacket(InitialPacket packet) throws CryptoException {
        this.decryptor.removeHeaderProtectionInitial(packet);
        packet.convertCompleteProtectedHeader();
        this.decryptor.decryptInitialPacket(packet);
        this.quicContext.addReceivedInitialPacketNumber(packet.getPlainPacketNumber());
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private HandshakePacket readHandshakePacket(int flags, byte[] versionBytes, InputStream dataStream) {
        HandshakePacket packet = new HandshakePacket((byte)flags, versionBytes);
        packet.getParser(this.context, dataStream).parse(packet);
        return packet;
    }

    private HandshakePacket decryptHandshakePacket(HandshakePacket packet) throws CryptoException {
        this.decryptor.removeHeaderProtectionHandshake(packet);
        packet.convertCompleteProtectedHeader();
        this.decryptor.decryptHandshakePacket(packet);
        this.quicContext.addReceivedHandshakePacketNumber(packet.getPlainPacketNumber());
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private OneRTTPacket readOneRTTPacket(int flags, InputStream dataStream) {
        OneRTTPacket packet = new OneRTTPacket((byte)flags);
        packet.getParser(this.context, dataStream).parse(packet);
        return packet;
    }

    private OneRTTPacket decryptOneRTTPacket(OneRTTPacket packet) throws CryptoException {
        this.decryptor.removeHeaderProtectionOneRTT(packet);
        packet.convertCompleteProtectedHeader();
        this.decryptor.decryptOneRTTPacket(packet);
        this.quicContext.addReceivedOneRTTPacketNumber(packet.getPlainPacketNumber());
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private RetryPacket readRetryPacket(int flags, InputStream dataStream) {
        RetryPacket packet = new RetryPacket((byte)flags);
        packet.getParser(this.context, dataStream).parse(packet);
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private VersionNegotiationPacket readVersionNegotiationPacket(InputStream dataStream) {
        VersionNegotiationPacket packet = new VersionNegotiationPacket();
        packet.getParser(this.context, dataStream).parse(packet);
        packet.getHandler(this.context).adjustContext(packet);
        this.addProducedContainer(packet);
        return packet;
    }

    private void decryptInitialPacketsInBuffer() {
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.INITIAL_PACKET).isEmpty() && this.quicContext.isInitialSecretsInitialized()) {
            this.receivedPacketBuffer.computeIfPresent(QuicPacketType.INITIAL_PACKET, (packetType, packets) -> (ArrayList)packets.stream().map(packet -> {
                try {
                    return packet.getUnprotectedPayload() == null ? this.decryptIntitialPacket((InitialPacket)packet) : packet;
                }
                catch (CryptoException ex) {
                    throw new CryptoException("Could not decrypt packet", (Throwable)ex);
                }
            }).sorted(Comparator.comparingInt(QuicPacket::getPlainPacketNumber)).collect(Collectors.toList()));
        }
    }

    private void decryptHandshakePacketsInBuffer() {
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.HANDSHAKE_PACKET).isEmpty() && this.quicContext.isHandshakeSecretsInitialized()) {
            this.receivedPacketBuffer.computeIfPresent(QuicPacketType.HANDSHAKE_PACKET, (packetType, packets) -> (ArrayList)packets.stream().map(packet -> {
                try {
                    return packet.getUnprotectedPayload() == null ? this.decryptHandshakePacket((HandshakePacket)packet) : packet;
                }
                catch (CryptoException ex) {
                    throw new CryptoException("Could not decrypt packet", (Throwable)ex);
                }
            }).sorted(Comparator.comparingInt(QuicPacket::getPlainPacketNumber)).collect(Collectors.toList()));
        }
    }

    private void decryptOneRRTPacketsInBuffer() {
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.ONE_RTT_PACKET).isEmpty() && this.quicContext.isApplicationSecretsInitialized()) {
            this.receivedPacketBuffer.computeIfPresent(QuicPacketType.ONE_RTT_PACKET, (packetType, packets) -> (ArrayList)packets.stream().map(packet -> {
                try {
                    return packet.getUnprotectedPayload() == null ? this.decryptOneRTTPacket((OneRTTPacket)packet) : packet;
                }
                catch (CryptoException ex) {
                    throw new CryptoException("Could not decrypt packet", (Throwable)ex);
                }
            }).sorted(Comparator.comparingInt(QuicPacket::getPlainPacketNumber)).collect(Collectors.toList()));
        }
    }

    private QuicPacketType getPacketTypeToProcessNext() {
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.INITIAL_PACKET).isEmpty() && this.quicContext.isInitialSecretsInitialized() && !this.quicContext.isHandshakeSecretsInitialized()) {
            return QuicPacketType.INITIAL_PACKET;
        }
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.HANDSHAKE_PACKET).isEmpty() && this.quicContext.isHandshakeSecretsInitialized() && !this.quicContext.isApplicationSecretsInitialized()) {
            return QuicPacketType.HANDSHAKE_PACKET;
        }
        if (!this.receivedPacketBuffer.get((Object)QuicPacketType.ONE_RTT_PACKET).isEmpty() && this.quicContext.isApplicationSecretsInitialized()) {
            return QuicPacketType.ONE_RTT_PACKET;
        }
        return null;
    }

    private boolean isEmptyPacket(QuicPacket packet) {
        return this.context.getConfig().isUseAllProvidedQuicPackets() == false && packet.getUnprotectedPayload() != null && ((byte[])packet.getUnprotectedPayload().getValue()).length == 0;
    }

    @Override
    public void sendAck(byte[] data) {
        this.context.setTalkingConnectionEndType(this.context.getConnection().getLocalConnectionEndType());
        try {
            if (this.quicContext.getReceivedPackets().getLast() == QuicPacketType.INITIAL_PACKET) {
                this.getLowerLayer().sendData(null, this.writePacket(data, new InitialPacket()));
            } else if (this.quicContext.getReceivedPackets().getLast() == QuicPacketType.HANDSHAKE_PACKET) {
                this.getLowerLayer().sendData(null, this.writePacket(data, new HandshakePacket()));
            } else if (this.quicContext.getReceivedPackets().getLast() == QuicPacketType.ONE_RTT_PACKET) {
                this.getLowerLayer().sendData(null, this.writePacket(data, new OneRTTPacket()));
            }
        }
        catch (CryptoException | IOException e) {
            LOGGER.error("Could not send ACK", e);
        }
        this.context.setTalkingConnectionEndType(this.context.getConnection().getLocalConnectionEndType().getPeer());
    }

    public void clearReceivedPacketBuffer() {
        this.receivedPacketBuffer.values().forEach(ArrayList::clear);
    }
}

