Fix some small protocol inconsistencies (#1772)

This commit is contained in:
booky
2026-04-12 21:18:42 +02:00
committed by GitHub
parent 1a41b77ccb
commit 339a4c1887
3 changed files with 47 additions and 56 deletions

View File

@@ -67,7 +67,7 @@ import io.netty.util.ReferenceCountUtil;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.util.HashMap; import java.util.EnumMap;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@@ -109,7 +109,7 @@ public class MinecraftConnection extends ChannelInboundHandlerAdapter {
this.server = server; this.server = server;
this.state = StateRegistry.HANDSHAKE; this.state = StateRegistry.HANDSHAKE;
this.sessionHandlers = new HashMap<>(); this.sessionHandlers = new EnumMap<>(StateRegistry.class);
} }
@Override @Override

View File

@@ -57,7 +57,11 @@ public class MinecraftDecoder extends ChannelInboundHandlerAdapter {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof ByteBuf buf) { if (msg instanceof ByteBuf buf) {
tryDecode(ctx, buf); try {
tryDecode(ctx, buf);
} finally {
buf.release();
}
} else { } else {
ctx.fireChannelRead(msg); ctx.fireChannelRead(msg);
} }
@@ -65,7 +69,6 @@ public class MinecraftDecoder extends ChannelInboundHandlerAdapter {
private void tryDecode(ChannelHandlerContext ctx, ByteBuf buf) throws Exception { private void tryDecode(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
if (!ctx.channel().isActive() || !buf.isReadable()) { if (!ctx.channel().isActive() || !buf.isReadable()) {
buf.release();
return; return;
} }
@@ -75,27 +78,22 @@ public class MinecraftDecoder extends ChannelInboundHandlerAdapter {
if (packet == null) { if (packet == null) {
buf.readerIndex(originalReaderIndex); buf.readerIndex(originalReaderIndex);
if (this.direction == ProtocolUtils.Direction.SERVERBOUND && this.state != StateRegistry.PLAY) { if (this.direction == ProtocolUtils.Direction.SERVERBOUND && this.state != StateRegistry.PLAY) {
buf.release();
throw this.handleInvalidPacketId(packetId); throw this.handleInvalidPacketId(packetId);
} }
ctx.fireChannelRead(buf); ctx.fireChannelRead(buf.retain());
} else { } else {
doLengthSanityChecks(buf, packet);
try { try {
doLengthSanityChecks(buf, packet); packet.decode(buf, direction, registry.version);
} catch (Exception e) {
try { throw handleDecodeFailure(e, packet, packetId);
packet.decode(buf, direction, registry.version);
} catch (Exception e) {
throw handleDecodeFailure(e, packet, packetId);
}
if (buf.isReadable()) {
throw handleOverflow(packet, buf.readerIndex(), buf.writerIndex());
}
ctx.fireChannelRead(packet);
} finally {
buf.release();
} }
if (buf.isReadable()) {
throw handleOverflow(packet, buf.readerIndex(), buf.writerIndex());
}
ctx.fireChannelRead(packet);
} }
} }

View File

@@ -94,7 +94,6 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder {
in.readerIndex(packetStart); in.readerIndex(packetStart);
// try to read the length of the packet // try to read the length of the packet
in.markReaderIndex();
try { try {
int length = readRawVarInt21(in); int length = readRawVarInt21(in);
if (packetStart == in.readerIndex()) { if (packetStart == in.readerIndex()) {
@@ -107,6 +106,7 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder {
if (length > 0) { if (length > 0) {
if (state == StateRegistry.HANDSHAKE && direction == ProtocolUtils.Direction.SERVERBOUND) { if (state == StateRegistry.HANDSHAKE && direction == ProtocolUtils.Direction.SERVERBOUND) {
if (validateServerboundHandshakePacket(in, length)) { if (validateServerboundHandshakePacket(in, length)) {
in.readerIndex(packetStart);
return; return;
} }
} }
@@ -115,7 +115,7 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder {
// note that zero-length packets are ignored // note that zero-length packets are ignored
if (length > 0) { if (length > 0) {
if (in.readableBytes() < length) { if (in.readableBytes() < length) {
in.resetReaderIndex(); in.readerIndex(packetStart);
} else { } else {
// If enabled, rate-limit serverbound payload bytes based on frame length // If enabled, rate-limit serverbound payload bytes based on frame length
if (packetLimiter != null) { if (packetLimiter != null) {
@@ -130,7 +130,7 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder {
} }
} catch (Exception e) { } catch (Exception e) {
// Reset buffer to consistent state before propagating exception to prevent memory leaks // Reset buffer to consistent state before propagating exception to prevent memory leaks
in.resetReaderIndex(); in.readerIndex(packetStart);
throw e; throw e;
} }
} }
@@ -140,40 +140,33 @@ public class MinecraftVarintFrameDecoder extends ByteToMessageDecoder {
state.getProtocolRegistry(direction, ProtocolVersion.MINIMUM_VERSION); state.getProtocolRegistry(direction, ProtocolVersion.MINIMUM_VERSION);
final int index = in.readerIndex(); final int index = in.readerIndex();
try { final int packetId = readRawVarInt21(in);
final int packetId = readRawVarInt21(in); // Index hasn't changed, we've read nothing
// Index hasn't changed, we've read nothing if (index == in.readerIndex()) {
if (index == in.readerIndex()) { return true;
in.resetReaderIndex();
return true;
}
final int payloadLength = length - ProtocolUtils.varIntBytes(packetId);
MinecraftPacket packet = registry.createPacket(packetId);
// We handle every packet in this phase, if you said something we don't know, something is really wrong
if (packet == null) {
throw UNKNOWN_PACKET;
}
// We 'technically' have the incoming bytes of a payload here, and so, these can actually parse
// the packet if needed, so, we'll take advantage of the existing methods
int expectedMinLen = packet.decodeExpectedMinLength(in, direction, registry.version);
int expectedMaxLen = packet.decodeExpectedMaxLength(in, direction, registry.version);
if (expectedMaxLen != -1 && payloadLength > expectedMaxLen) {
throw handleOverflow(packet, expectedMaxLen, in.readableBytes());
}
if (payloadLength < expectedMinLen) {
throw handleUnderflow(packet, expectedMaxLen, in.readableBytes());
}
in.readerIndex(index);
return false;
} catch (Exception e) {
// Reset buffer to consistent state before propagating exception to prevent memory leaks
in.readerIndex(index);
throw e;
} }
final int payloadLength = length - ProtocolUtils.varIntBytes(packetId);
MinecraftPacket packet = registry.createPacket(packetId);
// We handle every packet in this phase, if you said something we don't know, something is really wrong
if (packet == null) {
throw UNKNOWN_PACKET;
}
// We 'technically' have the incoming bytes of a payload here, and so, these can actually parse
// the packet if needed, so, we'll take advantage of the existing methods
int expectedMinLen = packet.decodeExpectedMinLength(in, direction, registry.version);
int expectedMaxLen = packet.decodeExpectedMaxLength(in, direction, registry.version);
if (expectedMaxLen != -1 && payloadLength > expectedMaxLen) {
throw handleOverflow(packet, expectedMaxLen, in.readableBytes());
}
if (payloadLength < expectedMinLen) {
throw handleUnderflow(packet, expectedMaxLen, in.readableBytes());
}
in.readerIndex(index);
return false;
} }
@Override @Override