Skip to content

Commit

Permalink
Add minimal TLS1.3 support on supported windows versions
Browse files Browse the repository at this point in the history
  • Loading branch information
steffengy committed Sep 21, 2024
1 parent 0fd2026 commit 640f61d
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 13 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ default-target = "x86_64-pc-windows-msvc"
windows-sys = { version = "0.59", features = [
"Win32_Foundation", "Win32_Security_Cryptography",
"Win32_Security_Authentication_Identity", "Win32_Security_Credentials",
"Win32_System_Memory"] }
"Win32_System_LibraryLoader", "Win32_System_Memory"
] }

[dev-dependencies]
windows-sys = { version = "0.59", features = ["Win32_System_SystemInformation", "Win32_System_Time"] }
71 changes: 59 additions & 12 deletions src/schannel_cred.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,31 @@ impl Protocol {
}
}

fn verify_min_os_build(major: u32, build: u32) -> Option<()> {
use windows_sys::Win32::System::SystemInformation::OSVERSIONINFOW;

let handle = std::ptr::NonNull::new(unsafe {
windows_sys::Win32::System::LibraryLoader::GetModuleHandleW(windows_sys::w!("ntdll.dll"))
})?;
let rtl_get_ver = unsafe {
windows_sys::Win32::System::LibraryLoader::GetProcAddress(handle.as_ptr(), windows_sys::s!("RtlGetVersion"))
}?;

type RtlGetVersionFunc = unsafe extern "system" fn(*mut OSVERSIONINFOW) -> i32;
let proc: RtlGetVersionFunc = unsafe { mem::transmute(rtl_get_ver) };

let mut info: OSVERSIONINFOW = unsafe { mem::zeroed() };
info.dwOSVersionInfoSize = mem::size_of::<OSVERSIONINFOW>() as u32;

unsafe { proc(&mut info) };

if info.dwMajorVersion > major || (info.dwMajorVersion == major && info.dwBuildNumber >= build) {
Some(())
} else {
None
}
}

/// A builder type for `SchannelCred`s.
#[derive(Default, Debug)]
pub struct Builder {
Expand Down Expand Up @@ -220,36 +245,58 @@ impl Builder {
/// Creates a new `SchannelCred`.
pub fn acquire(&self, direction: Direction) -> io::Result<SchannelCred> {
unsafe {
let mut handle: Credentials::SecHandle = mem::zeroed();
let mut cred_data: Identity::SCHANNEL_CRED = mem::zeroed();
cred_data.dwVersion = Identity::SCHANNEL_CRED_VERSION;
cred_data.dwFlags =
Identity::SCH_USE_STRONG_CRYPTO | Identity::SCH_CRED_NO_DEFAULT_CREDS;
if let Some(ref supported_algorithms) = self.supported_algorithms {
cred_data.cSupportedAlgs = supported_algorithms.len() as u32;
cred_data.palgSupportedAlgs = supported_algorithms.as_ptr() as *mut _;
}
if let Some(ref enabled_protocols) = self.enabled_protocols {
cred_data.grbitEnabledProtocols = enabled_protocols
let mut enabled_protocols: u32 = 0;
if let Some(ref enable_list) = self.enabled_protocols {
enabled_protocols = enable_list
.iter()
.map(|p| p.dword(direction))
.fold(0, |acc, p| acc | p);
}

let mut cred_data: Identity::SCHANNEL_CRED = mem::zeroed();
cred_data.dwVersion = Identity::SCHANNEL_CRED_VERSION;
cred_data.dwFlags = Identity::SCH_USE_STRONG_CRYPTO | Identity::SCH_CRED_NO_DEFAULT_CREDS;
cred_data.grbitEnabledProtocols = enabled_protocols;
let mut certs = self.certs.iter().map(|c| c.as_inner()).collect::<Vec<_>>();
cred_data.cCreds = certs.len() as u32;
cred_data.paCred = certs.as_mut_ptr() as _;

let mut tls_param: Identity::TLS_PARAMETERS = mem::zeroed();
let mut cred_data2: Identity::SCH_CREDENTIALS = mem::zeroed();

let mut pauthdata: *const core::ffi::c_void = ptr::null();
if let Some(ref supported_algorithms) = self.supported_algorithms {
cred_data.cSupportedAlgs = supported_algorithms.len() as u32;
cred_data.palgSupportedAlgs = supported_algorithms.as_ptr() as *mut _;
} else if verify_min_os_build(10, 17763).is_some() {
// If no algorithms specified and should be supported, use new SCH_CREDENTIALS interface which supports TLS1.3.
// Although we check for win10 build 17763 above, I have only seen this work on win 11.
tls_param.grbitDisabledProtocols = !enabled_protocols;
// TODO: support something to select tls13-ciphers
cred_data2.dwVersion = Identity::SCH_CREDENTIALS_VERSION;
cred_data2.dwFlags = Identity::SCH_USE_STRONG_CRYPTO | Identity::SCH_CRED_NO_DEFAULT_CREDS;
cred_data2.cCreds = certs.len() as u32;
cred_data2.paCred = certs.as_mut_ptr() as _;
cred_data2.cTlsParameters = 1;
cred_data2.pTlsParameters = &mut tls_param;
pauthdata = &mut cred_data2 as *const _ as *const _;
}

if pauthdata.is_null() {
pauthdata = &mut cred_data as *const _ as *const _;
}
let direction = match direction {
Direction::Inbound => Identity::SECPKG_CRED_INBOUND,
Direction::Outbound => Identity::SECPKG_CRED_OUTBOUND,
};
let mut handle: Credentials::SecHandle = mem::zeroed();

match Identity::AcquireCredentialsHandleA(
ptr::null(),
Identity::UNISP_NAME_A,
direction,
ptr::null_mut(),
&mut cred_data as *const _ as *const _,
pauthdata,
None,
ptr::null_mut(),
&mut handle,
Expand Down
21 changes: 21 additions & 0 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,27 @@ fn verify_callback_success() {
assert!(out.ends_with(b"</html>\n"));
}

#[test]
fn tls_13() {
let creds = SchannelCred::builder()
.enabled_protocols(&[Protocol::Tls12, Protocol::Tls13])
.acquire(Direction::Outbound)
.unwrap();
let stream = TcpStream::connect("tls13.akamai.io:443").unwrap();
let mut stream = tls_stream::Builder::new()
.domain("tls13.akamai.io")
.connect(creds, stream)
.unwrap();
stream
.write_all(b"GET / HTTP/1.0\r\nHost: tls13.akamai.io\r\n\r\n")
.unwrap();
let mut out = vec![];
stream.read_to_end(&mut out).unwrap();

let pattern = b"Your client negotiated TLS 1.3";
assert!(out.windows(pattern.len()).any(|x| x == pattern));
}

#[test]
fn verify_callback_error() {
let creds = SchannelCred::builder()
Expand Down

0 comments on commit 640f61d

Please sign in to comment.