|
1 | 1 | package analysis
|
| 2 | +import ir.transforms.ReadWriteAnalysis.* |
| 3 | +import ir.transforms.ReadWriteAnalysis |
2 | 4 |
|
3 | 5 | import ir.{DirectCall, LocalAssign, MemoryLoad, MemoryStore, Procedure, Program, Register}
|
4 | 6 |
|
5 | 7 | import scala.collection.mutable
|
6 | 8 |
|
7 | 9 | class WriteToAnalysis(program: Program) extends Analysis[Map[Procedure, Set[Register]]] {
|
8 | 10 |
|
9 |
| - val writesTo: mutable.Map[Procedure, Set[Register]] = mutable.Map() |
10 |
| - val mallocRegister = Register("R0", 64) |
11 |
| - val paramRegisters: Set[Register] = Set( |
12 |
| - mallocRegister, |
13 |
| - Register("R1", 64), |
14 |
| - Register("R2", 64), |
15 |
| - Register("R3", 64), |
16 |
| - Register("R4", 64), |
17 |
| - Register("R5", 64), |
18 |
| - Register("R6", 64), |
19 |
| - Register("R7", 64), |
20 |
| - ) |
| 11 | + lazy val result = ir.transforms.ReadWriteAnalysis.readWriteSets(program) |
21 | 12 |
|
22 |
| - def getWritesTos(proc: Procedure): Set[Register] = { |
23 |
| - if writesTo.contains(proc) then |
24 |
| - writesTo(proc) |
25 |
| - else |
26 |
| - val writtenTo: mutable.Set[Register] = mutable.Set() |
27 |
| - proc.blocks.foreach { block => |
28 |
| - block.statements.foreach { |
29 |
| - case LocalAssign(variable: Register, _, _) if paramRegisters.contains(variable) => |
30 |
| - writtenTo.add(variable) |
31 |
| - case MemoryLoad(lhs: Register, _, _, _, _, _) if paramRegisters.contains(lhs) => |
32 |
| - writtenTo.add(lhs) |
33 |
| - case DirectCall(target, _, _, _) if target.name == "malloc" => |
34 |
| - writtenTo.add(mallocRegister) |
35 |
| - case d: DirectCall if program.procedures.contains(d.target) => |
36 |
| - writtenTo.addAll(getWritesTos(d.target)) |
37 |
| - case _ => |
38 |
| - } |
39 |
| - } |
| 13 | + val overApprox = ((0 to 31).toSet -- (19 to 28).toSet).map(i => Register(s"R${i}", 64)).toSet |
40 | 14 |
|
41 |
| - writesTo.update(proc, writtenTo.toSet) |
42 |
| - writesTo(proc) |
| 15 | + def getWritesTos(proc: Procedure): Set[Register] = { |
| 16 | + result.get(proc).map { |
| 17 | + case Some(r) => r.writes.collect { |
| 18 | + case reg: Register => reg |
| 19 | + }.toSet |
| 20 | + case None => overApprox |
| 21 | + }.toSet.flatten |
43 | 22 | }
|
44 | 23 |
|
45 | 24 | def analyze(): Map[Procedure, Set[Register]] =
|
46 |
| - program.procedures.foreach(getWritesTos) |
47 |
| - writesTo.toMap |
| 25 | + result.keySet.map(p => p -> getWritesTos(p)).toMap |
48 | 26 | }
|
0 commit comments