Skip to content

Commit

Permalink
shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidan63 committed Nov 25, 2024
1 parent d6ef45d commit 23e1c7c
Showing 1 changed file with 103 additions and 59 deletions.
162 changes: 103 additions & 59 deletions src/hx/libs/ssl/windows/SSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,24 @@
#include <algorithm>
#include <assert.h>

#define printf(...)
//#define printf(...)

namespace
{
static void init_sec_buffer(SecBuffer* buffer, unsigned long type, void* data, unsigned long size)
{
buffer->cbBuffer = size;
buffer->BufferType = type;
buffer->pvBuffer = data;
}

static void init_sec_buffer_desc(SecBufferDesc* desc, SecBuffer* buffers, unsigned long buffer_count)
{
desc->ulVersion = SECBUFFER_VERSION;
desc->pBuffers = buffers;
desc->cBuffers = buffer_count;
}

struct SocketWrapper : public hx::Object
{
HX_IS_INSTANCE_OF enum { _hx_ClassId = hx::clsIdSocket };
Expand Down Expand Up @@ -66,6 +80,8 @@ namespace

SecPkgContext_StreamSizes sizes;

bool connected;

SChannelContext(::String inHost)
: host(inHost)
, socket(null())
Expand All @@ -81,6 +97,7 @@ namespace
, requestFlags(ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | ISC_REQ_STREAM)
, contextFlags(0)
, sizes()
, connected(false)
{
HX_OBJ_WB_NEW_MARKED_OBJECT(this);
}
Expand Down Expand Up @@ -137,6 +154,72 @@ namespace
decryptedUpper = size;
}

void Shutdown()
{
if (!connected)
{
return;
}

connected = false;

hx::strbuf hostBuffer;

auto inputBuffer = SecBuffer();
auto inputBufferDescription = SecBufferDesc();

auto result = SECURITY_STATUS{ SEC_E_OK };
auto shutdown = SCHANNEL_SHUTDOWN;

init_sec_buffer(&inputBuffer, SECBUFFER_TOKEN, &shutdown, sizeof(shutdown));
init_sec_buffer_desc(&inputBufferDescription, &inputBuffer, 1);

if (SEC_E_OK != (result = ApplyControlToken(&ctxtHandle, &inputBufferDescription)))
{
// How should we surface the error?
}

auto outputBuffer = SecBuffer();
auto outputBufferDescription = SecBufferDesc();
auto staging = std::array<uint8_t, 1024>();

init_sec_buffer(&outputBuffer, SECBUFFER_EMPTY, staging.data(), staging.size());
init_sec_buffer_desc(&outputBufferDescription, &outputBuffer, 1);

result =
InitializeSecurityContextA(
&credHandle,
&ctxtHandle,
const_cast<SEC_CHAR*>(host.utf8_str(&hostBuffer)),
requestFlags,
0,
0,
&inputBufferDescription,
0,
&ctxtHandle,
&outputBufferDescription,
&contextFlags,
&ctxtTimestamp);

if (result == SEC_E_OK || result == SEC_I_CONTEXT_EXPIRED)
{
assert(SECBUFFER_DATA == outputBuffer.BufferType);

auto wrapper = (SocketWrapper*)socket.mPtr;
auto target = 0;
while (target < outputBuffer.cbBuffer)
{
auto sent = send(wrapper->socket, static_cast<char*>(outputBuffer.pvBuffer), outputBuffer.cbBuffer - target, 0);
if (sent <= 0)
{
return; // Error??
}

target += sent;
}
}
}

void __Mark(HX_MARK_PARAMS) override
{
HX_MARK_MEMBER(host);
Expand Down Expand Up @@ -172,20 +255,6 @@ namespace
NCryptFreeObject(key->ctx);
}

static void init_sec_buffer(SecBuffer* buffer, unsigned long type, void* data, unsigned long size)
{
buffer->cbBuffer = size;
buffer->BufferType = type;
buffer->pvBuffer = data;
}

static void init_sec_buffer_desc(SecBufferDesc* desc, SecBuffer* buffers, unsigned long buffer_count)
{
desc->ulVersion = SECBUFFER_VERSION;
desc->pBuffers = buffers;
desc->cBuffers = buffer_count;
}

void block_error()
{
auto err = WSAGetLastError();
Expand Down Expand Up @@ -227,14 +296,14 @@ void _hx_ssl_init()

Dynamic _hx_ssl_new(Dynamic hconf)
{
printf("creating new schannel context\n");

return new SChannelContext(HX_CSTRING(""));
}

void _hx_ssl_close(Dynamic hssl)
{
//
auto ctx = (SChannelContext*)hssl.mPtr;

ctx->Shutdown();
}

void _hx_ssl_debug_set(int i)
Expand Down Expand Up @@ -308,19 +377,16 @@ void _hx_ssl_handshake(Dynamic handle)
{
case SEC_E_OK:
{
printf("handshake complete\n");

QueryContextAttributes(&ctx->ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &ctx->sizes);

ctx->connected = true;
ctx->encryptedBuffer->EnsureSize(ctx->sizes.cbMaximumMessage);
ctx->decryptedBuffer->EnsureSize(ctx->sizes.cbMaximumMessage);

if (SECBUFFER_EXTRA == inputBuffers[1].BufferType)
{
ctx->encryptedBuffer->memcpy(0, static_cast<uint8_t*>(inputBuffers[1].pvBuffer), inputBuffers[1].cbBuffer);
ctx->encryptedCursor = inputBuffers[1].cbBuffer;

printf("%i bytes of extra data found\n", inputBuffers[1].cbBuffer);
}

return;
Expand Down Expand Up @@ -463,8 +529,6 @@ int _hx_ssl_send(Dynamic hssl, Array<unsigned char> buf, int p, int l)
hx::Throw(String::create(_com_error(result).ErrorMessage()));
}

printf("encrypted %zi bytes\n", usage);

auto sent = 0;
auto total = ctx->sizes.cbHeader + usage + ctx->sizes.cbTrailer;
while (sent < total)
Expand All @@ -481,8 +545,6 @@ int _hx_ssl_send(Dynamic hssl, Array<unsigned char> buf, int p, int l)
remaining -= sent;
}

printf("sent %i bytes\n", l);

return l;
}

Expand Down Expand Up @@ -510,8 +572,6 @@ int _hx_ssl_recv(Dynamic hssl, Array<unsigned char> buf, int p, int l)
{
auto taking = std::min(l, ctx->DecryptedBytes());

printf("taking %i cached bytes\n", taking);

buf->memcpy(p, &ctx->decryptedBuffer[ctx->decryptedLower], taking);

ctx->decryptedLower += taking;
Expand All @@ -532,68 +592,56 @@ int _hx_ssl_recv(Dynamic hssl, Array<unsigned char> buf, int p, int l)
switch (result = DecryptMessage(&ctx->ctxtHandle, &bufferDescription, 0, nullptr))
{
case SEC_E_OK:
case SEC_I_RENEGOTIATE:
case SEC_I_CONTEXT_EXPIRED:
{
assert(buffers[0].BufferType == SECBUFFER_STREAM_HEADER);
assert(buffers[1].BufferType == SECBUFFER_DATA);
assert(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER);

printf("decrypting successful! (given %i bytes)\n", ctx->received);
printf("header ( ptr : %p, len : %i)\n", buffers[0].pvBuffer, buffers[0].cbBuffer);
printf("message ( ptr : %p, len : %i)\n", buffers[1].pvBuffer, buffers[1].cbBuffer);
printf("trailer ( ptr : %p, len : %i)\n", buffers[2].pvBuffer, buffers[2].cbBuffer);

if (buffers[3].BufferType == SECBUFFER_EXTRA)
if (SECBUFFER_DATA == buffers[1].BufferType)
{
printf("extra ( ptr : %p, len : %i)\n", buffers[3].pvBuffer, buffers[3].cbBuffer);
ctx->AppendDecrypted(static_cast<uint8_t*>(buffers[1].pvBuffer), buffers[1].cbBuffer);
}

ctx->AppendDecrypted(static_cast<uint8_t*>(buffers[1].pvBuffer), buffers[1].cbBuffer);

if (SECBUFFER_EXTRA == buffers[3].BufferType)
{
printf("moving %i to extra buffer\n", buffers[3].cbBuffer);

std::memmove(ctx->encryptedBuffer->getBase(), buffers[3].pvBuffer, buffers[3].cbBuffer);

ctx->encryptedCursor = buffers[3].cbBuffer;
}
else
{
printf("no extra buffer, resetting recieved\n");

ctx->encryptedCursor = 0;
}

if (SEC_I_RENEGOTIATE == result)
{
hx::Throw(String::create(_com_error(result).ErrorMessage()));
}
if (SEC_I_CONTEXT_EXPIRED == result)
{
ctx->Shutdown();

return hx::Throw(HX_CSTRING("EOF"));
}

break;
}
case SEC_E_INCOMPLETE_MESSAGE:
{
assert(buffers[0].BufferType == SECBUFFER_MISSING);
assert(buffers[0].cbBuffer > 0);

printf("incomplete message\n");
printf("\tSECBUFFER_MISSING indicates it wants %i more bytes\n", buffers[0].cbBuffer);
printf("\tcurrent receive position is %i, so %i free space\n", ctx->received, ctx->input->length - ctx->received);

if (ctx->encryptedCursor + buffers[0].cbBuffer > ctx->encryptedBuffer->length)
{
printf("\t\tgrowing input buffer\n");

ctx->encryptedBuffer->EnsureSize(ctx->encryptedCursor + buffers[0].cbBuffer);
}

auto count = recv(wrapper->socket, ctx->encryptedBuffer->getBase() + ctx->encryptedCursor, buffers[0].cbBuffer, 0);
if (count <= 0)
{
printf("about to throw leaving behind %i encrypted and %i decrypted bytes\n", ctx->received, ctx->decrypted->length);

block_error();
}

ctx->encryptedCursor += count;

printf("socket read, added %i\n", count);

break;
}
default:
Expand All @@ -603,16 +651,12 @@ int _hx_ssl_recv(Dynamic hssl, Array<unsigned char> buf, int p, int l)
}
else
{
printf("no buffered input, reading block from socket (%i)\n", ctx->sizes.cbBlockSize);

auto count = recv(wrapper->socket, ctx->encryptedBuffer->getBase(), ctx->encryptedBuffer->length, 0);
if (count <= 0)
{
block_error();
}

printf("added to received buffer (total %i)\n", count);

ctx->encryptedCursor = count;
}
}
Expand Down

0 comments on commit 23e1c7c

Please sign in to comment.