Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions src/main/kotlin/jamule/ec/packet/PacketParser.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import jamule.exception.InvalidECException
import org.slf4j.Logger
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.util.zip.DataFormatException
import java.util.zip.Inflater

@ExperimentalUnsignedTypes
Expand Down Expand Up @@ -87,16 +88,40 @@ internal class PacketParser(
private fun decompressPayload(stream: InputStream, length: UInt): UByteArray {
val compressed = stream.readNBytes(length.toInt())
val inflater = Inflater()
inflater.setInput(compressed)
val outputStream = ByteArrayOutputStream(length.toInt())
val buffer = ByteArray(8192)
while (!inflater.finished()) {
val count = inflater.inflate(buffer)
outputStream.write(buffer, 0, count)
try {
inflater.setInput(compressed)
val outputStream = ByteArrayOutputStream(length.toInt().coerceAtMost(MAX_DECOMPRESSED_SIZE))
val buffer = ByteArray(8192)
while (!inflater.finished()) {
val count = inflater.inflate(buffer)
if (count > 0) {
val decompressedSize = outputStream.size() + count
if (decompressedSize > MAX_DECOMPRESSED_SIZE) {
throw InvalidECException(
"Packet decompressed size $decompressedSize exceeds limit $MAX_DECOMPRESSED_SIZE"
)
}
outputStream.write(buffer, 0, count)
continue
}
when {
inflater.needsDictionary() ->
throw InvalidECException("Compressed payload requires a dictionary")

inflater.needsInput() ->
throw InvalidECException("Compressed payload ended before decompression completed")

else ->
throw InvalidECException("Inflater made no progress while decompressing payload")
}
}
outputStream.close()
return outputStream.toByteArray().toUByteArray()
} catch (e: DataFormatException) {
throw InvalidECException("Compressed payload is malformed", e)
} finally {
inflater.end()
}
outputStream.close()
inflater.end()
return outputStream.toByteArray().toUByteArray()
}

private fun InputStream.readUInt(): UInt =
Expand All @@ -112,6 +137,7 @@ internal class PacketParser(
companion object {
const val INDEX_TAG_COUNT = 1 // Index of the tag count in the payload
const val TAG_COUNT_SIZE = LEN_USHORT // Size of the tag count in bytes
const val MAX_DECOMPRESSED_SIZE = 50 * 1024 * 1024
}

}
}
13 changes: 12 additions & 1 deletion src/test/kotlin/jamule/ec/packet/PacketParserTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import jamule.ec.ECOpCode
import jamule.ec.ECTagName
import jamule.ec.tag.TagParser
import jamule.ec.tag.UShortTag
import jamule.exception.InvalidECException
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith

@OptIn(ExperimentalUnsignedTypes::class)
class PacketParserTest {
Expand Down Expand Up @@ -35,4 +37,13 @@ class PacketParserTest {
assertEquals(UShortTag(ECTagName.EC_TAG_STATS_UL_SPEED, 1664u), packet.tags[0])
}

}
@Test
fun `rejects malformed compressed payloads instead of spinning forever`() {
val parser = PacketParser(TagParser(logger), logger)

assertFailsWith<InvalidECException> {
parser.parse(SamplePackets.malformedCompressedPacket.inputStream())
}
}

}
4 changes: 3 additions & 1 deletion src/test/kotlin/jamule/ec/packet/SamplePackets.kt
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,7 @@ internal class SamplePackets {
"1d4e48541404041d4e485419")
.hexToByteArray()

val malformedCompressedPacket = "000000230000000100".hexToByteArray()

}
}
}
Loading