From c016afc1a8f95f63b415a9d7e9e0b988695fdf34 Mon Sep 17 00:00:00 2001
From: Michael Grosse Huelsewiesche <mihuelsewiesche@twilio.com>
Date: Tue, 3 Dec 2024 12:23:39 -0500
Subject: [PATCH 1/2] Improving threading safety for telemetry

---
 .../analytics/kotlin/core/Telemetry.kt        | 59 +++++++++++--------
 .../analytics/kotlin/core/TelemetryTest.kt    | 12 ++--
 2 files changed, 41 insertions(+), 30 deletions(-)

diff --git a/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt b/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt
index 7b7210c4..5f51343d 100644
--- a/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt
+++ b/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt
@@ -14,6 +14,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.concurrent.Executors
 import kotlin.math.min
 import kotlin.math.roundToInt
+import java.util.concurrent.atomic.AtomicBoolean
 
 class MetricsRequestFactory : RequestFactory() {
     override fun upload(apiHost: String): HttpURLConnection {
@@ -76,7 +77,14 @@ object Telemetry: Subscriber {
     var host: String = Constants.DEFAULT_API_HOST
     // 1.0 is 100%, will get set by Segment setting before start()
     // Values are adjusted by the sampleRate on send
-    var sampleRate: Double = 1.0
+    @Volatile private var _sampleRate: Double = 1.0
+    var sampleRate: Double
+        get() = _sampleRate
+        set(value) {
+            synchronized(this) {
+                _sampleRate = value
+            }
+        }
     var flushTimer: Int = 30 * 1000 // 30s
     var httpClient: HTTPClient = HTTPClient("", MetricsRequestFactory())
     var sendWriteKeyOnError: Boolean = true
@@ -93,9 +101,9 @@ object Telemetry: Subscriber {
 
     private val queue = ConcurrentLinkedQueue<RemoteMetric>()
     private var queueBytes = 0
-    private var started = false
+    private var started = AtomicBoolean(false)
     private var rateLimitEndTime: Long = 0
-    private var flushFirstError = true
+    private var flushFirstError = AtomicBoolean(true)
     private val exceptionHandler = CoroutineExceptionHandler { _, t ->
         errorHandler?.let {
             it( Exception(
@@ -113,8 +121,8 @@ object Telemetry: Subscriber {
      * Called automatically when Telemetry.enable is set to true and when configuration data is received from Segment.
      */
     fun start() {
-        if (!enable || started || sampleRate == 0.0) return
-        started = true
+        if (!enable || started.get() || sampleRate == 0.0) return
+        started.set(true)
 
         // Everything queued was sampled at default 100%, downsample adjustment and send will adjust values
         if (Math.random() > sampleRate) {
@@ -124,7 +132,7 @@ object Telemetry: Subscriber {
         telemetryJob = telemetryScope.launch(telemetryDispatcher) {
             while (isActive) {
                 if (!enable) {
-                    started = false
+                    started.set(false)
                     return@launch
                 }
                 try {
@@ -148,7 +156,7 @@ object Telemetry: Subscriber {
     fun reset() {
         telemetryJob?.cancel()
         resetQueue()
-        started = false
+        started.set(false)
         rateLimitEndTime = 0
     }
 
@@ -202,8 +210,8 @@ object Telemetry: Subscriber {
 
         addRemoteMetric(metric, filteredTags, log=logData)
 
-        if(flushFirstError) {
-            flushFirstError = false
+        if(flushFirstError.get()) {
+            flushFirstError.set(false)
             flush()
         }
     }
@@ -218,7 +226,6 @@ object Telemetry: Subscriber {
 
         try {
             send()
-            queueBytes = 0
         } catch (error: Throwable) {
             errorHandler?.invoke(error)
             sampleRate = 0.0
@@ -227,16 +234,14 @@ object Telemetry: Subscriber {
 
     private fun send() {
         if (sampleRate == 0.0) return
-        var queueCount = queue.size
-        // Reset queue data size counter since all current queue items will be removed
-        queueBytes = 0
-        val sendQueue = mutableListOf<RemoteMetric>()
-        while (queueCount-- > 0 && !queue.isEmpty()) {
-            val m = queue.poll()
-            if(m != null) {
-                m.value = (m.value / sampleRate).roundToInt()
-                sendQueue.add(m)
-            }
+        val sendQueue: MutableList<RemoteMetric>
+        synchronized(queue) {
+            sendQueue = queue.toMutableList()
+            queue.clear()
+            queueBytes = 0
+        }
+        sendQueue.forEach { m ->
+            m.value = (m.value / sampleRate).roundToInt()
         }
         try {
             // Json.encodeToString by default does not include default values
@@ -309,9 +314,11 @@ object Telemetry: Subscriber {
             tags = fullTags
         )
         val newMetricSize = newMetric.toString().toByteArray().size
-        if (queueBytes + newMetricSize <= maxQueueBytes) {
-            queue.add(newMetric)
-            queueBytes += newMetricSize
+        synchronized(queue) {
+            if (queueBytes + newMetricSize <= maxQueueBytes) {
+                queue.add(newMetric)
+                queueBytes += newMetricSize
+            }
         }
     }
 
@@ -338,7 +345,9 @@ object Telemetry: Subscriber {
     }
 
     private fun resetQueue() {
-        queue.clear()
-        queueBytes = 0
+        synchronized(queue) {
+            queue.clear()
+            queueBytes = 0
+        }
     }
 }
\ No newline at end of file
diff --git a/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt b/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt
index df1ff354..ea2ed543 100644
--- a/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt
+++ b/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt
@@ -10,13 +10,15 @@ import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.concurrent.CountDownLatch
 import java.util.concurrent.Executors
 import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
 import kotlin.random.Random
 
 class TelemetryTest {
     fun TelemetryResetFlushFirstError() {
         val field: Field = Telemetry::class.java.getDeclaredField("flushFirstError")
         field.isAccessible = true
-        field.set(true, true)
+        val atomicBoolean = field.get(Telemetry) as AtomicBoolean
+        atomicBoolean.set(true)
     }
     fun TelemetryQueueSize(): Int {
         val queueField: Field = Telemetry::class.java.getDeclaredField("queue")
@@ -29,11 +31,11 @@ class TelemetryTest {
         queueBytesField.isAccessible = true
         return queueBytesField.get(Telemetry) as Int
     }
-    var TelemetryStarted: Boolean
+    var TelemetryStarted: AtomicBoolean
         get() {
             val startedField: Field = Telemetry::class.java.getDeclaredField("started")
             startedField.isAccessible = true
-            return startedField.get(Telemetry) as Boolean
+            return startedField.get(Telemetry) as AtomicBoolean
         }
         set(value) {
             val startedField: Field = Telemetry::class.java.getDeclaredField("started")
@@ -78,11 +80,11 @@ class TelemetryTest {
         Telemetry.sampleRate = 0.0
         Telemetry.enable = true
         Telemetry.start()
-        assertEquals(false, TelemetryStarted)
+        assertEquals(false, TelemetryStarted.get())
 
         Telemetry.sampleRate = 1.0
         Telemetry.start()
-        assertEquals(true, TelemetryStarted)
+        assertEquals(true, TelemetryStarted.get())
         assertEquals(0,errors.size)
     }
 

From c03f7a017cab13731407127bc163960c0279d48e Mon Sep 17 00:00:00 2001
From: Michael Grosse Huelsewiesche <mihuelsewiesche@twilio.com>
Date: Wed, 4 Dec 2024 10:56:50 -0500
Subject: [PATCH 2/2] Another attempt without synchronized, tests need work

---
 .../analytics/kotlin/core/Telemetry.kt        | 93 +++++++++++--------
 .../analytics/kotlin/core/TelemetryTest.kt    | 29 ++++--
 2 files changed, 77 insertions(+), 45 deletions(-)

diff --git a/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt b/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt
index 5f51343d..fa96249b 100644
--- a/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt
+++ b/core/src/main/java/com/segment/analytics/kotlin/core/Telemetry.kt
@@ -15,6 +15,9 @@ import java.util.concurrent.Executors
 import kotlin.math.min
 import kotlin.math.roundToInt
 import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.atomic.AtomicReference
+import kotlinx.coroutines.channels.Channel
 
 class MetricsRequestFactory : RequestFactory() {
     override fun upload(apiHost: String): HttpURLConnection {
@@ -77,21 +80,14 @@ object Telemetry: Subscriber {
     var host: String = Constants.DEFAULT_API_HOST
     // 1.0 is 100%, will get set by Segment setting before start()
     // Values are adjusted by the sampleRate on send
-    @Volatile private var _sampleRate: Double = 1.0
-    var sampleRate: Double
-        get() = _sampleRate
-        set(value) {
-            synchronized(this) {
-                _sampleRate = value
-            }
-        }
-    var flushTimer: Int = 30 * 1000 // 30s
+    private var sampleRate = AtomicReference<Double>(1.0)
+    private var flushTimer: Int = 30 * 1000 // 30s
     var httpClient: HTTPClient = HTTPClient("", MetricsRequestFactory())
     var sendWriteKeyOnError: Boolean = true
     var sendErrorLogData: Boolean = false
     var errorHandler: ((Throwable) -> Unit)? = ::logError
-    var maxQueueSize: Int = 20
-    var errorLogSizeMax: Int = 4000
+    private var maxQueueSize: Int = 20
+    private var errorLogSizeMax: Int = 4000
 
     private const val MAX_QUEUE_BYTES = 28000
     var maxQueueBytes: Int = MAX_QUEUE_BYTES
@@ -100,7 +96,7 @@ object Telemetry: Subscriber {
         }
 
     private val queue = ConcurrentLinkedQueue<RemoteMetric>()
-    private var queueBytes = 0
+    private var queueBytes = AtomicInteger(0)
     private var started = AtomicBoolean(false)
     private var rateLimitEndTime: Long = 0
     private var flushFirstError = AtomicBoolean(true)
@@ -116,16 +112,27 @@ object Telemetry: Subscriber {
     private var telemetryDispatcher: ExecutorCoroutineDispatcher = Executors.newSingleThreadExecutor().asCoroutineDispatcher()
     private var telemetryJob: Job? = null
 
+    private val flushChannel = Channel<Unit>(Channel.UNLIMITED)
+
+    // Start a coroutine to process flush requests
+    init {
+        telemetryScope.launch(telemetryDispatcher) {
+            for (event in flushChannel) {
+                performFlush()
+            }
+        }
+    }
+
     /**
      * Starts the telemetry if it is enabled and not already started, and the sample rate is greater than 0.
      * Called automatically when Telemetry.enable is set to true and when configuration data is received from Segment.
      */
     fun start() {
-        if (!enable || started.get() || sampleRate == 0.0) return
+        if (!enable || started.get() || sampleRate.get() == 0.0) return
         started.set(true)
 
         // Everything queued was sampled at default 100%, downsample adjustment and send will adjust values
-        if (Math.random() > sampleRate) {
+        if (Math.random() > sampleRate.get()) {
             resetQueue()
         }
 
@@ -170,10 +177,10 @@ object Telemetry: Subscriber {
         val tags = mutableMapOf<String, String>()
         buildTags(tags)
 
-        if (!enable || sampleRate == 0.0) return
+        if (!enable || sampleRate.get() == 0.0) return
         if (!metric.startsWith(METRICS_BASE_TAG)) return
         if (tags.isEmpty()) return
-        if (Math.random() > sampleRate) return
+        if (Math.random() > sampleRate.get()) return
 
         addRemoteMetric(metric, tags)
     }
@@ -189,10 +196,10 @@ object Telemetry: Subscriber {
         val tags = mutableMapOf<String, String>()
         buildTags(tags)
 
-        if (!enable || sampleRate == 0.0) return
+        if (!enable || sampleRate.get() == 0.0) return
         if (!metric.startsWith(METRICS_BASE_TAG)) return
         if (tags.isEmpty()) return
-        if (Math.random() > sampleRate) return
+        if (Math.random() > sampleRate.get()) return
 
         var filteredTags = if(sendWriteKeyOnError) {
             tags.toMap()
@@ -216,33 +223,41 @@ object Telemetry: Subscriber {
         }
     }
 
-    @Synchronized
     fun flush() {
+        if (!enable) return
+        flushChannel.trySend(Unit)
+    }
+
+    private fun performFlush() {
         if (!enable || queue.isEmpty()) return
         if (rateLimitEndTime > (System.currentTimeMillis() / 1000).toInt()) {
             return
         }
         rateLimitEndTime = 0
-
+        flushFirstError.set(false)
         try {
             send()
         } catch (error: Throwable) {
             errorHandler?.invoke(error)
-            sampleRate = 0.0
+            sampleRate.set(0.0)
         }
     }
 
     private fun send() {
-        if (sampleRate == 0.0) return
-        val sendQueue: MutableList<RemoteMetric>
-        synchronized(queue) {
-            sendQueue = queue.toMutableList()
-            queue.clear()
-            queueBytes = 0
-        }
-        sendQueue.forEach { m ->
-            m.value = (m.value / sampleRate).roundToInt()
+        if (sampleRate.get() == 0.0) return
+        val sendQueue = mutableListOf<RemoteMetric>()
+        // Reset queue data size counter since all current queue items will be removed
+        queueBytes.set(0)
+        var queueCount = queue.size
+        while(queueCount > 0 && !queue.isEmpty()) {
+            --queueCount
+            val m = queue.poll()
+            if(m != null) {
+                m.value = (m.value / sampleRate.get()).roundToInt()
+                sendQueue.add(m)
+            }
         }
+        assert(queue.size == 0)
         try {
             // Json.encodeToString by default does not include default values
             //  We're using this to leave off the 'log' parameter if unset.
@@ -314,10 +329,12 @@ object Telemetry: Subscriber {
             tags = fullTags
         )
         val newMetricSize = newMetric.toString().toByteArray().size
-        synchronized(queue) {
-            if (queueBytes + newMetricSize <= maxQueueBytes) {
-                queue.add(newMetric)
-                queueBytes += newMetricSize
+        // Avoid synchronization issue by adding the size before checking.
+        if (queueBytes.addAndGet(newMetricSize) <= maxQueueBytes) {
+            queue.add(newMetric)
+        } else {
+            if(queueBytes.addAndGet(-newMetricSize) < 0) {
+                queueBytes.set(0)
             }
         }
     }
@@ -334,7 +351,7 @@ object Telemetry: Subscriber {
     private suspend fun systemUpdate(system: com.segment.analytics.kotlin.core.System) {
         system.settings?.let { settings ->
             settings.metrics["sampleRate"]?.jsonPrimitive?.double?.let {
-                sampleRate = it
+                sampleRate.set(it)
                 // We don't want to start telemetry until two conditions are met:
                 // Telemetry.enable is set to true
                 // Settings from the server have adjusted the sampleRate
@@ -345,9 +362,7 @@ object Telemetry: Subscriber {
     }
 
     private fun resetQueue() {
-        synchronized(queue) {
-            queue.clear()
-            queueBytes = 0
-        }
+        queue.clear()
+        queueBytes.set(0)
     }
 }
\ No newline at end of file
diff --git a/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt b/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt
index ea2ed543..8c0d39b5 100644
--- a/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt
+++ b/core/src/test/kotlin/com/segment/analytics/kotlin/core/TelemetryTest.kt
@@ -11,6 +11,7 @@ import java.util.concurrent.CountDownLatch
 import java.util.concurrent.Executors
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.AtomicReference
 import kotlin.random.Random
 
 class TelemetryTest {
@@ -31,6 +32,22 @@ class TelemetryTest {
         queueBytesField.isAccessible = true
         return queueBytesField.get(Telemetry) as Int
     }
+    fun TelemetryMaxQueueSize(): Int {
+        val maxQueueSizeField: Field = Telemetry::class.java.getDeclaredField("maxQueueSize")
+        maxQueueSizeField.isAccessible = true
+        return maxQueueSizeField.get(Telemetry) as Int
+    }
+    var TelemetrySampleRate: Double
+        get() {
+            val sampleRateField: Field = Telemetry::class.java.getDeclaredField("sampleRate")
+            sampleRateField.isAccessible = true
+            return (sampleRateField.get(Telemetry) as AtomicReference<Double>).get()
+        }
+        set(value) {
+            val sampleRateField: Field = Telemetry::class.java.getDeclaredField("sampleRate")
+            sampleRateField.isAccessible = true
+            (sampleRateField.get(Telemetry) as AtomicReference<Double>).set(value)
+        }
     var TelemetryStarted: AtomicBoolean
         get() {
             val startedField: Field = Telemetry::class.java.getDeclaredField("started")
@@ -69,7 +86,7 @@ class TelemetryTest {
         Telemetry.reset()
         Telemetry.errorHandler = ::errorHandler
         errors.clear()
-        Telemetry.sampleRate = 1.0
+        TelemetrySampleRate = 1.0
         MockKAnnotations.init(this)
         mockTelemetryHTTPClient()
         // Telemetry.enable = true <- this will call start(), so don't do it here
@@ -77,12 +94,12 @@ class TelemetryTest {
 
     @Test
     fun `Test telemetry start`() {
-        Telemetry.sampleRate = 0.0
+        TelemetrySampleRate = 0.0
         Telemetry.enable = true
         Telemetry.start()
         assertEquals(false, TelemetryStarted.get())
 
-        Telemetry.sampleRate = 1.0
+        TelemetrySampleRate = 1.0
         Telemetry.start()
         assertEquals(true, TelemetryStarted.get())
         assertEquals(0,errors.size)
@@ -186,11 +203,11 @@ class TelemetryTest {
     fun `Test increment and error methods when queue is full`() {
         Telemetry.enable = true
         Telemetry.start()
-        for (i in 1..Telemetry.maxQueueSize + 1) {
+        for (i in 1..TelemetryMaxQueueSize() + 1) {
             Telemetry.increment(Telemetry.INVOKE_METRIC) { it["test"] = "test" + i }
             Telemetry.error(Telemetry.INVOKE_ERROR_METRIC, "error") { it["error"] = "test" + i }
         }
-        assertEquals(Telemetry.maxQueueSize, TelemetryQueueSize())
+        assertEquals(TelemetryMaxQueueSize(), TelemetryQueueSize())
     }
 
     @Test
@@ -239,6 +256,6 @@ class TelemetryTest {
         } finally {
             executor.shutdown()
         }
-        assertTrue(TelemetryQueueSize() == Telemetry.maxQueueSize)
+        assertTrue(TelemetryQueueSize() == TelemetryMaxQueueSize())
     }
 }