Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal TLS1.3 support on supported windows versions #109

Merged
merged 1 commit into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: false
cancel-in-progress: true

env:
CARGO_INCREMENTAL: 0
Expand All @@ -24,6 +24,8 @@ jobs:
- target: i686-pc-windows-gnu
channel: 1.60.0
os: windows-2022
env:
SCHANNEL_SKIP_TLS_13_TEST: ${{ matrix.os == 'windows-2019' && '1' || '0' }}
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@master
Expand All @@ -32,6 +34,8 @@ jobs:
target: ${{ matrix.target }}
- if: matrix.target == 'i686-pc-windows-gnu'
uses: MinoruSekine/setup-scoop@main
- shell: cmd
run: echo %SCHANNEL_SKIP_TLS_13_TEST%
- if: matrix.target == 'i686-pc-windows-gnu'
run: |
scoop install -a 32bit mingw
Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ default-target = "x86_64-pc-windows-msvc"
windows-sys = { version = "0.59", features = [
"Win32_Foundation", "Win32_Security_Cryptography",
"Win32_Security_Authentication_Identity", "Win32_Security_Credentials",
"Win32_System_Memory"] }
"Win32_System_LibraryLoader", "Win32_System_Memory"
] }

[dev-dependencies]
windows-sys = { version = "0.59", features = ["Win32_System_SystemInformation", "Win32_System_Time"] }
75 changes: 62 additions & 13 deletions src/schannel_cred.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,31 @@ impl Protocol {
}
}

fn verify_min_os_build(major: u32, build: u32) -> Option<()> {
use windows_sys::Win32::System::SystemInformation::OSVERSIONINFOW;

let handle = std::ptr::NonNull::new(unsafe {
windows_sys::Win32::System::LibraryLoader::GetModuleHandleW(windows_sys::w!("ntdll.dll"))
})?;
let rtl_get_ver = unsafe {
windows_sys::Win32::System::LibraryLoader::GetProcAddress(handle.as_ptr(), windows_sys::s!("RtlGetVersion"))
}?;

type RtlGetVersionFunc = unsafe extern "system" fn(*mut OSVERSIONINFOW) -> i32;
let proc: RtlGetVersionFunc = unsafe { mem::transmute(rtl_get_ver) };

let mut info: OSVERSIONINFOW = unsafe { mem::zeroed() };
info.dwOSVersionInfoSize = mem::size_of::<OSVERSIONINFOW>() as u32;

unsafe { proc(&mut info) };

if info.dwMajorVersion > major || (info.dwMajorVersion == major && info.dwBuildNumber >= build) {
Some(())
} else {
None
}
}

