Skip to content

Commit feb9b90

Browse files
committed
Treat local mutable vars as capabilities
1 parent 125d9ab commit feb9b90

25 files changed

+315
-36
lines changed

compiler/src/dotty/tools/dotc/cc/CCState.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ class CCState:
7979
object Unrecorded extends VarState.Unrecorded
8080
object ClosedUnrecorded extends VarState.ClosedUnrecorded
8181

82+
// ----- Mirrors for local vars -------------------------
83+
84+
/** A cache for mirrors of local mutable vars */
85+
val varMirrors = util.EqHashMap[Symbol, Symbol]()
86+
8287
// ------ Context info accessed from companion object when isCaptureCheckingOrSetup is true
8388

8489
private var openExistentialScopes: List[MethodType] = Nil

compiler/src/dotty/tools/dotc/cc/Capability.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,18 @@ object Capabilities:
459459
case self: CoreCapability => self.isTrackableRef
460460
case _ => true
461461

462+
/** Under separatiion checking: Is this a mutable var owned by a term that is
463+
* not annotated with @untrackedCaptures? Such mutable variables need to be
464+
* tracked as capabilities. Since mutable variables are not trackable, we do
465+
* this by adding a varMirror symbol to such variables which represents the capability.
466+
*/
467+
final def isLocalMutable(using Context): Boolean = this match
468+
case tp @ TermRef(NoPrefix, _) =>
469+
ccConfig.newScheme && ccConfig.strictMutability
470+
&& tp.symbol.isMutableVar
471+
&& !tp.symbol.hasAnnotation(defn.UntrackedCapturesAnnot)
472+
case _ => false
473+
462474
/** The non-derived capability underlying this capability */
463475
final def core: CoreCapability | RootCapability = this match
464476
case self: (CoreCapability | RootCapability) => self

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ extension (tp: Type)
7474
GlobalCap
7575
case ref: Capability if ref.isTrackableRef =>
7676
ref
77+
case ref: TermRef if ref.isLocalMutable =>
78+
ref.mapLocalMutable
7779
case _ =>
7880
// if this was compiled from cc syntax, problem should have been reported at Typer
7981
throw IllegalCaptureRef(tp)
@@ -493,6 +495,12 @@ extension (tp: MethodType)
493495
def marksExistentialScope(using Context): Boolean =
494496
!tp.resType.isInstanceOf[MethodOrPoly]
495497

498+
extension (ref: TermRef | ThisType)
499+
/** Map a local mutable var to its mirror */
500+
def mapLocalMutable(using Context): TermRef | ThisType = ref match
501+
case ref: TermRef if ref.isLocalMutable => ref.symbol.varMirror.termRef
502+
case _ => ref
503+
496504
extension (cls: ClassSymbol)
497505

498506
def pureBaseClass(using Context): Option[Symbol] =
@@ -669,6 +677,13 @@ extension (sym: Symbol)
669677
def isDisallowedInCapset(using Context): Boolean =
670678
sym.isOneOf(if ccConfig.newScheme && ccConfig.strictMutability then Method else UnstableValueFlags)
671679

680+
def varMirror(using Context): Symbol =
681+
ccState.varMirrors.getOrElseUpdate(sym,
682+
sym.copy(
683+
flags = Flags.EmptyFlags,
684+
info = defn.Caps_Var.typeRef.appliedTo(sym.info)
685+
.capturing(FreshCap(sym, Origin.InDecl(sym)))))
686+
672687
extension (tp: AnnotatedType)
673688
/** Is this a boxed capturing type? */
674689
def isBoxed(using Context): Boolean = tp.annot match

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,14 @@ object CheckCaptures:
106106
/** Check that a @retains annotation only mentions references that can be tracked.
107107
* This check is performed at Typer.
108108
*/
109-
def checkWellformed(parent: Tree, ann: Tree)(using Context): Unit =
109+
def checkWellformedRetains(parent: Tree, ann: Tree)(using Context): Unit =
110110
def check(elem: Type): Unit = elem match
111111
case ref: TypeRef =>
112112
val refSym = ref.symbol
113113
if refSym.isType && !refSym.info.derivesFrom(defn.Caps_CapSet) then
114114
report.error(em"$elem is not a legal element of a capture set", ann.srcPos)
115115
case ref: CoreCapability =>
116-
if !ref.isTrackableRef && !ref.isCapRef then
116+
if !ref.isTrackableRef && !ref.isCapRef && !ref.isLocalMutable then
117117
report.error(em"$elem cannot be tracked since it is not a parameter or local value", ann.srcPos)
118118
case ReachCapability(ref) =>
119119
check(ref)
@@ -306,7 +306,7 @@ class CheckCaptures extends Recheck, SymTransformer:
306306
*/
307307
private val useInfos = mutable.ArrayBuffer[(Tree, CaptureSet, Env)]()
308308

