diff --git a/Directory.Packages.props b/Directory.Packages.props
index e2940d07..46686db7 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -4,13 +4,14 @@
-
+
+
-
-
+
+
@@ -23,7 +24,7 @@
-
+
diff --git a/docs/rules/DAP245.md b/docs/rules/DAP245.md
new file mode 100644
index 00000000..674caf3d
--- /dev/null
+++ b/docs/rules/DAP245.md
@@ -0,0 +1,24 @@
+# DAP245
+
+It is possible for an identifier to be *technically valid* to use without quoting, yet highly confusing As an example, the following TSQL is *entirely valid*:
+
+``` sql
+CREATE TABLE GO (GO int not null)
+GO
+INSERT GO ( GO ) VALUES (42)
+GO
+SELECT GO FROM GO
+```
+
+However, this can confuse readers and parsing tools. It would be *hugely*
+advantageous to use delimited identifiers appropriately:
+
+``` sql
+CREATE TABLE [GO] ([GO] int not null)
+GO
+INSERT [GO] ( [GO] ) VALUES (42)
+GO
+SELECT [GO] FROM [GO]
+```
+
+Or... maybe just use a different name?
\ No newline at end of file
diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.Diagnostics.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.Diagnostics.cs
index a47041b1..3d481f48 100644
--- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.Diagnostics.cs
+++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.Diagnostics.cs
@@ -95,6 +95,7 @@ public static readonly DiagnosticDescriptor
InterpolatedStringSqlExpression = SqlWarning("DAP241", "Interpolated string usage", "Data values should not be interpolated into SQL string - use parameters instead"),
ConcatenatedStringSqlExpression = SqlWarning("DAP242", "Concatenated string usage", "Data values should not be concatenated into SQL string - use parameters instead"),
InvalidDatepartToken = SqlWarning("DAP243", "Valid datepart token expected", "Date functions require a recognized datepart argument"),
- SelectAggregateMismatch = SqlWarning("DAP244", "SELECT aggregate mismatch", "SELECT has mixture of aggregate and non-aggregate expressions");
+ SelectAggregateMismatch = SqlWarning("DAP244", "SELECT aggregate mismatch", "SELECT has mixture of aggregate and non-aggregate expressions"),
+ DangerousNonDelimitedIdentifier = SqlWarning("DAP245", "Dangerous non-delimited identifier", "The identifier '{0}' can be confusing when not delimited; consider delimiting it with [...]");
}
}
diff --git a/src/Dapper.AOT.Analyzers/Internal/DiagnosticTSqlProcessor.cs b/src/Dapper.AOT.Analyzers/Internal/DiagnosticTSqlProcessor.cs
index bb6356ea..a5c46a10 100644
--- a/src/Dapper.AOT.Analyzers/Internal/DiagnosticTSqlProcessor.cs
+++ b/src/Dapper.AOT.Analyzers/Internal/DiagnosticTSqlProcessor.cs
@@ -161,4 +161,7 @@ protected override void OnInvalidNullExpression(Location location)
protected override void OnTrivialOperand(Location location)
=> OnDiagnostic(DapperAnalyzer.Diagnostics.TrivialOperand, location);
+
+ protected override void OnDangerousNonDelimitedIdentifier(Location location, string name)
+ => OnDiagnostic(DapperAnalyzer.Diagnostics.DangerousNonDelimitedIdentifier, location, name);
}
diff --git a/src/Dapper.AOT.Analyzers/Internal/GeneralSqlParser.cs b/src/Dapper.AOT.Analyzers/Internal/GeneralSqlParser.cs
new file mode 100644
index 00000000..d593da3c
--- /dev/null
+++ b/src/Dapper.AOT.Analyzers/Internal/GeneralSqlParser.cs
@@ -0,0 +1,562 @@
+using Dapper.SqlAnalysis;
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+
+namespace Dapper.Internal.SqlParsing;
+
+internal readonly struct CommandVariable : IEquatable
+{
+ public CommandVariable(string name, int index)
+ {
+ Name = name;
+ Index = index;
+ }
+ public int Index { get; }
+ public string Name { get; }
+
+ public ParameterKind Kind
+ {
+ get
+ {
+ var name = Name;
+ if (name is { Length: >= 2 } && GeneralSqlParser.IsParameterPrefix(name[0]))
+ {
+ if ((char.IsLetter(name[1]) || name[1] == '_'))
+ return ParameterKind.Nominal;
+
+ if (char.IsNumber(name[1]))
+ return ParameterKind.Ordinal;
+ }
+
+ return ParameterKind.Unknown;
+ }
+ }
+
+ public override int GetHashCode() => Index;
+ public override bool Equals(object obj) => obj is CommandVariable other && Equals(other);
+ public bool Equals(CommandVariable other)
+ => Index == other.Index && Name == other.Name;
+ public override string ToString() => $"@{Index}:{Name}";
+}
+internal readonly struct CommandBatch : IEquatable
+{
+ public ImmutableArray Variables { get; }
+ public string Sql { get; }
+
+ public CommandBatch(string sql) : this(ImmutableArray.Empty, sql) { }
+ public CommandBatch(string sql, CommandVariable var0) : this(ImmutableArray.Create(var0), sql) { }
+ public CommandBatch(string sql, CommandVariable var0, CommandVariable var1) : this(ImmutableArray.Create(var0, var1), sql) { }
+ public CommandBatch(string sql, CommandVariable var0, CommandVariable var1, CommandVariable var2) : this(ImmutableArray.Create(var0, var1, var2), sql) { }
+ public CommandBatch(string sql, CommandVariable var0, CommandVariable var1, CommandVariable var2, CommandVariable var3) : this(ImmutableArray.Create(var0, var1, var2, var3), sql) { }
+ public CommandBatch(string sql, CommandVariable var0, CommandVariable var1, CommandVariable var2, CommandVariable var3, CommandVariable var4) : this(ImmutableArray.Create(var0, var1, var2, var3, var4), sql) { }
+ public static CommandBatch Create(string sql, params CommandVariable[] variables)
+ => new(ImmutableArray.Create(variables), sql);
+ public static CommandBatch Create(string sql, ImmutableArray variables)
+ => new(variables, sql);
+ // invert order to solve some .ctor ambiguity issues
+ private CommandBatch(ImmutableArray variables, string sql)
+ {
+ Sql = sql;
+ Variables = variables;
+ }
+
+ public ParameterKind ParameterKind
+ {
+ get
+ {
+ if (Variables.IsDefaultOrEmpty)
+ {
+ return ParameterKind.NonParametrized;
+ }
+ var first = Variables[0].Kind;
+ foreach (var arg in Variables.AsSpan().Slice(1))
+ {
+ if (arg.Kind != first) return ParameterKind.Mixed;
+ }
+ return first;
+ }
+ }
+
+ public override int GetHashCode() => Sql.GetHashCode(); // args are a component of the sql; no need to hash them
+ public override string ToString() => Variables.IsDefaultOrEmpty ? Sql :
+ (Sql + " with " + string.Join(", ", Variables));
+
+ public override bool Equals(object obj) => obj is CommandBatch other && Equals(other);
+
+ public bool Equals(CommandBatch other)
+ => Sql == other.Sql && Variables.SequenceEqual(other.Variables);
+
+ public OrdinalResult TryMakeOrdinal(IList inputArgs, Func argName, Func argFactory, out IList args, out string sql, int argIndex = 0)
+ {
+ static bool TryFindByName(string name, IList inputArgs, Func argNameSelector, out T found)
+ {
+ if (string.IsNullOrWhiteSpace(name) || name.Length < 2
+ || !GeneralSqlParser.IsParameterPrefix(name[0]))
+ {
+ // general preconditions for nominal match: looking for @x
+ // i.e. must have parameter symbol and at least one token character
+ found = default!;
+ return false;
+
+ }
+ foreach (var arg in inputArgs)
+ {
+ var argName = argNameSelector(arg);
+ if (string.IsNullOrWhiteSpace(argName)) continue; // looking for nominal match
+
+ // check for exact match including prefix, i.e. "@foo" vs "@foo"
+ if (string.Equals(name, argName, StringComparison.OrdinalIgnoreCase))
+ {
+ found = arg;
+ return true;
+ }
+ // check for input name excluding prefix, i.e. "foo" vs detected "@foo"
+ // (when using Dapper, this is the normal usage)
+ if (argName.Length == name.Length - 1 && !GeneralSqlParser.IsParameterPrefix(argName[0])
+ && name.EndsWith(argName, StringComparison.OrdinalIgnoreCase))
+ {
+ found = arg;
+ return true;
+ }
+ }
+
+ found = default!;
+ return false;
+ }
+ sql = Sql;
+ var kind = ParameterKind;
+ switch (kind)
+ {
+ case ParameterKind.NonParametrized:
+ args = [];
+ return OrdinalResult.NoChange;
+ case ParameterKind.Mixed:
+ args = inputArgs;
+ return OrdinalResult.MixedParameters;
+ case ParameterKind.Ordinal:
+ // TODO: rewrite, filtering and ordering; i.e.
+ // where Id = $4 and Name = $3 -- no mention of $1 or $2
+ // could be
+ // where Id = $1 and Name = $2
+ args = inputArgs;
+ return OrdinalResult.NoChange;
+ case ParameterKind.Nominal:
+ break; // below
+ default:
+ args = inputArgs;
+ return OrdinalResult.UnsupportedScenario;
+ }
+
+ Debug.Assert(kind == ParameterKind.Nominal);
+
+ var map = new Dictionary(Variables.Length, StringComparer.OrdinalIgnoreCase);
+ args = new List();
+ var sb = new StringBuilder(sql);
+ int delta = 0; // change in length of string
+ foreach (var queryArg in Variables)
+ {
+ if (!map.TryGetValue(queryArg.Name, out var finalArg))
+ {
+ if (!TryFindByName(queryArg.Name, inputArgs, argName, out var found))
+ {
+ args = inputArgs;
+ return OrdinalResult.UnsupportedScenario;
+ }
+ finalArg = argFactory(found, argIndex++);
+ map.Add(queryArg.Name, finalArg);
+ args.Add(finalArg);
+ }
+ var newName = argName(finalArg);
+ // could potentially be more efficient with forwards-only write
+ sb.Remove(queryArg.Index + delta, queryArg.Name.Length);
+ sb.Insert(queryArg.Index + delta, newName);
+ delta += newName.Length - queryArg.Name.Length;
+ }
+ sql = sb.ToString();
+ return sql == Sql ? OrdinalResult.NoChange : OrdinalResult.Success;
+ }
+
+ internal static Func OrdinalNaming { get; } = (name, index) => $"${index + 1}";
+}
+
+internal enum ParameterKind
+{
+ NonParametrized,
+ Mixed,
+ Ordinal, // $1
+ Nominal, // @foo
+ Unknown,
+}
+internal enum OrdinalResult
+{
+ NoChange,
+ MixedParameters,
+ Success,
+ UnsupportedScenario,
+}
+
+internal static class GeneralSqlParser
+{
+ private enum ParseState
+ {
+ None,
+ Token,
+ Variable,
+ LineComment,
+ BlockComment,
+ String,
+ Whitespace,
+ }
+
+ ///
+ /// Tokenize a sql fragment into batches, extracting the variables/locals in use
+ ///
+ /// This is a basic parse only; no syntax processing - just literals, identifiers, etc
+ public static List Parse(string sql, SqlSyntax syntax, bool strip = false)
+ {
+ // this is a basic first pass; TODO: rewrite using a forwards seek approach, i.e.
+ // "find first [@$:'"...] etc, copy backfill then search for end of that symbol and process
+ // accordingly
+
+ int bIndex = 0, parenDepth = 0;
+ char[] buffer = ArrayPool.Shared.Rent(sql.Length + 1);
+
+ char stringType = '\0';
+ var state = ParseState.None;
+ int i = 0, elementStartbIndex = 0;
+ ImmutableArray.Builder? variables = null;
+ var result = new List();
+
+ bool BatchSemicolon() => syntax == SqlSyntax.PostgreSql;
+
+ char LookAhead(int delta = 1)
+ {
+ var ci = i + delta;
+ return ci >= 0 && ci < sql.Length ? sql[ci] : '\0';
+ }
+ char Last(int offset)
+ {
+ var ci = bIndex - (offset + 1);
+ return ci >= 0 && ci < bIndex ? buffer[ci] : '\0';
+ }
+ char LookBehind(int delta = 1) => LookAhead(-delta);
+ void Discard() => bIndex--;
+ void NormalizeSpace()
+ {
+ if (strip)
+ {
+ if (bIndex > 1 && buffer[bIndex - 2] == ' ')
+ {
+ Discard();
+ }
+ else
+ {
+ buffer[bIndex - 1] = ' ';
+ }
+ }
+ }
+ bool ActivateStringPrefix()
+ {
+ if (ElementLength() == 2) // already written, so: N'... E'... etc
+ {
+ stringType = Last(0);
+ return true;
+ };
+ return false;
+ }
+ void SkipLeadingWhitespace(char v)
+ {
+ if (bIndex == 1 && ((v is '\r' or '\n') || strip))
+ {
+ // always omit leading CRLFs; omit leading whitespace
+ // when stripping
+ Discard();
+ }
+ else if (strip && Last(0) == ';')
+ {
+ Discard(); // don't write whitespace after ;
+ }
+ else
+ {
+ NormalizeSpace();
+ }
+ }
+ int ElementLength() => bIndex - elementStartbIndex + 1;
+
+ void FlushBatch()
+ {
+ if (IsGo()) bIndex -= 2; // don't retain the GO
+
+ //bool removedSemicolon = false;
+ if ((strip || BatchSemicolon()) && Last(0) == ';')
+ {
+ Discard();
+ //removedSemicolon = true;
+ }
+
+ if (strip) // remove trailing whitespace
+ {
+ while (bIndex > 0 && char.IsWhiteSpace(buffer[bIndex - 1]))
+ {
+ bIndex--;
+ }
+ }
+
+ if (!IsWhitespace()) // anything left?
+ {
+ //if (removedSemicolon)
+ //{
+ // // reattach
+ // buffer[bIndex++] = ';';
+ //}
+
+ var batchSql = new string(buffer, 0, bIndex);
+ var args = variables is null ? ImmutableArray.Empty : variables.ToImmutable();
+ result.Add(CommandBatch.Create(batchSql, args));
+ }
+ // logical reset
+ bIndex = 0;
+ variables?.Clear();
+ state = ParseState.None;
+
+ // lose any same-line simple space between batches
+ while (LookAhead() is ' ' or '\t')
+ {
+ i++; // same as Advance();Discard();
+ }
+ }
+ bool IsWhitespace()
+ {
+ if (bIndex == 0) return true;
+ for (int i = 0; i < bIndex; i++)
+ {
+ if (!char.IsWhiteSpace(buffer[i])) return false;
+ }
+ return true;
+ }
+ bool IsGo()
+ {
+ return syntax == SqlSyntax.SqlServer && ElementLength() == 2
+ && Last(1) is 'g' or 'G' && Last(0) is 'o' or 'O';
+ }
+ void FlushVariable()
+ {
+ int varLen = ElementLength(), varStart = bIndex - varLen;
+ var name = new string(buffer, varStart, varLen);
+ variables ??= ImmutableArray.CreateBuilder();
+ variables.Add(new(name, varStart));
+ }
+
+ bool IsString(char c) => state == ParseState.String && stringType == c;
+
+ bool IsSingleQuoteString() => state == ParseState.String && (stringType == '\'' || char.IsLetter(stringType));
+ void Advance() => buffer[bIndex++] = sql[++i];
+
+ for (; i < sql.Length; i++)
+ {
+ var c = i == sql.Length ? ';' : sql[i]; // spoof a ; at the end to simplify end-of-block handling
+
+ // detect end of GO token
+ if (state == ParseState.Token && !IsToken(c) && IsGo())
+ {
+ FlushBatch(); // and keep going
+ }
+ else if (state == ParseState.Variable && !IsMidToken(c))
+ {
+ FlushVariable();
+ }
+
+ // store by default, we'll backtrack in the rare scenarios that we don't want it
+ buffer[bIndex++] = sql[i];
+
+ switch (state)
+ {
+ case ParseState.Whitespace when char.IsWhiteSpace(c): // more contiguous whitespace
+ if (strip) Discard();
+ else SkipLeadingWhitespace(c);
+ continue;
+ case ParseState.LineComment when c is '\r' or '\n': // end of line comment
+ case ParseState.BlockComment when c == '/' && LookBehind() == '*': // end of block comment
+ if (strip) Discard();
+ else NormalizeSpace();
+ state = ParseState.Whitespace;
+ continue;
+ case ParseState.BlockComment or ParseState.LineComment: // keep ignoring line comment
+ if (strip) Discard();
+ continue;
+ // string-escape characters
+ case ParseState.String when c == '\'' && IsSingleQuoteString() && LookAhead() == '\'': // [?]'...''...'
+ case ParseState.String when c == '"' && IsString('"') && LookAhead() == '\"': // "...""..."
+ case ParseState.String when c == '\\' && (IsString('E') || IsString('e')) && LookAhead() != '\0' && AllowEscapedStrings(): // E'...\*...'
+ case ParseState.String when c == ']' && IsString('[') && LookAhead() == ']': // [...]]...]
+ // escaped or double-quote; move forwards immediately
+ Advance();
+ continue;
+ // end string
+ case ParseState.String when c == '"' && IsString('"'): // "....."
+ case ParseState.String when c == ']' && IsString('['): // [.....]
+ case ParseState.String when c == '\'' && IsSingleQuoteString(): // [?]'....'
+ state = ParseState.None;
+ continue;
+ case ParseState.String:
+ // ongoing string content
+ continue;
+ case ParseState.Token when c == '\'' && ActivateStringPrefix(): // E'..., N'... etc
+ continue;
+ case ParseState.Token or ParseState.Variable when IsMidToken(c):
+ // ongoing token / variable content
+ continue;
+ case ParseState.Variable: // end of variable
+ case ParseState.Whitespace: // end of whitespace chunk
+ case ParseState.Token: // end of token
+ case ParseState.None: // not started
+ state = ParseState.None;
+ break; // make sure we still handle the value, below
+ default:
+ throw new InvalidOperationException($"Token kind not handled: {state}");
+ }
+
+ if (c == '-' && LookAhead() == '-')
+ {
+ state = ParseState.LineComment;
+ if (strip) Discard();
+ continue;
+ }
+ if (c == '/' && LookAhead() == '*')
+ {
+ state = ParseState.BlockComment;
+ if (strip) Discard();
+ continue;
+ }
+
+ if (c == '(') parenDepth++;
+ if (c == ')') parenDepth--;
+ if (c == ';')
+ {
+ if (BatchSemicolon() && parenDepth == 0)
+ {
+ FlushBatch();
+ continue;
+ }
+
+ // otherwise end-statement
+ // (prevent unnecessary additional whitespace when stripping)
+ state = ParseState.Whitespace;
+ if (strip && Last(1) == ';')
+ { // squash down to just one
+ Discard();
+ }
+ continue;
+ }
+
+ if (char.IsWhiteSpace(c))
+ {
+ SkipLeadingWhitespace(c);
+ state = ParseState.Whitespace;
+ continue;
+ }
+
+ elementStartbIndex = bIndex;
+
+ if (c is '"' or '\'' or '[') // TODO: '$' dollar quoting
+ {
+ // start a new string
+ state = ParseState.String;
+ stringType = c;
+ continue;
+ }
+
+ if (c is '$' && AllowDollarQuotedStrings())
+ {
+ TryReadDollarQuotedString();
+ continue;
+ }
+
+ if (IsParameterPrefix(c)
+ && IsToken(LookAhead()) && LookBehind() != c) // avoid @>, @@IDENTTIY etc
+ {
+ // start a new variable
+ state = ParseState.Variable;
+ continue;
+ }
+
+ if (IsToken(c))
+ {
+ // start a new token
+ state = ParseState.Token;
+ continue;
+ }
+
+ // other arbitrary syntax - operators etc
+ }
+
+ // deal with any remaining bits
+ if (state == ParseState.Variable) FlushVariable();
+ if (BatchSemicolon())
+ {
+ // spoof a final ;
+ buffer[bIndex++] = ';';
+ }
+ FlushBatch();
+
+ ArrayPool.Shared.Return(buffer);
+
+ return result;
+
+ bool IsMidToken(char c) => IsToken(c)
+ || (syntax == SqlSyntax.PostgreSql && (
+ (c == '$' && state != ParseState.Variable) // postgresql identifiers can contain $, but variables can't
+ || (c == '.' && state == ParseState.Variable) // postgresql mapped variables (only) can contain .
+ ));
+
+ bool IsToken(char c) => c == '_' || char.IsLetterOrDigit(c);
+
+ bool AllowEscapedStrings() => syntax == SqlSyntax.PostgreSql;
+ bool AllowDollarQuotedStrings() => syntax == SqlSyntax.PostgreSql;
+
+ void TryReadDollarQuotedString()
+ {
+ // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-DOLLAR-QUOTING
+
+ // A dollar-quoted string constant consists of a dollar sign ($), an optional “tag” of zero or more characters,
+ // another dollar sign, an arbitrary sequence of characters that makes up the string content, a dollar sign,
+ // the same tag that began this dollar quote, and a dollar sign.
+
+ // The tag, if any, of a dollar-quoted string follows the same rules as an
+ // unquoted identifier, except that it cannot contain a dollar sign
+
+ // note this is too complex to process in the iterative way; we'll switch to forwards looking
+ int len = 1; // $
+ while (true)
+ {
+ var next = LookAhead(len++);
+ if (next == '$') break; // found end of marker
+ if (!IsToken(next)) return; // covers _, letters and digits
+ if (len == 2 && char.IsDigit(next)) return; // identifier rules; cannot start with digit
+ }
+
+ var sqlSpan = sql.AsSpan();
+ var hunt = sqlSpan.Slice(i, len);
+ var remaining = sqlSpan.Slice(i + len);
+ var found = remaining.IndexOf(hunt);
+ if (found < 0) return; // non-terminated; ignore
+
+ var toCopy = len * 2 + found - 1; // we already copied the first $
+ for (int j = 0; j < toCopy; j++)
+ {
+ Advance();
+ }
+
+ }
+ }
+
+ internal static bool IsParameterPrefix(char c)
+ => SqlTools.ParameterPrefixCharacters.IndexOf(c) >= 0;
+}
+
+
diff --git a/src/Dapper.AOT.Analyzers/SqlAnalysis/TSqlProcessor.cs b/src/Dapper.AOT.Analyzers/SqlAnalysis/TSqlProcessor.cs
index 875fa09c..5e15e454 100644
--- a/src/Dapper.AOT.Analyzers/SqlAnalysis/TSqlProcessor.cs
+++ b/src/Dapper.AOT.Analyzers/SqlAnalysis/TSqlProcessor.cs
@@ -257,7 +257,8 @@ protected virtual void OnInvalidDatepartToken(Location location)
=> OnError($"Valid datepart token expected", location);
protected virtual void OnTopWithOffset(Location location)
=> OnError($"TOP cannot be used when OFFSET is specified", location);
-
+ protected virtual void OnDangerousNonDelimitedIdentifier(Location location, string name)
+ => OnError($"The identifier '{name}' can be confusing when not delimited", location);
internal readonly struct Location
{
@@ -756,6 +757,15 @@ static bool IsAnyCaseInsensitive(string value, string[] options)
"nanosecond", "ns"
];
+ public override void Visit(Identifier node)
+ {
+ if (node.QuoteType == QuoteType.NotQuoted && string.Equals("GO", node.Value, StringComparison.OrdinalIgnoreCase))
+ {
+ parser.OnDangerousNonDelimitedIdentifier(node, node.Value);
+ }
+ base.Visit(node);
+ }
+
private void ValidateDateArg(ScalarExpression value)
{
if (!(value is ColumnReferenceExpression col
diff --git a/src/Dapper.AOT/CommandFactory.cs b/src/Dapper.AOT/CommandFactory.cs
index ea0e0dda..cab68b2d 100644
--- a/src/Dapper.AOT/CommandFactory.cs
+++ b/src/Dapper.AOT/CommandFactory.cs
@@ -159,6 +159,57 @@ protected static bool TryRecycle(ref DbCommand? storage, DbCommand command)
command.Transaction = null;
return Interlocked.CompareExchange(ref storage, command, null) is null;
}
+
+
+#if NET6_0_OR_GREATER
+ ///
+ /// Provides an opportunity to recycle and reuse batch instances
+ ///
+ protected static bool TryRecycle(ref DbBatch? storage, DbBatch batch)
+ {
+ // detach and recycle
+ batch.Connection = null;
+ batch.Transaction = null;
+ return Interlocked.CompareExchange(ref storage, batch, null) is null;
+ }
+
+ ///
+ /// Provides an opportunity to recycle and reuse batch instances
+ ///
+ public virtual bool TryRecycle(DbBatch batch) => false;
+#endif
+
+ ///
+ /// Creates and initializes new instances.
+ ///
+ public virtual DbParameter CreateNewParameter(in UnifiedCommand command)
+ => command.DefaultCreateParameter();
+
+ ///
+ /// Creates and initializes new instances.
+ ///
+ public virtual DbCommand CreateNewCommand(DbConnection connection)
+ => connection.CreateCommand();
+
+#if NET6_0_OR_GREATER
+ ///
+ /// Creates and initializes new instances.
+ ///
+ public virtual DbBatch CreateNewBatch(DbConnection connection)
+ => connection.CreateBatch();
+
+ ///
+ /// Creates and initializes new instances.
+ ///
+ public virtual DbBatchCommand CreateNewCommand(DbBatch batch)
+ => batch.CreateBatchCommand();
+#endif
+
+
+ ///
+ /// Indicates where it is required to invoke post-operation logic to update parameter values.
+ ///
+ public virtual bool RequirePostProcess => false;
}
///
@@ -182,18 +233,13 @@ protected CommandFactory() { }
public virtual DbCommand GetCommand(DbConnection connection, string sql, CommandType commandType, T args)
{
// default behavior assumes no args, no special logic
- var cmd = connection.CreateCommand();
- Initialize(new(cmd), sql, commandType, args);
+ var cmd = CreateNewCommand(connection);
+ var unified = new UnifiedCommand(this, cmd);
+ unified.SetCommand(sql, commandType);
+ AddParameters(in unified, args);
return cmd;
}
- internal void Initialize(in UnifiedCommand cmd,
- string sql, CommandType commandType, T args)
- {
- cmd.CommandText = sql;
- cmd.CommandType = commandType != 0 ? commandType : DapperAotExtensions.GetCommandType(sql);
- AddParameters(in cmd, args);
- }
internal override sealed void PostProcessObject(in UnifiedCommand command, object? args, int rowCount) => PostProcess(in command, (T)args!, rowCount);
@@ -214,9 +260,10 @@ public virtual void AddParameters(in UnifiedCommand command, T args)
///
public virtual void UpdateParameters(in UnifiedCommand command, T args)
{
- if (command.Parameters.Count != 0) // try to avoid rogue "dirty" checks
+ var ps = command.Parameters;
+ if (ps.Count != 0) // try to avoid rogue "dirty" checks
{
- command.Parameters.Clear();
+ ps.Clear();
}
AddParameters(in command, args);
}
@@ -232,14 +279,80 @@ public virtual void UpdateParameters(in UnifiedCommand command, T args)
// try to avoid any dirty detection in the setters
if (cmd.CommandText != sql) cmd.CommandText = sql;
if (cmd.CommandType != commandType) cmd.CommandType = commandType;
- UpdateParameters(new(cmd), args);
+ UpdateParameters(new UnifiedCommand(this, cmd), args);
}
return cmd;
}
+#pragma warning disable IDE0079 // following will look unnecessary on up-level
+#pragma warning disable CS1574 // DbBatchCommand will not resolve on down-level TFMs
///
- /// Indicates where it is required to invoke post-operation logic to update parameter values.
+ /// Indicates whether the factory wishes to split this command into a multi-command batch.
///
- public virtual bool RequirePostProcess => false;
+ /// This may or may not be implemented using , depending on the capabilities
+ /// of the runtime and ADO.NET provider.
+ /// #pragma warning disable IDE0079 // following will look unnecessary on up-level
+#pragma warning restore CS1574 // DbBatchCommand will not resolve on down-level TFMs
+#pragma warning restore IDE0079 // following will look unnecessary on up-level
+ public virtual bool UseBatch(string sql) => false;
+
+#if NET6_0_OR_GREATER
+ ///
+ /// Create a populated batch from a command
+ ///
+ public virtual DbBatch GetBatch(DbConnection connection, string sql, CommandType commandType, T args)
+ {
+ Debug.Assert(connection.CanCreateBatch);
+ var batch = CreateNewBatch(connection);
+ // initialize with a command
+ batch.BatchCommands.Add(CreateNewCommand(batch));
+ AddCommands(new(this, batch), sql, args);
+ return batch;
+ }
+
+ ///
+ /// Provides an opportunity to recycle and reuse batch instances
+ ///
+ protected DbBatch? TryReuse(ref DbBatch? storage, T args)
+ {
+ var batch = Interlocked.Exchange(ref storage, null);
+ if (batch is not null)
+ {
+ // try to avoid any dirty detection in the setters
+ UpdateParameters(new UnifiedBatch(this, batch), args);
+ }
+ return batch;
+ }
+#endif
+
+ ///
+ /// Allows the caller to rewrite a composite command into a multi-command batch.
+ ///
+ public virtual void AddCommands(in UnifiedBatch batch, string sql, T args)
+ {
+ // implement as basic mode
+ batch.SetCommand(sql, CommandType.Text);
+ AddParameters(in batch.Command, args);
+ }
+
+ ///
+ /// Allows the caller to update the parameter values of a multi-command batch.
+ ///
+ public virtual void UpdateParameters(in UnifiedBatch batch, T args)
+ {
+ UpdateParameters(in batch.Command, args);
+ }
+
+ ///
+ /// Allows an implementation to process output parameters etc after a multi-command batch has completed.
+ ///
+ /// This API is only invoked when reported true, and
+ /// corresponds to
+ public virtual void PostProcess(in UnifiedBatch batch, T args, int rowCount) { }
+
+ internal void PostProcess(in UnifiedCommand command, TArgs? val, object recordsAffected)
+ {
+ throw new NotImplementedException();
+ }
}
\ No newline at end of file
diff --git a/src/Dapper.AOT/CommandT.Batch.cs b/src/Dapper.AOT/CommandT.Batch.cs
index 65e10286..26675730 100644
--- a/src/Dapper.AOT/CommandT.Batch.cs
+++ b/src/Dapper.AOT/CommandT.Batch.cs
@@ -1,10 +1,9 @@
using Dapper.Internal;
using System;
+using System.Buffers;
using System.Collections.Generic;
using System.Collections.Immutable;
-using System.Data.Common;
using System.Diagnostics;
-using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
@@ -142,24 +141,6 @@ public Task ExecuteAsync(TArgs[] values, int offset, int count, int batchSi
};
}
- internal void Recycle(ref SyncCommandState state)
- {
- Debug.Assert(state.Command is not null);
- if (commandFactory.TryRecycle(state.Command!))
- {
- state.Command = null;
- }
- }
-
- internal void Recycle(AsyncCommandState state)
- {
- Debug.Assert(state.Command is not null);
- if (commandFactory.TryRecycle(state.Command!))
- {
- state.Command = null;
- }
- }
-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private int ExecuteMulti(ReadOnlySpan source, int batchSize)
{
@@ -180,7 +161,7 @@ private int ExecuteMultiSequential(ReadOnlySpan source)
var current = source[0];
var local = state.ExecuteNonQuery(GetCommand(current));
- UnifiedCommand cmdState = new(state.Command);
+ UnifiedCommand cmdState = new(commandFactory, state.Command);
commandFactory.PostProcess(in cmdState, current, local);
total += local;
@@ -193,7 +174,7 @@ private int ExecuteMultiSequential(ReadOnlySpan source)
total += local;
}
- Recycle(ref state);
+ state.UnifiedBatch.TryRecycle();
return total;
}
finally
@@ -206,68 +187,140 @@ private int ExecuteMultiSequential(ReadOnlySpan source)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool UseBatch(int batchSize) => batchSize != 0 && connection is { CanCreateBatch: true };
- private DbBatchCommand AddCommand(ref UnifiedCommand state, TArgs args)
+ void Add(ref UnifiedBatch batch, TArgs args)
{
- var cmd = state.UnsafeCreateNewCommand();
- commandFactory.Initialize(state, sql, commandType, args);
- state.AssertBatchCommands.Add(cmd);
- return cmd;
+ if (batch.IsDefault)
+ {
+ // create a new batch initialized on a ready command
+ batch = new(commandFactory, connection, transaction);
+ if (timeout >= 0) batch.TimeoutSeconds = timeout;
+ Debug.Assert(batch.Mode is BatchMode.MultiCommandDbBatchCommand); // current expectations
+ batch.SetCommand(sql, commandType);
+ commandFactory.AddParameters(in batch.Command, args);
+ }
+ else if (batch.IsLastCommand)
+ {
+ // create a new command at the end of the batch
+ batch.CreateNextBatchGroup(sql, commandType);
+ commandFactory.AddParameters(in batch.Command, args);
+ }
+ else
+ {
+ // overwriting
+ batch.OverwriteNextBatchGroup();
+ commandFactory.UpdateParameters(in batch.Command, args);
+ }
}
- private int ExecuteMultiBatch(ReadOnlySpan source, int batchSize) // TODO: sub-batching
+ private int ExecuteMultiBatch(ReadOnlySpan source, int batchSize)
{
Debug.Assert(source.Length > 1);
- UnifiedCommand batch = default;
+ SyncCommandState state = default;
+ // note that we currently only use single-command-per-TArg mode, i.e. UseBatch is ignored
try
{
+ int sum = 0, ppOffset = 0;
foreach (var arg in source)
{
- if (!batch.HasBatch) batch = new(connection.CreateBatch());
- AddCommand(ref batch, arg);
+ Add(ref state.UnifiedBatch, arg);
+ if (state.UnifiedBatch.GroupCount == batchSize)
+ {
+ sum += state.ExecuteNonQueryUnified();
+ PostProcessMultiBatch(in state.UnifiedBatch, ref ppOffset, source);
+ }
}
- if (!batch.HasBatch) return 0;
-
- var result = batch.AssertBatch.ExecuteNonQuery();
-
- if (commandFactory.RequirePostProcess)
+ if (state.UnifiedBatch.GroupCount != 0)
{
- batch.PostProcess(source, commandFactory);
+ state.UnifiedBatch.Trim();
+ sum += state.ExecuteNonQueryUnified();
+ PostProcessMultiBatch(in state.UnifiedBatch, ref ppOffset, source);
}
- return result;
+ state.UnifiedBatch.TryRecycle();
+ return sum;
}
finally
{
- batch.Cleanup();
+ state.Dispose();
}
}
- private int ExecuteMultiBatch(IEnumerable source, int batchSize) // TODO: sub-batching
+
+ private void PostProcessMultiBatch(in UnifiedBatch batch, ReadOnlySpan args)
{
+ int i = 0;
+ PostProcessMultiBatch(batch, ref i, args);
+ }
+
+ private void PostProcessMultiBatch(in UnifiedBatch batch, ref int argOffset, ReadOnlySpan args)
+ {
+ // TODO: we'd need to buffer the sub-batch from IEnumerable
+ if (batch.IsDefault) return;
+
+ // assert that we currently expect only single commands per element
+ if (batch.Command.CommandCount != batch.GroupCount) Throw();
+
if (commandFactory.RequirePostProcess)
{
- // try to ensure it is repeatable
- source = (source as IReadOnlyCollection) ?? source.ToList();
+ batch.Command.UnsafeMoveTo(0);
+ foreach (var val in args.Slice(argOffset, batch.GroupCount))
+ {
+ commandFactory.PostProcess(in batch.Command, val, batch.Command.RecordsAffected);
+ batch.Command.UnsafeAdvance();
+ }
+ argOffset += batch.GroupCount;
}
+ // prepare for the next batch, if one
+ batch.UnsafeMoveBeforeFirst();
+
+ static void Throw() => throw new InvalidOperationException("The number of operations should have matched the number of groups!");
+ }
+
+ private TArgs[]? GetMultiBatchBuffer(ref int batchSize)
+ {
+ if (!commandFactory.RequirePostProcess) return null; // no problem, then
- UnifiedCommand batch = default;
+ const int MAX_SIZE = 1024;
+ if (batchSize < 0 || batchSize > 1024) batchSize = MAX_SIZE;
+
+ return ArrayPool.Shared.Rent(batchSize);
+ }
+ private static void RecycleMultiBatchBuffer(TArgs[]? buffer)
+ {
+ if (buffer is not null) ArrayPool.Shared.Return(buffer);
+ }
+
+ private int ExecuteMultiBatch(IEnumerable source, int batchSize)
+ {
+ SyncCommandState state = default;
+ var buffer = GetMultiBatchBuffer(ref batchSize);
try
{
+ int sum = 0, ppOffset = 0;
foreach (var arg in source)
{
- if (!batch.HasBatch) batch = new(connection.CreateBatch());
- AddCommand(ref batch, arg);
+ Add(ref state.UnifiedBatch, arg);
+ if (buffer is not null) buffer[ppOffset++] = arg;
+
+ if (state.UnifiedBatch.GroupCount == batchSize)
+ {
+ sum += state.ExecuteNonQueryUnified();
+ PostProcessMultiBatch(in state.UnifiedBatch, buffer);
+ ppOffset = 0;
+ }
}
- if (!batch.HasBatch) return 0;
- var result = batch.AssertBatch.ExecuteNonQuery();
- if (commandFactory.RequirePostProcess)
+ if (state.UnifiedBatch.GroupCount != 0)
{
- batch.PostProcess(source, commandFactory);
+ state.UnifiedBatch.Trim();
+ sum += state.ExecuteNonQueryUnified();
+ PostProcessMultiBatch(in state.UnifiedBatch, buffer);
}
- return result;
+ state.UnifiedBatch.TryRecycle();
+ return sum;
}
finally
{
- batch.Cleanup();
+ RecycleMultiBatchBuffer(buffer);
+ state.Dispose();
}
}
#endif
@@ -280,7 +333,7 @@ private int ExecuteMulti(IEnumerable source, int batchSize)
#endif
return ExecuteMultiSequential(source);
}
-
+
private int ExecuteMultiSequential(IEnumerable source)
{
SyncCommandState state = default;
@@ -294,7 +347,7 @@ private int ExecuteMultiSequential(IEnumerable source)
bool haveMore = iterator.MoveNext();
if (haveMore && commandFactory.CanPrepare) state.PrepareBeforeExecute();
var local = state.ExecuteNonQuery(GetCommand(current));
- UnifiedCommand cmdState = new(state.Command);
+ UnifiedCommand cmdState = new(commandFactory, state.Command);
commandFactory.PostProcess(in cmdState, current, local);
total += local;
@@ -308,7 +361,7 @@ private int ExecuteMultiSequential(IEnumerable source)
haveMore = iterator.MoveNext();
}
- Recycle(ref state);
+ state.UnifiedBatch.TryRecycle();
return total;
}
return total;
@@ -385,16 +438,16 @@ public Task ExecuteAsync(ImmutableArray values, int batchSize = -1,
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private Task ExecuteMultiAsync(ReadOnlyMemory source, int batchSize, CancellationToken cancellationToken)
{
-//#if NET6_0_OR_GREATER
-// if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken);
-//#endif
+ //#if NET6_0_OR_GREATER
+ // if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken);
+ //#endif
return ExecuteMultiSequentialAsync(source, cancellationToken);
}
private async Task ExecuteMultiSequentialAsync(ReadOnlyMemory source, CancellationToken cancellationToken)
{
Debug.Assert(source.Length > 1);
- AsyncCommandState state = new();
+ var state = AsyncCommandState.Create();
try
{
if (commandFactory.CanPrepare) state.PrepareBeforeExecute();
@@ -402,7 +455,7 @@ private async Task ExecuteMultiSequentialAsync(ReadOnlyMemory source
var current = source.Span[0];
var local = await state.ExecuteNonQueryAsync(GetCommand(current), cancellationToken);
- UnifiedCommand cmdState = new(state.Command);
+ UnifiedCommand cmdState = new(commandFactory, state.Command);
commandFactory.PostProcess(in cmdState, current, local);
total += local;
@@ -415,37 +468,41 @@ private async Task ExecuteMultiSequentialAsync(ReadOnlyMemory source
total += local;
}
- Recycle(state);
+ state.UnifiedBatch.TryRecycle();
return total;
}
finally
{
await state.DisposeAsync();
+ state.Recycle();
}
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private Task ExecuteMultiAsync(IAsyncEnumerable source, int batchSize, CancellationToken cancellationToken)
{
-//#if NET6_0_OR_GREATER
-// if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken);
-//#endif
+ //#if NET6_0_OR_GREATER
+ // if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken);
+ //#endif
return ExecuteMultiSequentialAsync(source, cancellationToken);
}
private async Task ExecuteMultiSequentialAsync(IAsyncEnumerable source, CancellationToken cancellationToken)
{
- AsyncCommandState state = new();
+ AsyncCommandState? state = null;
var iterator = source.GetAsyncEnumerator(cancellationToken);
try
{
int total = 0;
if (await iterator.MoveNextAsync())
{
+ state ??= AsyncCommandState.Create();
+
var current = iterator.Current;
bool haveMore = await iterator.MoveNextAsync();
if (haveMore && commandFactory.CanPrepare) state.PrepareBeforeExecute();
+
var local = await state.ExecuteNonQueryAsync(GetCommand(current), cancellationToken);
- UnifiedCommand cmdState = new(state.Command);
+ UnifiedCommand cmdState = new(commandFactory, state.Command);
commandFactory.PostProcess(in cmdState, current, local);
total += local;
@@ -459,41 +516,46 @@ private async Task ExecuteMultiSequentialAsync(IAsyncEnumerable sour
haveMore = await iterator.MoveNextAsync();
}
- Recycle(state);
- return total;
+ state.UnifiedBatch.TryRecycle();
}
return total;
}
finally
{
await iterator.DisposeAsync();
- await state.DisposeAsync();
+ if (state is not null)
+ {
+ await state.DisposeAsync();
+ state.Recycle();
+ }
}
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private Task ExecuteMultiAsync(IEnumerable source, int batchSize, CancellationToken cancellationToken)
{
-//#if NET6_0_OR_GREATER
-// if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken);
-//#endif
+#if NET6_0_OR_GREATER
+ if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken);
+#endif
return ExecuteMultiSequentialAsync(source, cancellationToken);
}
private async Task ExecuteMultiSequentialAsync(IEnumerable source, CancellationToken cancellationToken)
{
- AsyncCommandState state = new();
+ AsyncCommandState? state = null;
var iterator = source.GetEnumerator();
try
{
int total = 0;
if (iterator.MoveNext())
{
+ state ??= AsyncCommandState.Create();
+
var current = iterator.Current;
bool haveMore = iterator.MoveNext();
if (haveMore && commandFactory.CanPrepare) state.PrepareBeforeExecute();
var local = await state.ExecuteNonQueryAsync(GetCommand(current), cancellationToken);
- UnifiedCommand cmdState = new(state.Command);
+ UnifiedCommand cmdState = new(commandFactory, state.Command);
commandFactory.PostProcess(in cmdState, current, local);
total += local;
@@ -507,18 +569,66 @@ private async Task ExecuteMultiSequentialAsync(IEnumerable source, C
haveMore = iterator.MoveNext();
}
- Recycle(state);
- return total;
+ state.UnifiedBatch.TryRecycle();
}
return total;
}
finally
{
iterator.Dispose();
- await state.DisposeAsync();
+ if (state is not null)
+ {
+ await state.DisposeAsync();
+ state.Recycle();
+ }
}
}
+#if NET6_0_OR_GREATER
+ private async Task ExecuteMultiBatchAsync(IEnumerable source, int batchSize, CancellationToken cancellationToken)
+ {
+ AsyncCommandState? state = null;
+ var buffer = GetMultiBatchBuffer(ref batchSize);
+ try
+ {
+ int sum = 0, ppOffset = 0;
+ foreach (var arg in source)
+ {
+ state ??= AsyncCommandState.Create();
+ Add(ref state.UnifiedBatch, arg);
+ if (buffer is not null) buffer[ppOffset++] = arg;
+
+ if (state.UnifiedBatch.GroupCount == batchSize)
+ {
+ sum += await state.ExecuteNonQueryUnifiedAsync(cancellationToken);
+ PostProcessMultiBatch(in state.UnifiedBatch, buffer);
+ ppOffset = 0;
+ }
+ }
+
+ if (state is not null)
+ {
+ if (state.UnifiedBatch.GroupCount != 0)
+ {
+ sum += await state.ExecuteNonQueryUnifiedAsync(cancellationToken);
+ PostProcessMultiBatch(in state.UnifiedBatch, buffer);
+ }
+ state.UnifiedBatch.TryRecycle();
+ }
+ return sum;
+ }
+ finally
+ {
+ RecycleMultiBatchBuffer(buffer);
+ if (state is not null)
+ {
+ await state.DisposeAsync();
+ state.Recycle();
+ }
+ }
+ }
+#endif
+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private Task ExecuteMultiAsync(TArgs[] source, int offset, int count, int batchSize, CancellationToken cancellationToken)
{
@@ -529,7 +639,8 @@ private Task ExecuteMultiAsync(TArgs[] source, int offset, int count, int b
}
private async Task ExecuteMultiSequentialAsync(TArgs[] source, int offset, int count, CancellationToken cancellationToken)
{
- AsyncCommandState state = new();
+ Debug.Assert(count > 0);
+ var state = AsyncCommandState.Create();
try
{
// count is now actually "end"
@@ -540,7 +651,7 @@ private async Task ExecuteMultiSequentialAsync(TArgs[] source, int offset,
var current = source[offset++];
var local = await state.ExecuteNonQueryAsync(GetCommand(current), cancellationToken);
- UnifiedCommand cmdState = new(state.Command);
+ UnifiedCommand cmdState = new(commandFactory, state.Command);
commandFactory.PostProcess(in cmdState, current, local);
total += local;
@@ -553,41 +664,54 @@ private async Task ExecuteMultiSequentialAsync(TArgs[] source, int offset,
total += local;
}
- Recycle(state);
+ state.UnifiedBatch.TryRecycle();
return total;
}
finally
{
await state.DisposeAsync();
+ state.Recycle();
}
}
#if NET6_0_OR_GREATER
- private async Task ExecuteMultiBatchAsync(TArgs[] source, int offset, int count, int batchSize, CancellationToken cancellationToken) // TODO: sub-batching
+ private async Task ExecuteMultiBatchAsync(TArgs[] source, int offset, int count, int batchSize, CancellationToken cancellationToken)
{
Debug.Assert(source.Length > 1);
- UnifiedCommand batch = default;
+ AsyncCommandState? state = null;
var end = offset + count;
try
{
- for (int i = offset ; i < end; i++)
+ int sum = 0, ppOffset = offset;
+ for (int i = offset; i < end; i++)
{
- if (!batch.HasBatch) batch = new(connection.CreateBatch());
- AddCommand(ref batch, source[i]);
+ state ??= AsyncCommandState.Create();
+ Add(ref state.UnifiedBatch, source[i]);
+ if (state.UnifiedBatch.GroupCount == batchSize)
+ {
+ sum += await state.ExecuteNonQueryUnifiedAsync(cancellationToken);
+ PostProcessMultiBatch(in state.UnifiedBatch, ref ppOffset, source);
+ }
}
- if (!batch.HasBatch) return 0;
- var result = await batch.AssertBatch.ExecuteNonQueryAsync(cancellationToken);
-
- if (commandFactory.RequirePostProcess)
+ if (state is not null)
{
- batch.PostProcess(new ReadOnlySpan(source, offset, count), commandFactory);
+ if (state.UnifiedBatch.GroupCount != 0)
+ {
+ sum += await state.ExecuteNonQueryUnifiedAsync(cancellationToken);
+ PostProcessMultiBatch(in state.UnifiedBatch, ref ppOffset, source);
+ }
+ state.UnifiedBatch.TryRecycle();
}
- return result;
+ return sum;
}
finally
{
- batch.Cleanup();
+ if (state is not null)
+ {
+ await state.DisposeAsync();
+ state.Recycle();
+ }
}
}
#endif
diff --git a/src/Dapper.AOT/CommandT.Execute.cs b/src/Dapper.AOT/CommandT.Execute.cs
index 461ece74..2a975b32 100644
--- a/src/Dapper.AOT/CommandT.Execute.cs
+++ b/src/Dapper.AOT/CommandT.Execute.cs
@@ -14,8 +14,9 @@ public int Execute(TArgs args)
SyncCommandState state = default;
try
{
- var result = state.ExecuteNonQuery(GetCommand(args));
- PostProcessAndRecycle(ref state, args, result);
+ GetUnifiedBatch(out state.UnifiedBatch, args);
+ var result = state.ExecuteNonQueryUnified();
+ PostProcessAndRecycleUnified(state.UnifiedBatch, args, result);
return result;
}
finally
@@ -29,16 +30,18 @@ public int Execute(TArgs args)
///
public async Task ExecuteAsync(TArgs args, CancellationToken cancellationToken = default)
{
- AsyncCommandState state = new();
+ var state = AsyncCommandState.Create();
try
{
- var result = await state.ExecuteNonQueryAsync(GetCommand(args), cancellationToken);
- PostProcessAndRecycle(state, args, result);
+ GetUnifiedBatch(out state.UnifiedBatch, args);
+ var result = await state.ExecuteNonQueryUnifiedAsync(cancellationToken);
+ PostProcessAndRecycleUnified(in state.UnifiedBatch, args, result);
return result;
}
finally
{
await state.DisposeAsync();
+ state.Recycle();
}
}
}
diff --git a/src/Dapper.AOT/CommandT.ExecuteScalar.cs b/src/Dapper.AOT/CommandT.ExecuteScalar.cs
index 89cdab3e..fdc5e94b 100644
--- a/src/Dapper.AOT/CommandT.ExecuteScalar.cs
+++ b/src/Dapper.AOT/CommandT.ExecuteScalar.cs
@@ -1,4 +1,5 @@
using Dapper.Internal;
+using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
@@ -29,7 +30,7 @@ partial struct Command
///
public async Task