diff --git a/proxy/src/main/java/com/velocitypowered/proxy/connection/MinecraftConnection.java b/proxy/src/main/java/com/velocitypowered/proxy/connection/MinecraftConnection.java index 8a6a22e90..0071716d9 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/connection/MinecraftConnection.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/connection/MinecraftConnection.java @@ -67,7 +67,7 @@ import io.netty.util.ReferenceCountUtil; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.GeneralSecurityException; -import java.util.HashMap; +import java.util.EnumMap; import java.util.Map; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -109,7 +109,7 @@ public class MinecraftConnection extends ChannelInboundHandlerAdapter { this.server = server; this.state = StateRegistry.HANDSHAKE; - this.sessionHandlers = new HashMap<>(); + this.sessionHandlers = new EnumMap<>(StateRegistry.class); } @Override diff --git a/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java b/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java index ae84e7495..ce0d34268 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java @@ -57,7 +57,11 @@ public class MinecraftDecoder extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof ByteBuf buf) { - tryDecode(ctx, buf); + try { + tryDecode(ctx, buf); + } finally { + buf.release(); + } } else { ctx.fireChannelRead(msg); } @@ -65,7 +69,6 @@ public class MinecraftDecoder extends ChannelInboundHandlerAdapter { private void tryDecode(ChannelHandlerContext ctx, ByteBuf buf) throws Exception { if (!ctx.channel().isActive() || !buf.isReadable()) { - buf.release(); return; } @@ -75,27 +78,22 @@ public class MinecraftDecoder extends ChannelInboundHandlerAdapter { if (packet == null) { buf.readerIndex(originalReaderIndex); if (this.direction == ProtocolUtils.Direction.SERVERBOUND && this.state != StateRegistry.PLAY) { - buf.release(); throw this.handleInvalidPacketId(packetId); } - ctx.fireChannelRead(buf); + ctx.fireChannelRead(buf.retain()); } else { + doLengthSanityChecks(buf, packet); + try { - doLengthSanityChecks(buf, packet); - - try { - packet.decode(buf, direction, registry.version); - } catch (Exception e) { - throw handleDecodeFailure(e, packet, packetId); - } - - if (buf.isReadable()) { - throw handleOverflow(packet, buf.readerIndex(), buf.writerIndex()); - } - ctx.fireChannelRead(packet); - } finally { - buf.release(); + packet.decode(buf, direction, registry.version); + } catch (Exception e) { + throw handleDecodeFailure(e, packet, packetId); } + + if (buf.isReadable()) { + throw handleOverflow(packet, buf.readerIndex(), buf.writerIndex()); + } + ctx.fireChannelRead(packet); } } diff --git a/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftVarintFrameDecoder.java b/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftVarintFrameDecoder.java index 96c676593..56c58619f 100644 --- a/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftVarintFrameDecoder.java +++ b/proxy/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftVarintFrameDecoder.java @@ -94,7 +94,6 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder { in.readerIndex(packetStart); // try to read the length of the packet - in.markReaderIndex(); try { int length = readRawVarInt21(in); if (packetStart == in.readerIndex()) { @@ -107,6 +106,7 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder { if (length > 0) { if (state == StateRegistry.HANDSHAKE && direction == ProtocolUtils.Direction.SERVERBOUND) { if (validateServerboundHandshakePacket(in, length)) { + in.readerIndex(packetStart); return; } } @@ -115,7 +115,7 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder { // note that zero-length packets are ignored if (length > 0) { if (in.readableBytes() < length) { - in.resetReaderIndex(); + in.readerIndex(packetStart); } else { // If enabled, rate-limit serverbound payload bytes based on frame length if (packetLimiter != null) { @@ -130,7 +130,7 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder { } } catch (Exception e) { // Reset buffer to consistent state before propagating exception to prevent memory leaks - in.resetReaderIndex(); + in.readerIndex(packetStart); throw e; } } @@ -140,40 +140,33 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder { state.getProtocolRegistry(direction, ProtocolVersion.MINIMUM_VERSION); final int index = in.readerIndex(); - try { - final int packetId = readRawVarInt21(in); - // Index hasn't changed, we've read nothing - if (index == in.readerIndex()) { - in.resetReaderIndex(); - return true; - } - final int payloadLength = length - ProtocolUtils.varIntBytes(packetId); - - MinecraftPacket packet = registry.createPacket(packetId); - - // We handle every packet in this phase, if you said something we don't know, something is really wrong - if (packet == null) { - throw UNKNOWN_PACKET; - } - - // We 'technically' have the incoming bytes of a payload here, and so, these can actually parse - // the packet if needed, so, we'll take advantage of the existing methods - int expectedMinLen = packet.decodeExpectedMinLength(in, direction, registry.version); - int expectedMaxLen = packet.decodeExpectedMaxLength(in, direction, registry.version); - if (expectedMaxLen != -1 && payloadLength > expectedMaxLen) { - throw handleOverflow(packet, expectedMaxLen, in.readableBytes()); - } - if (payloadLength < expectedMinLen) { - throw handleUnderflow(packet, expectedMaxLen, in.readableBytes()); - } - - in.readerIndex(index); - return false; - } catch (Exception e) { - // Reset buffer to consistent state before propagating exception to prevent memory leaks - in.readerIndex(index); - throw e; + final int packetId = readRawVarInt21(in); + // Index hasn't changed, we've read nothing + if (index == in.readerIndex()) { + return true; } + final int payloadLength = length - ProtocolUtils.varIntBytes(packetId); + + MinecraftPacket packet = registry.createPacket(packetId); + + // We handle every packet in this phase, if you said something we don't know, something is really wrong + if (packet == null) { + throw UNKNOWN_PACKET; + } + + // We 'technically' have the incoming bytes of a payload here, and so, these can actually parse + // the packet if needed, so, we'll take advantage of the existing methods + int expectedMinLen = packet.decodeExpectedMinLength(in, direction, registry.version); + int expectedMaxLen = packet.decodeExpectedMaxLength(in, direction, registry.version); + if (expectedMaxLen != -1 && payloadLength > expectedMaxLen) { + throw handleOverflow(packet, expectedMaxLen, in.readableBytes()); + } + if (payloadLength < expectedMinLen) { + throw handleUnderflow(packet, expectedMaxLen, in.readableBytes()); + } + + in.readerIndex(index); + return false; } @Override