diff --git a/docs/api-walkthrough.md b/docs/api-walkthrough.md index d345478..a29b930 100644 --- a/docs/api-walkthrough.md +++ b/docs/api-walkthrough.md @@ -97,6 +97,31 @@ if (inst is BinaryInst { Op: BinaryOp.Sub, Left: BinaryInst { Op: BinaryOp.Add } } ``` +or you can use the pattern matching dsl: +```csharp +if (ins.Match("(sub {lhs: (add {x} $y)} $y)", out var outputs)) +{ + inst.ReplaceWith(outputs["lhs"]); +} +``` + +Pattern Matching Operations: +| Example | Description | +|------------|------------------------------------------------------------------------| +| $x | Use the value later for matching | +| {x} | Output the operand | +| {x: (add)} | Output the operand if the subpattern matched | +| _ | Ignore the operand | +| ! | Matches if the subpattern doesn't match | +| #instr | Matches if the operand is an instruction | +| #const | Matches if the operand is a constant | +| #int | Matches if the operand is a constant of type int32 | +| >5 | Matches if the value is greater than 5 | +| <5 | Matches if the value is less than 5 | +| *"hello"* | Matches if the operand is a string constant that contains the value | +| *"hello" | Matches if the operand is a string constant that end with the value | +| "hello"* | Matches if the operand is a string constant that starts with the value | +======= ## Finding Methods The module resolver has an `FindMethod` extension method that makes it easier to find methods from the IR. diff --git a/src/DistIL/AsmIO/ResolvingUtils.cs b/src/DistIL/AsmIO/ResolvingUtils.cs index 76dcce4..c8da389 100644 --- a/src/DistIL/AsmIO/ResolvingUtils.cs +++ b/src/DistIL/AsmIO/ResolvingUtils.cs @@ -70,7 +70,7 @@ public static class ResolvingUtils return methods.FirstOrDefault(); } - private static MethodSelector GetSelector(this ModuleResolver resolver, string selector) + internal static MethodSelector GetSelector(this ModuleResolver resolver, string selector) { var spl = selector.Split("::"); var type = FindType(resolver, spl[0].Trim()); diff --git a/src/DistIL/IR/DSL/IInstructionPatternArgument.cs b/src/DistIL/IR/DSL/IInstructionPatternArgument.cs new file mode 100644 index 0000000..cc43159 --- /dev/null +++ b/src/DistIL/IR/DSL/IInstructionPatternArgument.cs @@ -0,0 +1,6 @@ +namespace DistIL.IR.DSL; + +internal interface IInstructionPatternArgument +{ + +} \ No newline at end of file diff --git a/src/DistIL/IR/DSL/InstructionPattern.cs b/src/DistIL/IR/DSL/InstructionPattern.cs new file mode 100644 index 0000000..f0f375f --- /dev/null +++ b/src/DistIL/IR/DSL/InstructionPattern.cs @@ -0,0 +1,227 @@ +namespace DistIL.IR.DSL; + +using System; +using System.Collections.Generic; + +using PatternArguments; +using Utils.Parser; + +internal record InstructionPattern( + Opcode OpCode, + string Operation, + List Arguments) + : IInstructionPatternArgument +{ + public static IInstructionPatternArgument? Parse(ReadOnlySpan pattern) + { + // Remove whitespace and validate parentheses balance + pattern = pattern.Trim(); + if (pattern.Length == 0) { + return null; + } + + if (pattern[0] != '(' || pattern[^1] != ')') + return ParseArgument(pattern); + + // Remove the outer parentheses + pattern = pattern[1..^1].Trim(); + + // Split the operation from its arguments + int spaceIndex = pattern.IndexOf(' '); + if (spaceIndex == -1) { + spaceIndex = pattern.Length; + } + + var op = pattern[..spaceIndex].ToString(); + var operation = Opcodes.TryParse(op); // TryParse does not support span yet + var argsString = pattern[spaceIndex..].Trim(); + + List arguments = new List(); + + if (operation.Op is Opcode.Call or Opcode.CallVirt) { + var selector = argsString[..argsString.IndexOf(' ')].ToString(); + arguments.Add(new MethodRefArgument(selector)); + + argsString = argsString[argsString.IndexOf(' ')..]; + } + + ParseArguments(argsString, arguments); + + return new InstructionPattern(operation.Op, op, arguments); + } + + private static IInstructionPatternArgument? ParseEval(ReadOnlySpan pattern) + { + var op = pattern[1..].Trim(); + + return new EvalArgument(ParseArgument(op)); + } + + private static void ParseArguments(ReadOnlySpan argsString, List arguments) + { + int depth = 0; + string currentArg = ""; + Stack outputStack = new(); + + foreach (var c in argsString) + { + if (c == '{') { + outputStack.Push(c); + } + else if (c == '}') { + outputStack.Pop(); + } + + if (c == '(') + { + depth++; + currentArg += c; + } + else if (c == ')') + { + depth--; + currentArg += c; + if (depth == 0 && outputStack.Count == 0) + { + // Completed a nested argument + arguments.Add(Parse(currentArg.AsSpan())!); + currentArg = ""; + } + } + else if (char.IsWhiteSpace(c) && depth == 0) + { + // End of a top-level argument + if (!string.IsNullOrWhiteSpace(currentArg)) + { + arguments.Add(ParseArgument(currentArg.Trim())); + currentArg = ""; + } + } + else + { + currentArg += c; + } + } + + // Add any remaining argument + if (!string.IsNullOrWhiteSpace(currentArg)) + { + arguments.Add(ParseArgument(currentArg.Trim())); + } + } + + private static IInstructionPatternArgument ParseArgument(ReadOnlySpan arg) + { + if (arg[0] == '(' && arg[^1] == ')') { + return Parse(arg)!; + } + + if (arg.Contains('#')) + { + var left = arg[..arg.IndexOf('#')]; + var typeSpecifier = arg[arg.IndexOf('#')..].TrimStart('#'); + + var argument = left is not "" ? ParseArgument(left) : null; + return new TypedArgument(argument, typeSpecifier.ToString()); + } + + if (arg[0] == '!') { + return ParseNot(arg); + } + + if (arg[0] == '$') { + return ParseBuffer(arg); + } + + if (arg[0] == '<' || arg[0] == '>') { + return ParseNumOperator(arg); + } + + if (arg[0] == '#') { + return new TypedArgument(default, arg[1..].ToString()); + } + + if (arg[0] == '*' || arg[0] == '\'') + { + return ParseStringArgument(arg); + } + + if (arg[0] == '{' && arg[^1] == '}') { + return ParseOutputArgument(arg); + } + + if (arg[0] == '_') + { + return new IgnoreArgument(); + } + + if (long.TryParse(arg, out var number)) + { + return new ConstantArgument(number, PrimType.Int64); + } + if (double.TryParse(arg, out var dnumber)) + { + return new ConstantArgument(dnumber, PrimType.Double); + } + + throw new ArgumentException("Invalid Argument"); + } + + private static IInstructionPatternArgument ParseOutputArgument(ReadOnlySpan arg) + { + arg = arg[1..^1]; + + if (arg.Contains(':')) { + var name = arg[..arg.IndexOf(':')]; + var subPattern = ParseArgument(arg[(arg.IndexOf(':') + 1)..]); + + return new OutputArgument(name.ToString(), subPattern); + } + + return new OutputArgument(arg.ToString()); + } + + private static IInstructionPatternArgument ParseBuffer(ReadOnlySpan arg) + { + return new BufferArgument(arg[1..].ToString()); + } + + private static IInstructionPatternArgument ParseNumOperator(ReadOnlySpan arg) + { + var op = arg[0]; + + return new NumberOperatorArgument(op, ParseArgument(arg[1..])); + } + + + private static IInstructionPatternArgument ParseNot(ReadOnlySpan arg) + { + var trimmed = arg.TrimStart('!'); + + return new NotArgument(ParseArgument(trimmed)); + } + + private static IInstructionPatternArgument ParseStringArgument(ReadOnlySpan arg) + { + StringOperation operation = StringOperation.None; + + if (arg[0] == '*' && arg[^1] == '*') { + operation = StringOperation.Contains; + } + else if(arg[0] == '*') { + operation = StringOperation.EndsWith; + } + else if(arg[^1] == '*') { + operation = StringOperation.StartsWith; + } + + arg = arg.TrimStart('*').TrimEnd('*'); + + if (arg[0] == '\'' && arg[^1] == '\'') { + return new StringArgument(arg[1..^1].ToString(), operation); + } + + throw new ArgumentException("Invalid string"); + } + +} \ No newline at end of file diff --git a/src/DistIL/IR/DSL/OutputPattern.cs b/src/DistIL/IR/DSL/OutputPattern.cs new file mode 100644 index 0000000..43ff36c --- /dev/null +++ b/src/DistIL/IR/DSL/OutputPattern.cs @@ -0,0 +1,48 @@ +namespace DistIL.IR.DSL; + +public readonly struct OutputPattern +{ + private readonly Dictionary _outputs = []; + private readonly Dictionary _buffer = []; + internal readonly InstructionPattern? Pattern = null; + + public OutputPattern(ReadOnlySpan input) + { + Pattern = InstructionPattern.Parse(input) as InstructionPattern; + } + + internal void Add(string key, Value value) + { + _outputs[key] = value; + } + + internal void AddToBuffer(string key, Value value) + { + _buffer[key] = value; + } + + public Value this[string name] => _outputs[name]; + + public Value this[int position] { + get { + string name = _outputs.Keys.ElementAt(position); + + return _outputs[name]; + } + } + + internal Value GetFromBuffer(string name) + { + return _buffer[name]; + } + + internal Value? Get(string name) + { + return _buffer.TryGetValue(name, out Value? value) ? value : _outputs.GetValueOrDefault(name); + } + + internal bool IsValueInBuffer(string name) + { + return _buffer.ContainsKey(name); + } +} \ No newline at end of file diff --git a/src/DistIL/IR/DSL/PatternArguments/BufferArgument.cs b/src/DistIL/IR/DSL/PatternArguments/BufferArgument.cs new file mode 100644 index 0000000..52830df --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/BufferArgument.cs @@ -0,0 +1,3 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record BufferArgument(string Name) : IInstructionPatternArgument; \ No newline at end of file diff --git a/src/DistIL/IR/DSL/PatternArguments/EvalArgument.cs b/src/DistIL/IR/DSL/PatternArguments/EvalArgument.cs new file mode 100644 index 0000000..383ce58 --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/EvalArgument.cs @@ -0,0 +1,3 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record EvalArgument(IInstructionPatternArgument OP) : IInstructionPatternArgument; \ No newline at end of file diff --git a/src/DistIL/IR/DSL/PatternArguments/IgnoreArgument.cs b/src/DistIL/IR/DSL/PatternArguments/IgnoreArgument.cs new file mode 100644 index 0000000..fdaf504 --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/IgnoreArgument.cs @@ -0,0 +1,6 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record IgnoreArgument : IInstructionPatternArgument +{ + +} diff --git a/src/DistIL/IR/DSL/PatternArguments/MethodRefArgument.cs b/src/DistIL/IR/DSL/PatternArguments/MethodRefArgument.cs new file mode 100644 index 0000000..a480c17 --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/MethodRefArgument.cs @@ -0,0 +1,3 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record MethodRefArgument(string Selector) : IInstructionPatternArgument; \ No newline at end of file diff --git a/src/DistIL/IR/DSL/PatternArguments/NotArgument.cs b/src/DistIL/IR/DSL/PatternArguments/NotArgument.cs new file mode 100644 index 0000000..ccb35ad --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/NotArgument.cs @@ -0,0 +1,6 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record NotArgument(IInstructionPatternArgument Inner) : IInstructionPatternArgument +{ + +} \ No newline at end of file diff --git a/src/DistIL/IR/DSL/PatternArguments/NumberArgument.cs b/src/DistIL/IR/DSL/PatternArguments/NumberArgument.cs new file mode 100644 index 0000000..6d85a77 --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/NumberArgument.cs @@ -0,0 +1,6 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record ConstantArgument(object Value, TypeDesc? Type) : IInstructionPatternArgument +{ + +} diff --git a/src/DistIL/IR/DSL/PatternArguments/NumberOperatorArgument.cs b/src/DistIL/IR/DSL/PatternArguments/NumberOperatorArgument.cs new file mode 100644 index 0000000..da97841 --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/NumberOperatorArgument.cs @@ -0,0 +1,6 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record NumberOperatorArgument(char Operator, IInstructionPatternArgument Argument) : IInstructionPatternArgument +{ + +} \ No newline at end of file diff --git a/src/DistIL/IR/DSL/PatternArguments/OutputArgument.cs b/src/DistIL/IR/DSL/PatternArguments/OutputArgument.cs new file mode 100644 index 0000000..9386137 --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/OutputArgument.cs @@ -0,0 +1,6 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record OutputArgument(string Name, IInstructionPatternArgument? SubPattern = null) : IInstructionPatternArgument +{ + +} \ No newline at end of file diff --git a/src/DistIL/IR/DSL/PatternArguments/StringArgument.cs b/src/DistIL/IR/DSL/PatternArguments/StringArgument.cs new file mode 100644 index 0000000..3fa3e77 --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/StringArgument.cs @@ -0,0 +1,6 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record StringArgument(object Value, StringOperation Operation) : ConstantArgument(Value, PrimType.String) +{ + +} \ No newline at end of file diff --git a/src/DistIL/IR/DSL/PatternArguments/StringOperation.cs b/src/DistIL/IR/DSL/PatternArguments/StringOperation.cs new file mode 100644 index 0000000..0a4255f --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/StringOperation.cs @@ -0,0 +1,9 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal enum StringOperation +{ + None, + StartsWith, + EndsWith, + Contains +} diff --git a/src/DistIL/IR/DSL/PatternArguments/TypedArgument.cs b/src/DistIL/IR/DSL/PatternArguments/TypedArgument.cs new file mode 100644 index 0000000..041b244 --- /dev/null +++ b/src/DistIL/IR/DSL/PatternArguments/TypedArgument.cs @@ -0,0 +1,6 @@ +namespace DistIL.IR.DSL.PatternArguments; + +internal record TypedArgument(IInstructionPatternArgument? Argument, string Type) : IInstructionPatternArgument +{ + +} diff --git a/src/DistIL/IR/MatchExtensions.cs b/src/DistIL/IR/MatchExtensions.cs new file mode 100644 index 0000000..8182741 --- /dev/null +++ b/src/DistIL/IR/MatchExtensions.cs @@ -0,0 +1,238 @@ +namespace DistIL.IR; + +using DSL.PatternArguments; +using DSL; +using Utils.Parser; + +public static class MatchExtensions +{ + public static bool Match(this Instruction instruction, string pattern, out OutputPattern outputs) + { + outputs = new OutputPattern(pattern); + + return MatchInstruction(instruction, outputs.Pattern!, outputs); + } + + public static bool Match(this Instruction instruction, string pattern) + { + var outputs = new OutputPattern(pattern); + + return MatchInstruction(instruction, outputs.Pattern!, outputs); + } + + internal static bool MatchInstruction(this Instruction instruction, InstructionPattern instrPattern, OutputPattern outputs) + { + if (instruction is BinaryInst bin) { + return MatchBinary(bin, instrPattern, outputs); + } + if (instruction is CompareInst comp) { + return MatchCompare(comp, instrPattern, outputs); + } + if (instruction is UnaryInst un) { + return MatchUnary(un, instrPattern, outputs); + } + + return MatchOtherInstruction(instruction, instrPattern, outputs); + } + + private static bool MatchOtherInstruction(Instruction instruction, InstructionPattern pattern, OutputPattern outputs) + { + var op = pattern.OpCode; + var ops = MatchOpCode((op, instruction)); + + return MatchOperands(instruction, pattern, outputs); + } + + private static bool MatchOpCode((Opcode op, Instruction instruction) opInstrTuple) + { + return opInstrTuple switch { + (Opcode.Unknown, _) => false, + (Opcode.Goto, BranchInst) => true, + (Opcode.Switch, SwitchInst) => true, + (Opcode.Ret, ReturnInst) => true, + (Opcode.Phi, PhiInst) => true, + (Opcode.Select, SelectInst) => true, + (Opcode.Lea, PtrOffsetInst) => true, + (Opcode.Getfld,FieldExtractInst) => true, + (Opcode.Setfld, FieldInsertInst) => true, + (Opcode.ArrAddr, ArrayAddrInst) => true, + (Opcode.FldAddr, FieldAddrInst) => true, + (Opcode.Load, LoadInst) => true, + (Opcode.Store, StoreInst) => true, + (Opcode.Conv, ConvertInst) => true, + _ => false + }; + } + + private static bool MatchOperands(Instruction instruction, InstructionPattern pattern, OutputPattern outputs) + { + if (pattern.Arguments.Count > instruction.Operands.Length) { + return false; + } + + for (int index = 0; index < pattern.Arguments.Count; index++) { + Value? operand = instruction.Operands[index]; + if (!MatchValue(operand, pattern.Arguments[index], outputs)) { + return false; + } + } + + return true; + } + + private static bool MatchArgument(Value value, IInstructionPatternArgument argument, OutputPattern outputs) + { + switch (argument) { + case NotArgument not: + return !MatchArgument(value, not.Inner, outputs); + case IgnoreArgument: + return true; + case BufferArgument buffer: + return MatchBuffer(value, buffer, outputs); + case OutputArgument output: + return MatchOutput(value, outputs, output); + case ConstantArgument constArg when value is Const constant: + return MatchConstArgument(constArg, constant); + case InstructionPattern pattern: + return MatchValue(value, pattern, outputs); + case TypedArgument typed: + return MatchTypeSpecifier(value, typed, outputs); + case NumberOperatorArgument numOp: + return MatchNumOperator(value, numOp, outputs); + default: + return false; + } + } + + private static bool MatchOutput(Value value, OutputPattern outputs, OutputArgument output) + { + if (output.SubPattern is null) { + outputs.Add(output.Name, value); + return true; + } + + if (MatchValue(value, output.SubPattern, outputs)) { + outputs.Add(output.Name, value); + return true; + } + + return false; + } + + private static bool MatchBuffer(Value value, BufferArgument buffer, OutputPattern outputs) + { + if (outputs.IsValueInBuffer(buffer.Name)) { + var bufferedValue = outputs.GetFromBuffer(buffer.Name); + + return bufferedValue == value; + } + + outputs.AddToBuffer(buffer.Name, value); + return true; + } + + private static bool MatchNumOperator(Value value, NumberOperatorArgument numOp, OutputPattern outputs) + { + if (numOp.Argument is not ConstantArgument constantArg) { + return false; + } + + if (constantArg.Type != PrimType.Int64 && constantArg.Type != PrimType.Double) { + return false; + } + + dynamic constant = constantArg; + dynamic val = value; + + if (numOp.Operator == '<') { + return val.Value < constant.Value; + } else if (numOp.Operator == '>') { + return val.Value > constant.Value; + } + + return false; + } + + + private static bool MatchTypeSpecifier(Value value, TypedArgument typed, OutputPattern outputs) + { + bool result = true; + if (typed.Argument is not null) { + result = MatchArgument(value, typed.Argument, outputs); + } + + if (typed.Type is "const") { + result &= value is Const; + } else if (typed.Type is "instr") { + result &= value is Instruction; + } else { + result &= PrimType.GetFromAlias(typed.Type) == value.ResultType; + } + + return result; + } + + private static bool MatchValue(Value value, IInstructionPatternArgument pattern, OutputPattern outputs) + { + return pattern switch { + InstructionPattern p when value is Instruction instruction => MatchInstruction(instruction, p, outputs), + _ => MatchArgument(value, pattern, outputs) + }; + } + + private static bool MatchConstArgument(ConstantArgument constantArg, Const constant) + { + if (constantArg is StringArgument strArg) { + return constant is ConstString str && MatchStringArg(strArg, str); + } + + // TODO: consider supporting explicit number typing suffixes in ConstantArgument + object? value = constant switch { + ConstInt constInt => constInt.Value, + ConstFloat constFloat => constFloat.Value, + ConstNull => null, + _ => null + }; + return constantArg.Value.Equals(value); + } + + private static bool MatchStringArg(StringArgument strArg, ConstString constant) + { + if (strArg.Operation == StringOperation.StartsWith) { + return constant.Value.StartsWith(strArg.Value.ToString()!); + } + if (strArg.Operation == StringOperation.EndsWith) { + return constant.Value.EndsWith(strArg.Value.ToString()!); + } + if (strArg.Operation == StringOperation.Contains) { + return constant.Value.Contains(strArg.Value.ToString()!); + } + + return strArg.Value.Equals(constant.Value); + } + + + private static bool MatchBinary(BinaryInst bin, InstructionPattern pattern, OutputPattern outputs) + { + if (pattern.OpCode.IsBinaryOp() && pattern.OpCode.GetBinaryOp() == bin.Op) { + return MatchOperands(bin, pattern, outputs); + } + return false; + } + + private static bool MatchCompare(CompareInst comp, InstructionPattern pattern, OutputPattern outputs) + { + if (pattern.OpCode.IsCompareOp() && pattern.OpCode.GetCompareOp() == comp.Op) { + return MatchOperands(comp, pattern, outputs); + } + return false; + } + + private static bool MatchUnary(UnaryInst un, InstructionPattern pattern, OutputPattern outputs) + { + if (pattern.OpCode.IsUnaryOp() && pattern.OpCode.GetUnaryOp() == un.Op) { + return MatchOperands(un, pattern, outputs); + } + return false; + } +} \ No newline at end of file diff --git a/src/DistIL/IR/ReplaceExtensions.cs b/src/DistIL/IR/ReplaceExtensions.cs new file mode 100644 index 0000000..6f2495c --- /dev/null +++ b/src/DistIL/IR/ReplaceExtensions.cs @@ -0,0 +1,71 @@ +namespace DistIL.IR; + +using DSL; +using DSL.PatternArguments; + +using Utils.Parser; + +public static class ReplaceExtensions +{ + public static void Replace(this Instruction instruction, ReadOnlySpan replacementPattern) + { + var parts = new Range[2]; + replacementPattern.Split(parts, "->", StringSplitOptions.TrimEntries); + + var outputs = new OutputPattern(replacementPattern[parts[0]]); + var matched = instruction.MatchInstruction(outputs.Pattern!, outputs); + + if (matched) { + var pattern = InstructionPattern.Parse(replacementPattern[parts[1]]); + var newInstr = Evaluate(pattern, outputs); + instruction.ReplaceUses(newInstr); + + if (instruction.NumUses == 0) { + instruction.Remove(); + } + } + } + + private static Value Evaluate(IInstructionPatternArgument replacementPattern, OutputPattern outputs) + { + return replacementPattern switch { + BufferArgument b => outputs.Get(b.Name)!, + OutputArgument o => outputs.Get(o.Name)!, + InstructionPattern instr => CreateInstruction(instr, outputs), + ConstantArgument constant => CreateConstant(constant), + _ => throw new ArgumentException($"Invalid replacement pattern type: {replacementPattern.GetType()}") + }; + } + + private static Value CreateConstant(ConstantArgument constant) + { + if (constant.Type == PrimType.Single) { + return ConstFloat.CreateS((float)constant.Value); + } + if (constant.Type == PrimType.Double) { + return ConstFloat.CreateD((double)constant.Value); + } + if (constant.Type == PrimType.Int32) { + return ConstInt.CreateI((int)constant.Value); + } + if (constant.Type == PrimType.Int64) { + return ConstInt.CreateL((long)constant.Value); + } + + throw new ArgumentOutOfRangeException(nameof(constant.Type)); + } + + private static Value CreateInstruction(InstructionPattern instr, OutputPattern outputs) + { + var args = instr.Arguments.Select(a => Evaluate(a, outputs)).ToArray(); + + if (instr.OpCode.IsBinaryOp()) { + return new BinaryInst(instr.OpCode.GetBinaryOp(), args[0], args[1]); + } + if (instr.OpCode.IsCompareOp()) { + return new CompareInst(instr.OpCode.GetCompareOp(), args[0], args[1]); + } + + throw new ArgumentException("Invalid instruction opcode"); + } +} \ No newline at end of file diff --git a/src/DistIL/IR/Utils/Parser/IRParser.cs b/src/DistIL/IR/Utils/Parser/IRParser.cs index fdff791..1c366e8 100644 --- a/src/DistIL/IR/Utils/Parser/IRParser.cs +++ b/src/DistIL/IR/Utils/Parser/IRParser.cs @@ -347,30 +347,22 @@ private TypeDesc ParseResultType() private Instruction ParseMultiOpInst(Opcode op, AbsRange pos) { - if (op is > Opcode._Bin_First and < Opcode._Bin_Last) { - return Schedule(PendingInst.Kind.Binary, op - (Opcode._Bin_First + 1)); - } - if (op is > Opcode._Cmp_First and < Opcode._Cmp_Last) { - return Schedule(PendingInst.Kind.Compare, op - (Opcode._Cmp_First + 1)); - } - throw _ctx.Fatal("Unknown instruction", pos); - - // Some insts have dynamic result types and depend on the real value type, - // once they're found, we'll materialize them. - Instruction Schedule(PendingInst.Kind kind, int op) - { + // Some insts have dynamic result types and depend on the real value type. + // We'll defer materialization until they have been parsed. + if (op.IsBinaryOp() || op.IsCompareOp()) { var left = ParseValue(); _lexer.Expect(TokenType.Comma); var right = ParseValue(); var type = ParseResultType(); - if (PendingInst.Resolve(kind, op, left, right) is { } resolved) { + if (PendingInst.Resolve(op, left, right) is { } resolved) { return resolved; } - var inst = new PendingInst(kind, op, left, right, type); + var inst = new PendingInst(op, left, right, type); _pendingInsts.Add(inst); return inst; } + throw _ctx.Fatal("Unknown instruction", pos); } // Goto := Label | (Value "?" Label ":" Label) @@ -827,19 +819,17 @@ sealed class PendingValue : TrackedValue } sealed class PendingInst : Instruction { - public Kind InstKind; - public int Op; + public Opcode Op; - public override string InstName => "pending." + InstKind; + public override string InstName => "pending." + Op; public override void Accept(InstVisitor visitor) => throw new InvalidOperationException(); - public PendingInst(Kind kind, int op, Value left, Value right, TypeDesc resultType) - : base(left, right) - => (InstKind, Op, ResultType) = (kind, op, resultType); + public PendingInst(Opcode op, Value left, Value right, TypeDesc resultType) : base(left, right) + => (Op, ResultType) = (op, resultType); public bool TryResolve() { - if (Resolve(InstKind, Op, Operands[0], Operands[1]) is { } resolved) { + if (Resolve(Op, Operands[0], Operands[1]) is { } resolved) { Ensure.That(resolved.ResultType == ResultType); ReplaceWith(resolved, insertIfInst: true); return true; @@ -847,15 +837,18 @@ public bool TryResolve() return false; } - public static Instruction? Resolve(Kind kind, int op, Value left, Value right) + public static Instruction? Resolve(Opcode op, Value left, Value right) { if (left is PendingValue || right is PendingValue) { return null; } - return kind switch { - PendingInst.Kind.Binary => new BinaryInst((BinaryOp)op, left, right), - PendingInst.Kind.Compare => new CompareInst((CompareOp)op, left, right) - }; + if (op.IsBinaryOp()) { + return new BinaryInst(op.GetBinaryOp(), left, right); + } + if (op.IsCompareOp()) { + return new CompareInst(op.GetCompareOp(), left, right); + } + throw new InvalidOperationException(); } public enum Kind { Binary, Compare }; diff --git a/src/DistIL/IR/Utils/Parser/Opcodes.cs b/src/DistIL/IR/Utils/Parser/Opcodes.cs index dde7ff5..98c931f 100644 --- a/src/DistIL/IR/Utils/Parser/Opcodes.cs +++ b/src/DistIL/IR/Utils/Parser/Opcodes.cs @@ -18,32 +18,38 @@ internal enum Opcode Load, Store, Conv, - // Note: Entries must be keept in the same order as in BinaryOp - _Bin_First, - Bin_Add, Bin_Sub, Bin_Mul, - Bin_SDiv, Bin_UDiv, - Bin_SRem, Bin_URem, - - Bin_And, Bin_Or, Bin_Xor, - Bin_Shl, // << Shift left - Bin_Shra, // >> Shift right (arithmetic) - Bin_Shrl, // >>> Shift right (logical) - - Bin_FAdd, Bin_FSub, Bin_FMul, Bin_FDiv, Bin_FRem, - - Bin_AddOvf, Bin_SubOvf, Bin_MulOvf, - Bin_UAddOvf, Bin_USubOvf, Bin_UMulOvf, - _Bin_Last, - - // Note: Entries must be keept in the same order as in CompareOp - _Cmp_First, + // BinaryOp + // NOTE: order must match respective enums + _FirstBinaryOp, + Add, Sub, Mul, + SDiv, UDiv, + SRem, URem, + + And, Or, Xor, + Shl, // << Shift left + Shra, // >> Shift right (arithmetic) + Shrl, // >>> Shift right (logical) + + FAdd, FSub, FMul, FDiv, FRem, + + AddOvf, SubOvf, MulOvf, + UAddOvf, USubOvf, UMulOvf, + _LastBinaryOp, + + // UnaryOp + _FirstUnaryOp, + Neg, Not, FNeg, + _LastUnaryOp, + + // CompareOp + _FirstCompareOp, Cmp_Eq, Cmp_Ne, Cmp_Slt, Cmp_Sgt, Cmp_Sle, Cmp_Sge, Cmp_Ult, Cmp_Ugt, Cmp_Ule, Cmp_Uge, Cmp_FOlt, Cmp_FOgt, Cmp_FOle, Cmp_FOge, Cmp_FOeq, Cmp_FOne, Cmp_FUlt, Cmp_FUgt, Cmp_FUle, Cmp_FUge, Cmp_FUeq, Cmp_FUne, - _Cmp_Last, + _LastCompareOp, } [Flags] @@ -77,30 +83,34 @@ public static (Opcode Op, OpcodeModifiers Mods) TryParse(string str) "getfld" => Opcode.Getfld, "setfld" => Opcode.Setfld, - "add" => Opcode.Bin_Add, - "sub" => Opcode.Bin_Sub, - "mul" => Opcode.Bin_Mul, - "sdiv" => Opcode.Bin_SDiv, - "srem" => Opcode.Bin_SRem, - "udiv" => Opcode.Bin_UDiv, - "urem" => Opcode.Bin_URem, - "and" => Opcode.Bin_And, - "or" => Opcode.Bin_Or, - "xor" => Opcode.Bin_Xor, - "shl" => Opcode.Bin_Shl, - "shra" => Opcode.Bin_Shra, - "shrl" => Opcode.Bin_Shrl, - "fadd" => Opcode.Bin_FAdd, - "fsub" => Opcode.Bin_FSub, - "fmul" => Opcode.Bin_FMul, - "fdiv" => Opcode.Bin_FDiv, - "frem" => Opcode.Bin_FRem, - "add.ovf" => Opcode.Bin_AddOvf, - "sub.ovf" => Opcode.Bin_SubOvf, - "mul.ovf" => Opcode.Bin_MulOvf, - "uadd.ovf" => Opcode.Bin_UAddOvf, - "usub.ovf" => Opcode.Bin_USubOvf, - "umul.ovf" => Opcode.Bin_UMulOvf, + "add" => Opcode.Add, + "sub" => Opcode.Sub, + "mul" => Opcode.Mul, + "sdiv" => Opcode.SDiv, + "srem" => Opcode.SRem, + "udiv" => Opcode.UDiv, + "urem" => Opcode.URem, + "and" => Opcode.And, + "or" => Opcode.Or, + "xor" => Opcode.Xor, + "shl" => Opcode.Shl, + "shra" => Opcode.Shra, + "shrl" => Opcode.Shrl, + "fadd" => Opcode.FAdd, + "fsub" => Opcode.FSub, + "fmul" => Opcode.FMul, + "fdiv" => Opcode.FDiv, + "frem" => Opcode.FRem, + "add.ovf" => Opcode.AddOvf, + "sub.ovf" => Opcode.SubOvf, + "mul.ovf" => Opcode.MulOvf, + "uadd.ovf" => Opcode.UAddOvf, + "usub.ovf" => Opcode.USubOvf, + "umul.ovf" => Opcode.UMulOvf, + + "not" => Opcode.Not, + "neg" => Opcode.Neg, + "fneg" => Opcode.FNeg, "cmp.eq" => Opcode.Cmp_Eq, "cmp.ne" => Opcode.Cmp_Ne, @@ -157,6 +167,25 @@ private static OpcodeModifiers ParseModifiers(string str) (str.Contains(".volatile") ? OpcodeModifiers.Volatile : 0) | (str.Contains(".inbounds") ? OpcodeModifiers.InBounds : 0) | (str.Contains(".readonly") ? OpcodeModifiers.ReadOnly : 0); + } + + public static bool IsBinaryOp(this Opcode op) => op is > Opcode._FirstBinaryOp and < Opcode._LastBinaryOp; + public static bool IsUnaryOp(this Opcode op) => op is > Opcode._FirstUnaryOp and < Opcode._LastUnaryOp; + public static bool IsCompareOp(this Opcode op) => op is > Opcode._FirstCompareOp and < Opcode._LastCompareOp; + public static BinaryOp GetBinaryOp(this Opcode op) + { + Ensure.That(op.IsBinaryOp()); + return (BinaryOp)(op - (Opcode._FirstBinaryOp + 1)); + } + public static UnaryOp GetUnaryOp(this Opcode op) + { + Ensure.That(op.IsUnaryOp()); + return (UnaryOp)(op - (Opcode._FirstUnaryOp + 1)); + } + public static CompareOp GetCompareOp(this Opcode op) + { + Ensure.That(op.IsCompareOp()); + return (CompareOp)(op - (Opcode._FirstCompareOp + 1)); } } \ No newline at end of file diff --git a/tests/Benchmarks/AutoVectorization.cs b/tests/Benchmarks/AutoVectorization.cs index cf8c6f3..8d18a31 100644 --- a/tests/Benchmarks/AutoVectorization.cs +++ b/tests/Benchmarks/AutoVectorization.cs @@ -5,7 +5,7 @@ using System.Runtime.Intrinsics; using System.Runtime.InteropServices; -[Optimize(TryVectorize = true), DisassemblyDiagnoser] +[Optimize(), DisassemblyDiagnoser] public class AutoVectorization { [Params(4096)] diff --git a/tests/DistIL.Tests/IR/MatchingTests.cs b/tests/DistIL.Tests/IR/MatchingTests.cs new file mode 100644 index 0000000..cf86fc1 --- /dev/null +++ b/tests/DistIL.Tests/IR/MatchingTests.cs @@ -0,0 +1,139 @@ +namespace DistIL.Tests.IR; + +using DistIL.AsmIO; +using DistIL.IR; +using DistIL.IR.Utils; + +[Collection("ModuleResolver")] +public class MatchingTests +{ + private readonly ModuleResolver _modResolver; + private MethodDesc _stub; + private readonly ModuleDef _module; + private readonly TypeDef _testType; + + + public MatchingTests(ModuleResolverFixture mrf) + { + _modResolver = mrf.Resolver; + + _module = _modResolver.Resolve("TestAsm"); + _testType = _module.FindType("TestAsm", "MatcherStub")!; + _stub = _testType.FindMethod("StubMethod"); + } + + [Fact] + public void TestMatch() + { + var inst = new BinaryInst(BinaryOp.Add, ConstInt.CreateI(42), new BinaryInst(BinaryOp.Mul, ConstInt.CreateI(1), ConstInt.CreateI(3))); + + Assert.True(inst.Match("(add 42 {instr})", out var outputs)); + var instr = (BinaryInst)outputs["instr"]; + Assert.IsType(instr); + Assert.Equal(BinaryOp.Mul, instr.Op); + + Assert.True(inst.Match("(add {x} (mul _ _))", out outputs)); + var x = (ConstInt)outputs["x"]; + Assert.IsType(x); + Assert.Equal(42L, x.Value); + } + + [Fact] + public void TestSubMatchOutput() + { + var inst = new BinaryInst(BinaryOp.Sub, new BinaryInst(BinaryOp.Add, ConstInt.CreateI(1), ConstInt.CreateI(3)), ConstInt.CreateI(3)); + + Assert.True(inst.Match("(sub {lhs:(add {x} $y)} $y)", out var outputs)); + var instr = (BinaryInst)outputs["lhs"]; + Assert.IsType(instr); + Assert.Equal(BinaryOp.Add, instr.Op); + } + + [Fact] + public void TestCallMatchOutput() + { + var inst = new CallInst(_stub, [ConstNull.Create(), ConstNull.Create()]); + + /* cannot work properly because method matching is not implemented + Assert.True(inst.Match("(call DistIL.Tests.IR.MatchingTests.SubMethod())", out var outputs)); + var instr = (BinaryInst)outputs["lhs"]; + Assert.IsType(instr); + Assert.Equal(BinaryOp.Add, instr.Op); + */ + } + + [Fact] + public void TestReplace() + { + var method = _testType.CreateMethod("ReplaceMe", new TypeSig(PrimType.Void), []); + var body = new MethodBody(method); + var builder = new IRBuilder(body.CreateBlock()); + + var inst = new BinaryInst(BinaryOp.Sub, new BinaryInst(BinaryOp.Add, ConstInt.CreateL(1), ConstInt.CreateL(3)), ConstInt.CreateL(3)); + builder.Emit(inst); + + inst.Replace("(sub {lhs:(add $x $y)} $y) -> (add $x 0)"); + } + + [Fact] + public void TestNot() + { + var inst = new BinaryInst(BinaryOp.Add, ConstInt.CreateI(42), new BinaryInst(BinaryOp.Mul, ConstInt.CreateI(1), ConstInt.CreateI(3))); + + Assert.True(inst.Match("(add _ !42)")); + } + + [Fact] + public void TestReturn() + { + var inst = new ReturnInst(new BinaryInst(BinaryOp.Mul, ConstInt.CreateI(1), ConstInt.CreateI(3))); + + Assert.True(inst.Match("(ret _)")); + } + + [Fact] + public void TestCompare() + { + var inst = new CompareInst(CompareOp.Eq, ConstInt.CreateI(1), ConstInt.CreateI(3)); + + Assert.True(inst.Match("(cmp.eq)")); + } + + [Fact] + public void TestUnary() + { + var inst = new UnaryInst(UnaryOp.Neg, new UnaryInst(UnaryOp.Neg, ConstInt.CreateI(2))); + + Assert.True(inst.Match("(neg (neg {x}))")); + } + + [Fact] + public void TestTypedArgument() + { + var inst = new BinaryInst(BinaryOp.Add, ConstInt.CreateI(42), new BinaryInst(BinaryOp.Mul, ConstInt.CreateI(1), ConstInt.CreateI(3))); + + Assert.True(inst.Match("(add #int !42)")); + } + + [Fact] + public void TestNumberOperator() + { + var inst = new BinaryInst(BinaryOp.Add, ConstInt.CreateI(42), new BinaryInst(BinaryOp.Mul, ConstInt.CreateI(1), ConstInt.CreateI(3))); + + Assert.True(inst.Match($"(add >5 _)")); + } + + [Fact] + public void Test_Strings() + { + var instr = new CallInst(_stub, [ConstString.Create("hello"), ConstString.Create("world")]); + + /* cannot work properly because method matching is not implemented + Assert.True(instr.Match($"(call 'hello' _)")); + Assert.True(instr.Match($"(call *'o' _)")); + Assert.True(instr.Match($"(call _ 'h'*)")); + Assert.True(instr.Match($"(call *'l'* _)")); + */ + } + +} \ No newline at end of file diff --git a/tests/misc/TestAsm/Stub.cs b/tests/misc/TestAsm/Stub.cs new file mode 100644 index 0000000..7a34000 --- /dev/null +++ b/tests/misc/TestAsm/Stub.cs @@ -0,0 +1,9 @@ +namespace TestAsm; + +public class MatcherStub +{ + public static void StubMethod(string first, string second) + { + System.Console.WriteLine(first + second); + } +} \ No newline at end of file