From eb2fa8f1bf39a61d0634228539a00575e1ab1b94 Mon Sep 17 00:00:00 2001
From: Hunter Park <hpark@innovationdx.com>
Date: Sun, 29 Jul 2018 23:08:04 -0500
Subject: [PATCH 1/2] Fix __eq__ not logical, check if a UnitTagID exists in a
 unit group, as well as include unit test

---
 sc2/unit.py        | 3 +++
 sc2/units.py       | 6 ++++++
 test/test_units.py | 6 ++++++
 3 files changed, 15 insertions(+)

diff --git a/sc2/unit.py b/sc2/unit.py
index d54f0d522..77982e347 100644
--- a/sc2/unit.py
+++ b/sc2/unit.py
@@ -464,6 +464,9 @@ def __call__(self, ability, *args, **kwargs):
     def __repr__(self):
         return f"Unit(name={self.name !r}, tag={self.tag})"
 
+    def __eq__(self, ID):
+        return ID == self.type_id
+
 class UnitOrder(object):
     @classmethod
     def from_proto(cls, proto, game_data):
diff --git a/sc2/units.py b/sc2/units.py
index 612a73e69..1187a0fc4 100644
--- a/sc2/units.py
+++ b/sc2/units.py
@@ -39,6 +39,12 @@ def __sub__(self, other: "Units") -> "Units":
         units = [unit for unit in self if unit.tag not in tags]
         return Units(units, self.game_data)
 
+    def __contains__(self, TypeID):
+        for tag in self.tags:
+            if TypeID == self.find_by_tag(tag):
+                return True
+        return False
+
     @property
     def amount(self) -> int:
         return len(self)
diff --git a/test/test_units.py b/test/test_units.py
index 642be43ef..8121c882d 100644
--- a/test/test_units.py
+++ b/test/test_units.py
@@ -141,5 +141,11 @@ def test_owned(self):
     def test_enemy(self):
         self.assertEqual(self.marines.enemy, self.emptyUnitsGroup)
 
+    def test_contains(self):
+        self.assertTrue(UnitTypeId.MARINE in self.marines)
+
+    def test_not_contains(self):
+        self.assertFalse(UnitTypeId.ADEPT in self.marines)
+
 if __name__ == "__main__":
     unittest.main()
\ No newline at end of file

From 1a36707eb79e700d8108061e20caad30533932d8 Mon Sep 17 00:00:00 2001
From: Hunter Park <hpark@innovationdx.com>
Date: Mon, 30 Jul 2018 23:42:54 -0500
Subject: [PATCH 2/2] started improving ondistributed workers

---
 examples/distributed_workers.py | 26 +++++++++-
 sc2/bot_ai.py                   | 91 ++++++++++++++++++++++-----------
 2 files changed, 85 insertions(+), 32 deletions(-)

diff --git a/examples/distributed_workers.py b/examples/distributed_workers.py
index e7233288d..d75f0e469 100644
--- a/examples/distributed_workers.py
+++ b/examples/distributed_workers.py
@@ -2,12 +2,15 @@
 from sc2 import run_game, maps, Race, Difficulty
 from sc2.player import Bot, Computer
 from sc2.constants import *
+import logging
 
+logger = logging.getLogger(__name__)
 
 class TerranBot(sc2.BotAI):
     async def on_step(self, iteration):
         await self.distribute_workers()
         await self.build_supply()
+        await self.build_geyser()
         await self.build_workers()
         await self.expand()
 
@@ -28,8 +31,29 @@ async def build_supply(self):
                 if self.can_afford(UnitTypeId.SUPPLYDEPOT):
                     await self.build(UnitTypeId.SUPPLYDEPOT, near=cc.position.towards(self.game_info.map_center, 5))
 
+    async def build_geyser(self):
+        for CC in self.units(UnitTypeId.COMMANDCENTER):
+            # get the number of nearby refineries
+            nearby_refineries = []
+            for refinery in self.units(UnitTypeId.REFINERY):
+                logging.warn(CC.distance_to(refinery))
+                if CC.distance_to(refinery) < 10:
+                    nearby_refineries.append(refinery)
+            refineries_to_build = 2 - len(nearby_refineries)
+            if refineries_to_build <= 0:
+                return
 
