44import Accelerate
55import CoreML
66
7+ /// How to space timesteps for inference
8+ public enum TimeStepSpacing {
9+ case linspace
10+ case leading
11+ case karras
12+ }
13+
714/// A scheduler used to compute a de-noised image
815///
916/// This implementation matches:
@@ -32,6 +39,8 @@ public final class DPMSolverMultistepScheduler: Scheduler {
3239 public let solverOrder = 2
3340 private( set) var lowerOrderStepped = 0
3441
42+ private var usingKarrasSigmas = false
43+
3544 /// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps.
3645 /// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps.
3746 public let useLowerOrderFinal = true
@@ -47,13 +56,15 @@ public final class DPMSolverMultistepScheduler: Scheduler {
4756 /// - betaSchedule: Method to schedule betas from betaStart to betaEnd
4857 /// - betaStart: The starting value of beta for inference
4958 /// - betaEnd: The end value for beta for inference
59+ /// - timeStepSpacing: How to space time steps
5060 /// - Returns: A scheduler ready for its first step
5161 public init (
5262 stepCount: Int = 50 ,
5363 trainStepCount: Int = 1000 ,
5464 betaSchedule: BetaSchedule = . scaledLinear,
5565 betaStart: Float = 0.00085 ,
56- betaEnd: Float = 0.012
66+ betaEnd: Float = 0.012 ,
67+ timeStepSpacing: TimeStepSpacing = . linspace
5768 ) {
5869 self . trainStepCount = trainStepCount
5970 self . inferenceStepCount = stepCount
@@ -72,20 +83,60 @@ public final class DPMSolverMultistepScheduler: Scheduler {
7283 }
7384 self . alphasCumProd = alphasCumProd
7485
75- // Currently we only support VP-type noise shedule
76- self . alpha_t = vForce. sqrt ( self . alphasCumProd)
77- self . sigma_t = vForce. sqrt ( vDSP. subtract ( [ Float] ( repeating: 1 , count: self . alphasCumProd. count) , self . alphasCumProd) )
78- self . lambda_t = zip ( self . alpha_t, self . sigma_t) . map { α, σ in log ( α) - log( σ) }
86+ switch timeStepSpacing {
87+ case . linspace:
88+ self . timeSteps = linspace ( 0 , Float ( self . trainStepCount- 1 ) , stepCount+ 1 ) . dropFirst ( ) . reversed ( ) . map { Int ( round ( $0) ) }
89+ self . alpha_t = vForce. sqrt ( self . alphasCumProd)
90+ self . sigma_t = vForce. sqrt ( vDSP. subtract ( [ Float] ( repeating: 1 , count: self . alphasCumProd. count) , self . alphasCumProd) )
91+ case . leading:
92+ let lastTimeStep = trainStepCount - 1
93+ let stepRatio = lastTimeStep / ( stepCount + 1 )
94+ // Creates integer timesteps by multiplying by ratio
95+ self . timeSteps = ( 0 ... stepCount) . map { 1 + $0 * stepRatio } . dropFirst ( ) . reversed ( )
96+ self . alpha_t = vForce. sqrt ( self . alphasCumProd)
97+ self . sigma_t = vForce. sqrt ( vDSP. subtract ( [ Float] ( repeating: 1 , count: self . alphasCumProd. count) , self . alphasCumProd) )
98+ case . karras:
99+ // sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
100+ let scaled = vDSP. multiply (
101+ subtraction: ( [ Float] ( repeating: 1 , count: self . alphasCumProd. count) , self . alphasCumProd) ,
102+ subtraction: ( vDSP. divide ( 1 , self . alphasCumProd) , [ Float] ( repeating: 0 , count: self . alphasCumProd. count) )
103+ )
104+ let sigmas = vForce. sqrt ( scaled)
105+ let logSigmas = sigmas. map { log ( $0) }
106+
107+ let sigmaMin = sigmas. first!
108+ let sigmaMax = sigmas. last!
109+ let rho : Float = 7
110+ let ramp = linspace ( 0 , 1 , stepCount)
111+ let minInvRho = pow ( sigmaMin, ( 1 / rho) )
112+ let maxInvRho = pow ( sigmaMax, ( 1 / rho) )
79113
80- self . timeSteps = linspace ( 0 , Float ( self . trainStepCount- 1 ) , stepCount+ 1 ) . dropFirst ( ) . reversed ( ) . map { Int ( round ( $0) ) }
114+ var karrasSigmas = ramp. map { pow ( maxInvRho + $0 * ( minInvRho - maxInvRho) , rho) }
115+ let karrasTimeSteps = karrasSigmas. map { sigmaToTimestep ( sigma: $0, logSigmas: logSigmas) }
116+ self . timeSteps = karrasTimeSteps
117+
118+ karrasSigmas. append ( karrasSigmas. last!)
119+
120+ self . alpha_t = vDSP. divide ( 1 , vForce. sqrt ( vDSP. add ( 1 , vDSP. square ( karrasSigmas) ) ) )
121+ self . sigma_t = vDSP. multiply ( karrasSigmas, self . alpha_t)
122+ usingKarrasSigmas = true
123+ }
124+
125+ self . lambda_t = zip ( self . alpha_t, self . sigma_t) . map { α, σ in log ( α) - log( σ) }
126+ }
127+
128+ func timestepToIndex( _ timestep: Int ) -> Int {
129+ guard usingKarrasSigmas else { return timestep }
130+ return self . timeSteps. firstIndex ( of: timestep) ?? 0
81131 }
82132
83133 /// Convert the model output to the corresponding type the algorithm needs.
84134 /// This implementation is for second-order DPM-Solver++ assuming epsilon prediction.
85135 func convertModelOutput( modelOutput: MLShapedArray < Float32 > , timestep: Int , sample: MLShapedArray < Float32 > ) -> MLShapedArray < Float32 > {
86136 assert ( modelOutput. scalarCount == sample. scalarCount)
87137 let scalarCount = modelOutput. scalarCount
88- let ( alpha_t, sigma_t) = ( self . alpha_t [ timestep] , self . sigma_t [ timestep] )
138+ let sigmaIndex = timestepToIndex ( timestep)
139+ let ( alpha_t, sigma_t) = ( self . alpha_t [ sigmaIndex] , self . sigma_t [ sigmaIndex] )
89140
90141 return MLShapedArray ( unsafeUninitializedShape: modelOutput. shape) { scalars, _ in
91142 assert ( scalars. count == scalarCount)
@@ -108,9 +159,11 @@ public final class DPMSolverMultistepScheduler: Scheduler {
108159 prevTimestep: Int ,
109160 sample: MLShapedArray < Float32 >
110161 ) -> MLShapedArray < Float32 > {
111- let ( p_lambda_t, lambda_s) = ( Double ( lambda_t [ prevTimestep] ) , Double ( lambda_t [ timestep] ) )
112- let p_alpha_t = Double ( alpha_t [ prevTimestep] )
113- let ( p_sigma_t, sigma_s) = ( Double ( sigma_t [ prevTimestep] ) , Double ( sigma_t [ timestep] ) )
162+ let prevIndex = timestepToIndex ( prevTimestep)
163+ let currIndex = timestepToIndex ( timestep)
164+ let ( p_lambda_t, lambda_s) = ( Double ( lambda_t [ prevIndex] ) , Double ( lambda_t [ currIndex] ) )
165+ let p_alpha_t = Double ( alpha_t [ prevIndex] )
166+ let ( p_sigma_t, sigma_s) = ( Double ( sigma_t [ prevIndex] ) , Double ( sigma_t [ currIndex] ) )
114167 let h = p_lambda_t - lambda_s
115168 // x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
116169 let x_t = weightedSum (
@@ -130,9 +183,13 @@ public final class DPMSolverMultistepScheduler: Scheduler {
130183 ) -> MLShapedArray < Float32 > {
131184 let ( s0, s1) = ( timesteps [ back: 1 ] , timesteps [ back: 2 ] )
132185 let ( m0, m1) = ( modelOutputs [ back: 1 ] , modelOutputs [ back: 2 ] )
133- let ( p_lambda_t, lambda_s0, lambda_s1) = ( Double ( lambda_t [ t] ) , Double ( lambda_t [ s0] ) , Double ( lambda_t [ s1] ) )
134- let p_alpha_t = Double ( alpha_t [ t] )
135- let ( p_sigma_t, sigma_s0) = ( Double ( sigma_t [ t] ) , Double ( sigma_t [ s0] ) )
186+ let ( p_lambda_t, lambda_s0, lambda_s1) = (
187+ Double ( lambda_t [ timestepToIndex ( t) ] ) ,
188+ Double ( lambda_t [ timestepToIndex ( s0) ] ) ,
189+ Double ( lambda_t [ timestepToIndex ( s1) ] )
190+ )
191+ let p_alpha_t = Double ( alpha_t [ timestepToIndex ( t) ] )
192+ let ( p_sigma_t, sigma_s0) = ( Double ( sigma_t [ timestepToIndex ( t) ] ) , Double ( sigma_t [ timestepToIndex ( s0) ] ) )
136193 let ( h, h_0) = ( p_lambda_t - lambda_s0, lambda_s0 - lambda_s1)
137194 let r0 = h_0 / h
138195 let D0 = m0
@@ -186,3 +243,31 @@ public final class DPMSolverMultistepScheduler: Scheduler {
186243 return prevSample
187244 }
188245}
246+
247+ func sigmaToTimestep( sigma: Float , logSigmas: [ Float ] ) -> Int {
248+ let logSigma = log ( sigma)
249+ let dists = logSigmas. map { logSigma - $0 }
250+
251+ // last index that is not negative, clipped to last index - 1
252+ var lowIndex = dists. reduce ( - 1 ) { partialResult, dist in
253+ return dist >= 0 && partialResult < dists. endIndex- 2 ? partialResult + 1 : partialResult
254+ }
255+ lowIndex = max ( lowIndex, 0 )
256+ let highIndex = lowIndex + 1
257+
258+ let low = logSigmas [ lowIndex]
259+ let high = logSigmas [ highIndex]
260+
261+ // Interpolate sigmas
262+ let w = ( ( low - logSigma) / ( low - high) ) . clipped ( to: 0 ... 1 )
263+
264+ // transform interpolated value to time range
265+ let t = ( 1 - w) * Float( lowIndex) + w * Float( highIndex)
266+ return Int ( round ( t) )
267+ }
268+
269+ extension FloatingPoint {
270+ func clipped( to range: ClosedRange < Self > ) -> Self {
271+ return min ( max ( self , range. lowerBound) , range. upperBound)
272+ }
273+ }
0 commit comments