Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Key Update Crash on Allocation Failure #4447

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/platform/crypt.c
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ QuicPacketKeyUpdate(
_Out_ QUIC_PACKET_KEY** NewKey
)
{
if (OldKey->Type != QUIC_PACKET_KEY_1_RTT) {
if (OldKey == NULL || OldKey->Type != QUIC_PACKET_KEY_1_RTT) {
return QUIC_STATUS_INVALID_STATE;
}

Expand Down
5 changes: 5 additions & 0 deletions src/test/MsQuicTests.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,11 @@ QuicDrillTestServerVNPacket(
_In_ int Family
);

void
QuicDrillTestKeyUpdateDuringHandshake(
_In_ int Family
);

//
// Datagram tests
//
Expand Down
9 changes: 9 additions & 0 deletions src/test/bin/quic_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2298,6 +2298,15 @@ TEST_P(WithDrillInitialPacketTokenArgs, QuicDrillTestServerVNPacket) {
}
}

TEST_P(WithDrillInitialPacketTokenArgs, QuicDrillTestKeyUpdateDuringHandshake) {
TestLoggerT<ParamType> Logger("QuicDrillTestKeyUpdateDuringHandshake", GetParam());
if (TestingKernelMode) {
//ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_DRILL_VN_PACKET_TOKEN, GetParam().Family));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add kernel support.

} else {
QuicDrillTestKeyUpdateDuringHandshake(GetParam().Family);
}
}

TEST_P(WithDatagramNegotiationArgs, DatagramNegotiation) {
TestLoggerT<ParamType> Logger("QuicTestDatagramNegotiation", GetParam());
if (TestingKernelMode) {
Expand Down
141 changes: 126 additions & 15 deletions src/test/lib/DrillDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
--*/

#include "precomp.h"
#include <quic_crypt.h>
#include <msquichelper.h>

#ifdef QUIC_CLOG
#include "DrillDescriptor.cpp.clog.h"
#endif
Expand Down Expand Up @@ -152,7 +155,7 @@ DrillVNPacketDescriptor::write(
return PacketBuffer;
}

DrillInitialPacketDescriptor::DrillInitialPacketDescriptor()
DrillInitialPacketDescriptor::DrillInitialPacketDescriptor(uint8_t SrcCidLength)
{
Type = Initial;
Header.FixedBit = 1;
Expand All @@ -161,18 +164,20 @@ DrillInitialPacketDescriptor::DrillInitialPacketDescriptor()
const uint8_t CidValMax = 8;
for (uint8_t CidVal = 0; CidVal <= CidValMax; CidVal++) {
DestCid.push_back(CidVal);
SourceCid.push_back(CidValMax - CidVal);
}

for (uint8_t CidVal = 0; CidVal < SrcCidLength; CidVal++) {
SourceCid.push_back(SrcCidLength - CidVal);
}
}

DrillBuffer
DrillInitialPacketDescriptor::write(
bool EncryptPayload
) const
{
DrillBuffer PacketBuffer = DrillPacketDescriptor::write();

size_t CalculatedPacketLength = PacketBuffer.size();

DrillBuffer EncodedTokenLength;
if (TokenLen != nullptr) {
EncodedTokenLength = QuicDrillEncodeQuicVarInt(*TokenLen);
Expand All @@ -181,24 +186,20 @@ DrillInitialPacketDescriptor::write(
}
PacketBuffer.insert(PacketBuffer.end(), EncodedTokenLength.begin(), EncodedTokenLength.end());

CalculatedPacketLength += EncodedTokenLength.size();

if (Token.size()) {
PacketBuffer.insert(PacketBuffer.end(), Token.begin(), Token.end());
CalculatedPacketLength += Token.size();
}

//
// Note: this ignores the bits in the Header that specify how many bytes
// are used. The caller must ensure these are in-sync.
// Packet number buffer.
//
DrillBuffer PacketNumberBuffer;
if (PacketNumber < 0x100) {
if (Header.PacketNumLen == 0) {
PacketNumberBuffer.push_back((uint8_t) PacketNumber);
} else if (PacketNumber < 0x10000) {
} else if (Header.PacketNumLen == 1) {
PacketNumberBuffer.push_back((uint8_t) (PacketNumber >> 8));
PacketNumberBuffer.push_back((uint8_t) PacketNumber);
} else if (PacketNumber < 0x1000000) {
} else if (Header.PacketNumLen == 2) {
PacketNumberBuffer.push_back((uint8_t) (PacketNumber >> 16));
PacketNumberBuffer.push_back((uint8_t) (PacketNumber >> 8));
PacketNumberBuffer.push_back((uint8_t) PacketNumber);
Expand All @@ -209,16 +210,17 @@ DrillInitialPacketDescriptor::write(
PacketNumberBuffer.push_back((uint8_t) PacketNumber);
}

CalculatedPacketLength += PacketNumberBuffer.size();
CalculatedPacketLength += Payload.size();

//
// Write packet length.
//
DrillBuffer PacketLengthBuffer;
if (PacketLength != nullptr) {
PacketLengthBuffer = QuicDrillEncodeQuicVarInt(*PacketLength);
} else {
size_t CalculatedPacketLength = PacketNumberBuffer.size() + Payload.size();
if (EncryptPayload) {
CalculatedPacketLength += CXPLAT_ENCRYPTION_OVERHEAD;
}
PacketLengthBuffer = QuicDrillEncodeQuicVarInt(CalculatedPacketLength);
}
PacketBuffer.insert(PacketBuffer.end(), PacketLengthBuffer.begin(), PacketLengthBuffer.end());
Expand All @@ -228,12 +230,121 @@ DrillInitialPacketDescriptor::write(
//
PacketBuffer.insert(PacketBuffer.end(), PacketNumberBuffer.begin(), PacketNumberBuffer.end());

auto HeaderLength = (uint16_t)PacketBuffer.size();

//
// Write payload.
//
if (Payload.size() > 0) {
PacketBuffer.insert(PacketBuffer.end(), Payload.begin(), Payload.end());
}

if (EncryptPayload) {
for (uint8_t i = 0; i < CXPLAT_ENCRYPTION_OVERHEAD; ++i) {
PacketBuffer.push_back(0);
}
encrypt(PacketBuffer, HeaderLength, (uint8_t)PacketNumberBuffer.size());
}

return PacketBuffer;
}

struct StrBuffer {
uint8_t* Data;
uint16_t Length;

StrBuffer(const char* HexBytes)
{
Length = (uint16_t)(strlen(HexBytes) / 2);
Data = new uint8_t[Length];

for (uint16_t i = 0; i < Length; ++i) {
Data[i] =
(DecodeHexChar(HexBytes[i * 2]) << 4) |
DecodeHexChar(HexBytes[i * 2 + 1]);
}
}

~StrBuffer() { delete [] Data; }
};

void
DrillInitialPacketDescriptor::encrypt(
DrillBuffer& PacketBuffer,
uint16_t HeaderLength,
uint8_t PacketNumberLength
) const
{
const QUIC_HKDF_LABELS HkdfLabels = { "quic key", "quic iv", "quic hp", "quic ku" };
const StrBuffer InitialSalt("38762cf7f55934b34d179ae6a4c80cadccbb7f0a");

QUIC_PACKET_KEY* WriteKey;
QuicPacketKeyCreateInitial(
FALSE,
&HkdfLabels,
InitialSalt.Data,
(uint8_t)DestCid.size(),
DestCid.data(),
nullptr,
&WriteKey);

uint8_t Iv[CXPLAT_IV_LENGTH];
uint64_t FullPacketNumber = PacketNumber;
QuicCryptoCombineIvAndPacketNumber(
WriteKey->Iv, (uint8_t*)&FullPacketNumber, Iv);

CxPlatEncrypt(
WriteKey->PacketKey,
Iv,
HeaderLength,
PacketBuffer.data(),
(uint16_t)PacketBuffer.size() - HeaderLength,
PacketBuffer.data() + HeaderLength);

uint8_t HpMask[16];
CxPlatHpComputeMask(
WriteKey->HeaderKey,
1,
PacketBuffer.data() + HeaderLength,
HpMask);

uint16_t PacketNumberOffset = HeaderLength - PacketNumberLength;
PacketBuffer[0] ^= HpMask[0] & 0x0F;
for (uint8_t i = 0; i < PacketNumberLength; ++i) {
PacketBuffer[PacketNumberOffset + i] ^= HpMask[i + 1];
}

QuicPacketKeyFree(WriteKey);
}

union QuicShortHeader {
uint8_t HeaderByte;
struct {
uint8_t PacketNumLen : 2;
uint8_t KeyPhase : 1;
uint8_t Reserved : 2;
uint8_t SpinBit : 1;
uint8_t FixedBit : 1;
uint8_t LongHeader : 1;
};
};

DrillBuffer
Drill1RttPacketDescriptor::write(
) const
{
DrillBuffer PacketBuffer;
QuicShortHeader Header = { 0 };
Header.PacketNumLen = 3;
Header.KeyPhase = KeyPhase;

PacketBuffer.push_back(Header.HeaderByte);
PacketBuffer.insert(PacketBuffer.end(), DestCid.begin(), DestCid.end());
PacketBuffer.push_back((uint8_t) (PacketNumber >> 24));// TODO - different packet number sizes
PacketBuffer.push_back((uint8_t) (PacketNumber >> 16));
PacketBuffer.push_back((uint8_t) (PacketNumber >> 8));
PacketBuffer.push_back((uint8_t) PacketNumber);
PacketBuffer.insert(PacketBuffer.end(), Payload.begin(), Payload.end());

return PacketBuffer;
}
25 changes: 23 additions & 2 deletions src/test/lib/DrillDescriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,34 @@ struct DrillInitialPacketDescriptor : DrillPacketDescriptor {

DrillBuffer Payload;

DrillInitialPacketDescriptor(uint8_t SrcCidLength = 9);

DrillInitialPacketDescriptor();
//
// Write this descriptor to a byte array to send on the wire.
//
virtual DrillBuffer write(bool EncryptPayload = false) const;

private:

void encrypt(DrillBuffer& PacketBuffer, uint16_t HeaderLength, uint8_t PacketNumberLength) const;
};

struct Drill1RttPacketDescriptor {

DrillBuffer DestCid;

uint8_t KeyPhase {0};

uint32_t PacketNumber {0};

DrillBuffer Payload;

Drill1RttPacketDescriptor() {}

//
// Write this descriptor to a byte array to send on the wire.
//
virtual DrillBuffer write() const;
DrillBuffer write() const;
};

enum DrillVarIntSize {
Expand Down
49 changes: 49 additions & 0 deletions src/test/lib/QuicDrill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,52 @@ QuicDrillTestServerVNPacket(

CxPlatSleep(500);
}

void
QuicDrillTestKeyUpdateDuringHandshake(
_In_ int Family
)
{
MsQuicRegistration Registration(true);
TEST_QUIC_SUCCEEDED(Registration.GetInitStatus());

if (QuitTestIsFeatureSupported(CXPLAT_DATAPATH_FEATURE_RAW)) {
return;
}

QUIC_ADDRESS_FAMILY QuicAddrFamily = (Family == 4) ? QUIC_ADDRESS_FAMILY_INET : QUIC_ADDRESS_FAMILY_INET6;
QuicAddr ServerLocalAddr(QuicAddrFamily);

MsQuicAutoAcceptListener Listener(Registration, MsQuicConnection::NoOpCallback);
TEST_QUIC_SUCCEEDED(Listener.Start("MsQuicTest", &ServerLocalAddr.SockAddr));
TEST_QUIC_SUCCEEDED(Listener.GetInitStatus());
TEST_QUIC_SUCCEEDED(Listener.GetLocalAddr(ServerLocalAddr));

DrillSender Sender;
TEST_QUIC_SUCCEEDED(
Sender.Initialize(
QUIC_TEST_LOOPBACK_FOR_AF(QuicAddrFamily),
QuicAddrFamily,
(QuicAddrFamily == QUIC_ADDRESS_FAMILY_INET) ?
ServerLocalAddr.SockAddr.Ipv4.sin_port :
ServerLocalAddr.SockAddr.Ipv6.sin6_port));

DrillInitialPacketDescriptor InitialPacketBuffer(0);
InitialPacketBuffer.Header.PacketNumLen = 3;
InitialPacketBuffer.Payload.push_back(1); // Ping frame
for (uint16_t i = 0; i < 1199; ++i) { InitialPacketBuffer.Payload.push_back(0); } // Padding frames

Drill1RttPacketDescriptor OneRttPacketBuffer;
OneRttPacketBuffer.DestCid.insert(
OneRttPacketBuffer.DestCid.end(),
InitialPacketBuffer.DestCid.begin(),
InitialPacketBuffer.DestCid.end());
OneRttPacketBuffer.KeyPhase = 1;
OneRttPacketBuffer.Payload.push_back(1); // Ping frame
for (uint16_t i = 0; i < 80; ++i) { OneRttPacketBuffer.Payload.push_back(0); } // Padding frames

TEST_QUIC_SUCCEEDED(Sender.Send(InitialPacketBuffer.write(true)));
TEST_QUIC_SUCCEEDED(Sender.Send(OneRttPacketBuffer.write()));

CxPlatSleep(500);
}
Loading