-run_game(maps.get("Abyssal Reef LE"), [
+            # get a worker from the CC
+            scv = self.workers.closest_to(CC)
+
+            # get the closest geyser
+            target = self.state.vespene_geyser.closest_to(scv.position)
+            if scv.position.distance_to(target) < 25 and self.can_afford(UnitTypeId.REFINERY) and not self.already_pending(UnitTypeId.SUPPLYDEPOT):
+                err = await self.do(scv.build(UnitTypeId.REFINERY, target))
+
+
+
+run_game(maps.get("AbyssalReefLE"), [
     Bot(Race.Terran, TerranBot()),
     Computer(Race.Protoss, Difficulty.Medium)
 ], realtime=False)
diff --git a/sc2/bot_ai.py b/sc2/bot_ai.py
index fe7d74bcb..5659205fe 100644
--- a/sc2/bot_ai.py
+++ b/sc2/bot_ai.py
@@ -137,24 +137,14 @@ def is_near_to_expansion(t):
 
         return closest
 
-    async def distribute_workers(self):
-        """
-        Distributes workers across all the bases taken.
-        WARNING: This is quite slow when there are lots of workers or multiple bases.
-        """
-
-        # TODO:
-        # OPTIMIZE: Assign idle workers smarter
-        # OPTIMIZE: Never use same worker mutltiple times
 
-        expansion_locations = self.expansion_locations
-        owned_expansions = self.owned_expansions
-        worker_pool = []
-        for idle_worker in self.workers.idle:
-            mf = self.state.mineral_field.closest_to(idle_worker)
-            await self.do(idle_worker.gather(mf))
+    @staticmethod
+    def getNearestWorker(location, workers: list):
+        return sorted(workers, key=lambda worker: worker.distance_to(location))[0]
 
-        for location, townhall in owned_expansions.items():
+    # ensure that each townhall has the right amount of workers
+    def check_townhalls(self, worker_pool):
+        for location, townhall in self.owned_expansions.items():
             workers = self.workers.closer_than(20, location)
             actual = townhall.assigned_harvesters
             ideal = townhall.ideal_harvesters
@@ -162,6 +152,9 @@ async def distribute_workers(self):
             if actual > ideal:
                 worker_pool.extend(workers.random_group_of(min(excess, len(workers))))
                 continue
+
+    # ensure that each geyser has the right amount of workers
+    def check_geysers(self, worker_pool):
         for g in self.geysers:
             workers = self.workers.closer_than(5, g)
             actual = g.assigned_harvesters
@@ -171,35 +164,71 @@ async def distribute_workers(self):
                 worker_pool.extend(workers.random_group_of(min(excess, len(workers))))
                 continue
 
+    async def assign_geysers(self, worker_pool):
         for g in self.geysers:
             actual = g.assigned_harvesters
             ideal = g.ideal_harvesters
             deficit = ideal - actual
 
             for x in range(0, deficit):
-                if worker_pool:
-                    w = worker_pool.pop()
-                    if len(w.orders) == 1 and w.orders[0].ability.id in [AbilityId.HARVEST_RETURN]:
-                        await self.do(w.move(g))
-                        await self.do(w.return_resource(queue=True))
+                worker = None
+                if self.workers.idle:
+                    worker = self.getNearestWorker(g,  self.workers.idle)
+                    self.workers.idle.remove(worker)
+                elif worker_pool:
+                    worker = self.getNearestWorker(g, worker_pool)
+                    worker_pool.remove(worker)
+                if worker:
+                    if len(worker.orders) == 1 and worker.orders[0].ability.id in [AbilityId.HARVEST_RETURN]:
+                        await self.do(worker.move(g))
+                        await self.do(worker.return_resource(queue=True))
                     else:
-                        await self.do(w.gather(g))
+                        await self.do(worker.gather(g))
 
-        for location, townhall in owned_expansions.items():
+    async def assign_townhalls(self, worker_pool):
+        for location, townhall in self.owned_expansions.items():
             actual = townhall.assigned_harvesters
             ideal = townhall.ideal_harvesters
 
             deficit = ideal - actual
-            for x in range(0, deficit):
-                if worker_pool:
-                    w = worker_pool.pop()
+            for _ in range(0, deficit):
+                worker = None
+                if  self.workers.idle:
+                    worker = self.getNearestWorker(location,  self.workers.idle)
+                    self.workers.idle.remove(worker)
+                elif worker_pool:
+                    worker = self.getNearestWorker(location, worker_pool)
+                    worker_pool.remove(worker)
+                if worker:
                     mf = self.state.mineral_field.closest_to(townhall)
-                    if len(w.orders) == 1 and w.orders[0].ability.id in [AbilityId.HARVEST_RETURN]:
-                        await self.do(w.move(townhall))
-                        await self.do(w.return_resource(queue=True))
-                        await self.do(w.gather(mf, queue=True))
+                    if len(worker.orders) == 1 and worker.orders[0].ability.id in [AbilityId.HARVEST_RETURN]:
+                        await self.do(worker.move(townhall))
+                        await self.do(worker.return_resource(queue=True))
+                        await self.do(worker.gather(mf, queue=True))
                     else:
-                        await self.do(w.gather(mf))
+                        await self.do(worker.gather(mf))
+
+    async def distribute_workers(self, geyser_first: bool = False):
+        """
+        Distributes workers across all the bases taken.
+        WARNING: This is quite slow when there are lots of workers or multiple bases.
+        """
+
+        # TODO:
+        # OPTIMIZE: Assign idle workers smarter
+        # OPTIMIZE: Never use same worker multiple times
+
+        worker_pool = []
+
+        self.check_townhalls(worker_pool)
+        self.check_geysers(worker_pool)
+        if geyser_first:
+            await self.assign_geysers(worker_pool)
+            await self.assign_townhalls(worker_pool)
+        else:
+            await self.assign_townhalls(worker_pool)
+            await self.assign_geysers(worker_pool)
+
 
     @property
     def owned_expansions(self):