diff --git a/NativeDefinitions.cs b/NativeDefinitions.cs index d1c7695..ddf9619 100644 --- a/NativeDefinitions.cs +++ b/NativeDefinitions.cs @@ -61,6 +61,7 @@ public enum EventControlCode : uint CaptureState = 2, } + [Flags] public enum EnableTraceProperties : uint { Sid = 0x1, @@ -151,6 +152,46 @@ public enum WNodeClientContext : uint CpuCycleCounter = 3 } + public enum TRACE_QUERY_INFO_CLASS : uint + { + TraceGuidQueryList = 0, + TraceGuidQueryInfo = 1, + TraceGuidQueryProcess = 2, + TraceStackTracingInfo = 3, + TraceSystemTraceEnableFlagsInfo = 4, + TraceSampledProfileIntervalInfo = 5, + TraceProfileSourceConfigInfo = 6, + TraceProfileSourceListInfo = 7, + TracePmcEventListInfo = 8, + TracePmcCounterListInfo = 9, + TraceSetDisallowList = 10, + TraceVersionInfo = 11, + TraceGroupQueryList = 12, + TraceGroupQueryInfo = 13, + TraceDisallowListQuery = 14, + TraceInfoReserved15, + TracePeriodicCaptureStateListInfo = 16, + TracePeriodicCaptureStateInfo = 17, + TraceProviderBinaryTracking = 18, + TraceMaxLoggersQuery = 19, + TraceLbrConfigurationInfo = 20, + TraceLbrEventListInfo = 21, + TraceMaxPmcCounterQuery = 22, + TraceStreamCount = 23, + TraceStackCachingInfo = 24, + TracePmcCounterOwners = 25, + TraceUnifiedStackCachingInfo = 26, + TracePmcSessionInformation = 27, + MaxTraceSetInfoClass = 28 + } + + [Flags] + public enum TRACE_PROVIDER_INSTANCE_FLAGS : uint + { + TRACE_PROVIDER_FLAG_LEGACY = 1, + TRACE_PROVIDER_FLAG_PRE_ENABLE = 2 + } + // // ETW filtering // @@ -321,6 +362,38 @@ public struct EVENT_FILTER_LEVEL_KW public bool FilterIn; } + // + // advapi structs for enumerating trace sessions + // + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct TRACE_GUID_INFO + { + public uint InstanceCount; + public uint Reserved; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct TRACE_PROVIDER_INSTANCE_INFO + { + public uint NextOffset; + public uint EnableCount; + public uint Pid; + public TRACE_PROVIDER_INSTANCE_FLAGS Flags; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct TRACE_ENABLE_INFO + { + public uint IsEnabled; + public EventTraceLevel Level; + public byte Reserved1; + public ushort LoggerId; + public EnableTraceProperties EnableProperty; + public uint Reserved2; + public ulong MatchAnyKeyword; + public ulong MatchAllKeyword; + } + #endregion #region APIs @@ -904,6 +977,16 @@ internal static extern uint TdhQueryProviderFieldInformation( [In, Out] nint Buffer, // PPROVIDER_FIELD_INFOARRAY [In, Out] ref uint BufferSize ); + + [DllImport("advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern uint EnumerateTraceGuidsEx( + [In] NativeTraceControl.TRACE_QUERY_INFO_CLASS InfoClass, + [In] nint InBuffer, + [In] uint InBufferSize, + [In, Out] nint OutBuffer, + [In] uint OutBufferSize, + [In, Out] ref uint ReturnLength + ); #endregion public const int ERROR_SUCCESS = 0; @@ -915,6 +998,7 @@ internal static extern uint TdhQueryProviderFieldInformation( public const int ERROR_NOT_FOUND = 1168; public const int ERROR_XML_PARSE_ERROR = 1465; public const int ERROR_RESOURCE_TYPE_NOT_FOUND = 1813; + public const int ERROR_WMI_GUID_NOT_FOUND = 4200; public const int ERROR_EMPTY = 4306; public const int ERROR_EVT_INVALID_EVENT_DATA = 15005; public const int ERROR_MUI_FILE_NOT_FOUND = 15100; diff --git a/ParsedEtwSession.cs b/ParsedEtwSession.cs new file mode 100644 index 0000000..52a2b6d --- /dev/null +++ b/ParsedEtwSession.cs @@ -0,0 +1,139 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ +using System.Text; + +namespace etwlib +{ + using static NativeTraceControl; + + public class SessionEnabledProvider + { + public Guid ProviderId; + public uint ProcessId; + public TRACE_PROVIDER_INSTANCE_FLAGS InstanceFlags; + public EventTraceLevel Level; + public EnableTraceProperties EnableProperty; + public ulong MatchAnyKeyword; + public ulong MatchAllKeyword; + + public SessionEnabledProvider( + Guid providerId, + uint processId, + TRACE_PROVIDER_INSTANCE_FLAGS instanceFlags, + EventTraceLevel level, + EnableTraceProperties enableProperty, + ulong matchAnyKeyword, + ulong matchAllKeyword) + { + ProviderId = providerId; + ProcessId = processId; + InstanceFlags = instanceFlags; + Level = level; + EnableProperty = enableProperty; + MatchAnyKeyword = matchAnyKeyword; + MatchAllKeyword = matchAllKeyword; + } + + public override string ToString() + { + var enablePropertyStr = ""; + if (EnableProperty != 0) + { + enablePropertyStr = $", EnableProperty={EnableProperty}"; + } + return $"{ProviderId} registered by PID {ProcessId}, InstanceFlags={InstanceFlags}, "+ + $"Level={Level}{enablePropertyStr}, AnyKeyword={MatchAnyKeyword:X}, "+ + $"AllKeyword={MatchAllKeyword:X}"; + } + } + + public class ParsedEtwSession : IEquatable, IComparable + { + public ushort LoggerId; + public List EnabledProviders; + + public ParsedEtwSession(ushort Id) + { + LoggerId = Id; + EnabledProviders = new List(); + } + + public override bool Equals(object? Other) + { + if (Other == null) + { + return false; + } + var field = Other as ParsedEtwSession; + if (field == null) + { + return false; + } + return Equals(Other); + } + + public bool Equals(ParsedEtwSession? Other) + { + if (Other == null) + { + return false; + } + return LoggerId == Other.LoggerId; + } + + public static bool operator ==(ParsedEtwSession? Session1, ParsedEtwSession? Session2) + { + if ((object)Session1 == null || (object)Session2 == null) + return Equals(Session1, Session2); + return Session1.Equals(Session2); + } + + public static bool operator !=(ParsedEtwSession? Session1, ParsedEtwSession? Session2) + { + if ((object)Session1 == null || (object)Session2 == null) + return !Equals(Session1, Session2); + return !(Session1.Equals(Session2)); + } + + public override int GetHashCode() + { + return LoggerId.GetHashCode(); + } + + public int CompareTo(ParsedEtwSession? Other) + { + if (Other == null) + { + return 1; + } + return LoggerId.CompareTo(Other.LoggerId); + } + + public override string ToString() + { + var sb = new StringBuilder(); + sb.AppendLine($"Logger ID {LoggerId}"); + foreach (var p in EnabledProviders) + { + sb.AppendLine($" {p}"); + } + return sb.ToString(); + } + } +} diff --git a/ProviderParser.cs b/ProviderParser.cs index e709a56..0651fe6 100644 --- a/ProviderParser.cs +++ b/ProviderParser.cs @@ -534,6 +534,48 @@ public static class ProviderParser return results; } + public + static + bool + IsManifestKnown(Guid ProviderGuid) + { + var buffer = nint.Zero; + try + { + uint bufferSize = 0; + for (; ; ) + { + var status = TdhEnumerateManifestProviderEvents( + ref ProviderGuid, + buffer, + ref bufferSize); + switch (status) + { + case ERROR_SUCCESS: + case ERROR_INSUFFICIENT_BUFFER: + { + return true; + } + case ERROR_NOT_FOUND: + case ERROR_FILE_NOT_FOUND: + case ERROR_RESOURCE_TYPE_NOT_FOUND: + case ERROR_MUI_FILE_NOT_FOUND: + { + return false; + } + default: + { + return false; + } + } + } + } + catch (Exception) + { + return false; + } + } + private static List diff --git a/SessionParser.cs b/SessionParser.cs new file mode 100644 index 0000000..d4a7d59 --- /dev/null +++ b/SessionParser.cs @@ -0,0 +1,253 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Xml.Linq; + +namespace etwlib +{ + using static TraceLogger; + using static NativeTraceConsumer; + using static NativeTraceControl; + + public static class SessionParser + { + public + static + List + GetSessions() + { + var results = new List(); + nint buffer = nint.Zero; + + try + { + uint result = 0; + uint bufferSize = 0; + uint returnLength = 0; + + for (; ; ) + { + result = EnumerateTraceGuidsEx( + TRACE_QUERY_INFO_CLASS.TraceGuidQueryList, + nint.Zero, + 0, + buffer, + bufferSize, + ref returnLength); + if (result == ERROR_SUCCESS) + { + break; + } + else if (result != ERROR_INSUFFICIENT_BUFFER) + { + var error = $"EnumerateTraceGuidsEx failed: 0x{result:X}"; + Trace(TraceLoggerType.EtwSessionParser, + TraceEventType.Warning, + error); + throw new Exception(error); + } + + buffer = Marshal.AllocHGlobal((int)returnLength); + bufferSize = returnLength; + if (buffer == nint.Zero) + { + throw new Exception("Out of memory"); + } + } + + if (buffer == nint.Zero || bufferSize == 0) + { + throw new Exception("EnumerateTraceGuidsEx returned null " + + " or empty buffer."); + } + + int numProviders = (int)bufferSize / Marshal.SizeOf(typeof(Guid)); + var pointer = buffer; + + for (int i = 0; i < numProviders; i++) + { + var guid = (Guid)Marshal.PtrToStructure(pointer, typeof(Guid))!; + var session = GetSessions(guid); + if (session == null) + { + continue; + } + results.AddRange(session); + pointer = nint.Add(pointer, Marshal.SizeOf(typeof(Guid))); + } + + results.Sort(); + return results; + } + catch (Exception ex) + { + Trace(TraceLoggerType.EtwProviderParser, + TraceEventType.Error, + $"Exception in GetProviders(): {ex.Message}"); + throw; + } + finally + { + if (buffer != nint.Zero) + { + Marshal.FreeHGlobal(buffer); + } + } + } + + public static List? GetSessions(string ProviderName) + { + var provider = ProviderParser.GetProvider(ProviderName); + if (provider == null) + { + throw new Exception($"Provider {ProviderName} not found"); + } + return GetSessions(provider.Id); + } + + public static List? GetSessions(Guid ProviderId) + { + var inBufferSize = (uint)Marshal.SizeOf(typeof(Guid)); + var inBuffer = Marshal.AllocHGlobal((int)inBufferSize); + if (inBuffer == nint.Zero) + { + throw new Exception("Out of memory"); + } + + Marshal.StructureToPtr(ProviderId, inBuffer, false); + + var sessions = new List(); + var outBuffer = nint.Zero; + uint outBufferSize = 0; + uint returnLength = 0; + + try + { + for ( ; ; ) + { + var result = EnumerateTraceGuidsEx( + TRACE_QUERY_INFO_CLASS.TraceGuidQueryInfo, + inBuffer, + inBufferSize, + outBuffer, + outBufferSize, + ref returnLength); + if (result == ERROR_SUCCESS) + { + break; + } + else if (result == ERROR_WMI_GUID_NOT_FOUND) + { + // + // This can occur if the GUID is registered but not loaded + // + return null; + } + else if (result != ERROR_INSUFFICIENT_BUFFER) + { + var error = $"EnumerateTraceGuidsEx failed: 0x{result:X}"; + Trace(TraceLoggerType.EtwSessionParser, + TraceEventType.Error, + error); + throw new Exception(error); + } + + outBuffer = Marshal.AllocHGlobal((int)returnLength); + outBufferSize = returnLength; + if (outBuffer == nint.Zero) + { + throw new Exception("Out of memory"); + } + } + + if (outBuffer == nint.Zero || outBufferSize == 0) + { + throw new Exception("EnumerateTraceGuidsEx returned null " + + " or empty buffer."); + } + + var pointer = outBuffer; + var info = (TRACE_GUID_INFO)Marshal.PtrToStructure( + pointer, typeof(TRACE_GUID_INFO))!; + pointer = nint.Add(pointer, Marshal.SizeOf(typeof(TRACE_GUID_INFO))); + + // + // NB: there can be multiple instances of a provider with the same + // GUID if they're hosted in a DLL loaded in multiple processes. + // + for (int i = 0; i < info.InstanceCount; i++) + { + var instance = (TRACE_PROVIDER_INSTANCE_INFO)Marshal.PtrToStructure( + pointer, typeof(TRACE_PROVIDER_INSTANCE_INFO))!; + if (instance.EnableCount > 0) + { + var sessionPointer = pointer; + for (int j = 0; j < instance.EnableCount; j++) + { + var sessionInfo = (TRACE_ENABLE_INFO)Marshal.PtrToStructure( + sessionPointer, typeof(TRACE_ENABLE_INFO))!; + var enabledProvider = new SessionEnabledProvider( + ProviderId, + instance.Pid, + instance.Flags, + sessionInfo.Level, + sessionInfo.EnableProperty, + sessionInfo.MatchAnyKeyword, + sessionInfo.MatchAllKeyword); + var session = new ParsedEtwSession(sessionInfo.LoggerId); + if (!sessions.Contains(session)) + { + sessions.Add(session); + } + else + { + session = sessions.FirstOrDefault(s => s == session); + } + session!.EnabledProviders.Add(enabledProvider); + sessionPointer = nint.Add(sessionPointer, + Marshal.SizeOf(typeof(TRACE_ENABLE_INFO))); + } + } + pointer = nint.Add(pointer, (int)instance.NextOffset); + } + return sessions; + } + catch (Exception ex) + { + Trace(TraceLoggerType.EtwSessionParser, + TraceEventType.Error, + $"Exception in GetSessionsForProvider(): {ex.Message}"); + throw; + } + finally + { + if (outBuffer != nint.Zero) + { + Marshal.FreeHGlobal(outBuffer); + } + if (inBuffer != nint.Zero) + { + Marshal.FreeHGlobal(inBuffer); + } + } + } + } +} diff --git a/TraceLogger.cs b/TraceLogger.cs index 25da590..28d7b38 100644 --- a/TraceLogger.cs +++ b/TraceLogger.cs @@ -37,6 +37,7 @@ public static class TraceLogger new TraceSource("EventParser", SourceLevels.Verbose), new TraceSource("ProviderParser", SourceLevels.Verbose), new TraceSource("ManifestParser", SourceLevels.Verbose), + new TraceSource("SessionParser", SourceLevels.Verbose), }; public enum TraceLoggerType @@ -46,6 +47,7 @@ public enum TraceLoggerType EtwEventParser, EtwProviderParser, EtwManifestParser, + EtwSessionParser, Max } diff --git a/UnitTests/SessionTests.cs b/UnitTests/SessionTests.cs new file mode 100644 index 0000000..7820a13 --- /dev/null +++ b/UnitTests/SessionTests.cs @@ -0,0 +1,98 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using etwlib; +using System; + +namespace UnitTests +{ + using static Shared; + + [TestClass] + public class SessionTests + { + [TestMethod] + public void SessionsByProviderName() + { + ConfigureLoggers(); + + try + { + var sessions = SessionParser.GetSessions(); + Assert.IsNotNull(sessions); + Assert.IsTrue(sessions.Count > 0); + foreach (var session in sessions) + { + var provider = ProviderParser.GetProvider( + session.EnabledProviders[0].ProviderId); + if (provider != null && !string.IsNullOrEmpty(provider.Name)) + { + var results = SessionParser.GetSessions(provider.Name); + Assert.IsNotNull(results); + Assert.IsTrue(session == results[0]); + return; + } + } + Assert.Fail(); + } + catch (Exception ex) + { + Assert.Fail(ex.Message); + } + } + + [TestMethod] + public void SingleProviderById() + { + ConfigureLoggers(); + + try + { + var sessions = SessionParser.GetSessions(); + Assert.IsNotNull(sessions); + Assert.IsTrue(sessions.Count > 0); + var session = sessions[0]; + var results = SessionParser.GetSessions(session.EnabledProviders[0].ProviderId); + Assert.IsNotNull(results); + Assert.IsTrue(session == results[0]); + } + catch (Exception ex) + { + Assert.Fail(ex.Message); + } + } + + [TestMethod] + public void AllSessions() + { + ConfigureLoggers(); + + try + { + var sessions = SessionParser.GetSessions(); + Assert.IsNotNull(sessions); + Assert.IsTrue(sessions.Count > 0); + } + catch (Exception ex) + { + Assert.Fail(ex.Message); + } + } + } +}