309-
private var usedSet = util.EqHashMap[Tree, CaptureSet]()
309+
private val usedSet = util.EqHashMap[Tree, CaptureSet]()
310310

311311
/** The set of symbols that were rechecked via a completer */
312312
private val completed = new mutable.HashSet[Symbol]
@@ -720,7 +720,7 @@ class CheckCaptures extends Recheck, SymTransformer:
720720
val sel = ref.select(pt.selector).asInstanceOf[TermRef]
721721
markPathFree(sel, pt.pt, pt.select)
722722
case _ =>
723-
markFree(ref.adjustReadOnly(pt), tree)
723+
markFree(ref.mapLocalMutable.adjustReadOnly(pt), tree)
724724

725725
/** The expected type for the qualifier of a selection. If the selection
726726
* could be part of a capability path or is a a read-only method, we return
@@ -1069,7 +1069,7 @@ class CheckCaptures extends Recheck, SymTransformer:
10691069
recheck(tree.rhs, lhsType.widen)
10701070
lhsType match
10711071
case lhsType @ TermRef(qualType, _)
1072-
if (qualType ne NoPrefix) && !lhsType.symbol.hasAnnotation(defn.UntrackedCapturesAnnot) =>
1072+
if !lhsType.symbol.hasAnnotation(defn.UntrackedCapturesAnnot) =>
10731073
checkUpdate(qualType, tree.srcPos)(i"Cannot assign to field ${lhsType.name} of ${qualType.showRef}")
10741074
case _ =>
10751075
defn.UnitType
@@ -1142,6 +1142,13 @@ class CheckCaptures extends Recheck, SymTransformer:
11421142
openClosures = openClosures.tail
11431143
end recheckClosureBlock
11441144

1145+
/** Add var mirrors to the list of block-local symbols to avoid */
1146+
override def avoidLocals(tp: Type, symsToAvoid: => List[Symbol])(using Context): Type =
1147+
val locals = symsToAvoid
1148+
val varMirrors = locals.collect:
1149+
case local if local.termRef.isLocalMutable => local.varMirror
1150+
super.avoidLocals(tp, varMirrors ++ locals)
1151+
11451152
/** Elements of a SeqLiteral instantiate a Seq or Array parameter, so they
11461153
* should be boxed.
11471154
*/

compiler/src/dotty/tools/dotc/cc/Mutability.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ object Mutability:
121121
&& (!tp.isStatefulType || tp.captureSet.mutability == CaptureSet.Mutability.Reader)
122122

123123
extension (ref: TermRef | ThisType)
124-
/** Map `ref` to `ref.readOnly` if its type extends Mutble, and one of the
124+
/** Map `ref` to `ref.readOnly` if its type extends Mutable, and one of the
125125
* following is true:
126126
* - it appears in a non-exclusive context,
127127
* - the expected type is a value type that is not a stateful type,

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,9 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
744744
val captured = genPart.deepCaptureSet.elems
745745
val hiddenSet = captured.transHiddenSet.pruned
746746
val clashSet = otherPart.deepCaptureSet.elems
747-
val deepClashSet = clashSet.completeFootprint.nonPeaks.pruned
747+
var deepClashSet = clashSet.completeFootprint.nonPeaks.pruned
748+
if deepClashSet.isEmpty then
749+
deepClashSet = clashSet.completeFootprint.pruned
748750
report.error(
749751
em"""Separation failure in ${role.description} $tpe.
750752
|One part, $genPart, hides capabilities ${CaptureSet(hiddenSet)}.

compiler/src/dotty/tools/dotc/cc/ccConfig.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ object ccConfig:
5454

5555
/** Not used currently. Handy for trying out new features */
5656
def newScheme(using ctx: Context): Boolean =
57-
Feature.sourceVersion.stable.isAtLeast(SourceVersion.`3.7`)
57+
Feature.sourceVersion.stable.isAtLeast(SourceVersion.`3.8`)
5858