/// A builder type for `SchannelCred`s.
#[derive(Default, Debug)]
pub struct Builder {
Expand Down Expand Up @@ -219,37 +244,61 @@ impl Builder {

/// Creates a new `SchannelCred`.
pub fn acquire(&self, direction: Direction) -> io::Result<SchannelCred> {
let mut enabled_protocols: u32 = 0;
if let Some(ref enable_list) = self.enabled_protocols {
enabled_protocols = enable_list
.iter()
.map(|p| p.dword(direction))
.fold(0, |acc, p| acc | p);
}

unsafe {
let mut handle: Credentials::SecHandle = mem::zeroed();
let mut cred_data: Identity::SCHANNEL_CRED = mem::zeroed();
cred_data.dwVersion = Identity::SCHANNEL_CRED_VERSION;
cred_data.dwFlags =
Identity::SCH_USE_STRONG_CRYPTO | Identity::SCH_CRED_NO_DEFAULT_CREDS;
cred_data.dwFlags = Identity::SCH_USE_STRONG_CRYPTO | Identity::SCH_CRED_NO_DEFAULT_CREDS;
cred_data.grbitEnabledProtocols = enabled_protocols;
let mut certs = self.certs.iter().map(|c| c.as_inner()).collect::<Vec<_>>();
cred_data.cCreds = certs.len() as u32;
cred_data.paCred = certs.as_mut_ptr() as _;

let mut tls_param: Identity::TLS_PARAMETERS = mem::zeroed();
let mut cred_data2: Identity::SCH_CREDENTIALS = mem::zeroed();

let mut pauthdata: *const core::ffi::c_void = ptr::null();
if let Some(ref supported_algorithms) = self.supported_algorithms {
cred_data.cSupportedAlgs = supported_algorithms.len() as u32;
cred_data.palgSupportedAlgs = supported_algorithms.as_ptr() as *mut _;
} else if verify_min_os_build(10, 17763).is_some() {
// If no algorithms specified and should be supported, use new SCH_CREDENTIALS interface which supports TLS1.3.
// Although we check for win10 build 17763 above, I have only seen this work on win 11.
if enabled_protocols != 0 {
tls_param.grbitDisabledProtocols = !enabled_protocols;
}
// TODO: support something to select tls13-ciphers
cred_data2.dwVersion = Identity::SCH_CREDENTIALS_VERSION;
cred_data2.dwFlags = Identity::SCH_USE_STRONG_CRYPTO | Identity::SCH_CRED_NO_DEFAULT_CREDS;
cred_data2.cCreds = certs.len() as u32;
cred_data2.paCred = certs.as_mut_ptr() as _;
cred_data2.cTlsParameters = 1;
cred_data2.pTlsParameters = &mut tls_param;
pauthdata = &mut cred_data2 as *const _ as *const _;
}
if let Some(ref enabled_protocols) = self.enabled_protocols {
cred_data.grbitEnabledProtocols = enabled_protocols
.iter()
.map(|p| p.dword(direction))
.fold(0, |acc, p| acc | p);

if pauthdata.is_null() {
pauthdata = &mut cred_data as *const _ as *const _;
}
let mut certs = self.certs.iter().map(|c| c.as_inner()).collect::<Vec<_>>();
cred_data.cCreds = certs.len() as u32;
cred_data.paCred = certs.as_mut_ptr() as _;

let direction = match direction {
Direction::Inbound => Identity::SECPKG_CRED_INBOUND,
Direction::Outbound => Identity::SECPKG_CRED_OUTBOUND,
};
let mut handle: Credentials::SecHandle = mem::zeroed();

match Identity::AcquireCredentialsHandleA(
ptr::null(),
Identity::UNISP_NAME_A,
direction,
ptr::null_mut(),
&mut cred_data as *const _ as *const _,
pauthdata,
None,
ptr::null_mut(),
&mut handle,
Expand Down
29 changes: 29 additions & 0 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,31 @@ fn verify_callback_success() {
assert!(out.ends_with(b"</html>\n"));
}

#[test]
fn tls_13() {
if env::var("SCHANNEL_SKIP_TLS_13_TEST") == Ok("1".to_owned()) {
return
}

let creds = SchannelCred::builder()
.enabled_protocols(&[Protocol::Tls12, Protocol::Tls13])
.acquire(Direction::Outbound)
.unwrap();
let stream = TcpStream::connect("tls13.akamai.io:443").unwrap();
let mut stream = tls_stream::Builder::new()
.domain("tls13.akamai.io")
.connect(creds, stream)
.unwrap();
stream
.write_all(b"GET / HTTP/1.0\r\nHost: tls13.akamai.io\r\n\r\n")
.unwrap();
let mut out = vec![];
stream.read_to_end(&mut out).unwrap();

let pattern = b"Your client negotiated TLS 1.3";
assert!(out.windows(pattern.len()).any(|x| x == pattern));
}

#[test]
fn verify_callback_error() {
let creds = SchannelCred::builder()
Expand Down Expand Up @@ -345,6 +370,8 @@ fn no_session_resumed() {
#[test]
fn basic_session_resumed() {
let creds = SchannelCred::builder()
// TOOD: figure out why Tls13 doesnt resume
.enabled_protocols(&[Protocol::Tls12])
.acquire(Direction::Outbound)
.unwrap();
let creds_copy = creds.clone();
Expand All @@ -367,6 +394,8 @@ fn basic_session_resumed() {
#[test]
fn session_resumption_thread_safety() {
let creds = SchannelCred::builder()
// TOOD: figure out why Tls13 doesnt resume
.enabled_protocols(&[Protocol::Tls12])
.acquire(Direction::Outbound)
.unwrap();

Expand Down
3 changes: 2 additions & 1 deletion src/tls_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,8 @@ where
Foundation::SEC_E_OK => {
let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize;
let end = start + bufs[1].cbBuffer as usize;
self.dec_in.get_mut().clear();
let dec_in_read_pos = self.dec_in.position() as usize;
self.dec_in.get_mut().drain(..dec_in_read_pos);
self.dec_in
.get_mut()
.extend_from_slice(&self.enc_in.get_ref()[start..end]);
Expand Down
Loading