5959
/** Allow @use annotations */
6060
def allowUse(using Context): Boolean =

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,7 @@ class Definitions {
10311031
@tu lazy val Caps_ContainsModule: Symbol = requiredModule("scala.caps.Contains")
10321032
@tu lazy val Caps_containsImpl: TermSymbol = Caps_ContainsModule.requiredMethod("containsImpl")
10331033
@tu lazy val Caps_freeze: TermSymbol = CapsModule.requiredMethod("freeze")
1034+
@tu lazy val Caps_Var: ClassSymbol = requiredClass("scala.caps.internal.Var")
10341035

10351036
@tu lazy val PureClass: ClassSymbol = requiredClass("scala.caps.Pure")
10361037

@@ -1991,7 +1992,7 @@ class Definitions {
19911992
Caps_Shared, RequiresCapabilityAnnot,
19921993
captureRoot, Caps_CapSet, Caps_ContainsTrait, Caps_ContainsModule, Caps_ContainsModule.moduleClass,
19931994
ConsumeAnnot, UseAnnot, ReserveAnnot,
1994-
CapsUnsafeModule, CapsUnsafeModule.moduleClass, Caps_freeze,
1995+
CapsUnsafeModule, CapsUnsafeModule.moduleClass, Caps_freeze, Caps_Var,
19951996
CapsInternalModule, CapsInternalModule.moduleClass,
19961997
RetainsAnnot, RetainsCapAnnot, RetainsByNameAnnot)
19971998

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,9 +566,7 @@ object TypeOps:
566566
def avoid(tp: Type, symsToAvoid: => List[Symbol])(using Context): Type = {
567567
val widenMap = new AvoidMap {
568568
@threadUnsafe lazy val forbidden = symsToAvoid.toSet
569-
def toAvoid(tp: NamedType) =
570-
val sym = tp.symbol
571-
forbidden.contains(sym)
569+
def toAvoid(tp: NamedType) = forbidden.contains(tp.symbol)
572570

573571
override def apply(tp: Type): Type = tp match
574572
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,10 +392,13 @@ abstract class Recheck extends Phase, SymTransformer:
392392
recheck(tree.rhs, lhsType.widen)
393393
defn.UnitType
394394

395+
protected def avoidLocals(tp: Type, symsToAvoid: => List[Symbol])(using Context): Type =
396+
TypeOps.avoid(tp, symsToAvoid)
397+
395398
private def recheckBlock(stats: List[Tree], expr: Tree, pt: Type)(using Context): Type =
396399
recheckStats(stats)
397400
val exprType = recheck(expr, pt)
398-
TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm))
401+
avoidLocals(exprType, localSyms(stats).filterConserve(_.isTerm))
399402

400403
def recheckBlock(tree: Block, pt: Type)(using Context): Type = tree match
401404
case Block((mdef : DefDef) :: Nil, closure: Closure) =>
@@ -439,7 +442,7 @@ abstract class Recheck extends Phase, SymTransformer:
439442
// for blocks since a case branch is of the form `return[MatchResultN] { ... }`
440443
// For source-level returns from methods, there's nothing to avoid, since the
441444
// result type of a method with a return must be given explicitly.
442-
def avoidMap = new TypeOps.AvoidMap:
445+
val avoidMap = new TypeOps.AvoidMap:
443446
def toAvoid(tp: NamedType) =
444447
tp.symbol.is(Case) && tp.symbol.owner.isContainedIn(ctx.owner)
445448

0 commit comments

Comments
 (0)