diff --git a/jina/parsers/peapods/pod.py b/jina/parsers/peapods/pod.py index 28b07314bf3e9..c297570a18f32 100644 --- a/jina/parsers/peapods/pod.py +++ b/jina/parsers/peapods/pod.py @@ -39,7 +39,7 @@ def mixin_base_pod_parser(parser): choices=list(PollingType), default=PollingType.ANY, help=''' -The polling strategy of the Pod (when `parallel>1`) +The polling strategy of the Pod (when `parallel>1`) - ANY: only one (whoever is idle) Pea polls the message - ALL: all Peas poll the message (like a broadcast) ''', @@ -64,12 +64,9 @@ def mixin_base_pod_parser(parser): ) gp.add_argument( '--peas-hosts', - action=KVAppendAction, - metavar='KEY: VALUE', - nargs='*', - help='''The hosts of the peas when parallel greater than 1, - pea have a new host address if the pea_id present in the map. - otherwise pea host will be identical to the host of pod. - Represented as a key value pair in argument. - key is the pea_id, and value is the host address.''', + nargs='+', + type=str, + help='''The hosts of the peas when parallel greater than 1. + Peas will be evenly distributed among the hosts. By default, + peas are running in the same host as the pod.''', ) diff --git a/jina/peapods/pods/helper.py b/jina/peapods/pods/helper.py index b94de28953705..879f5e3daf96c 100644 --- a/jina/peapods/pods/helper.py +++ b/jina/peapods/pods/helper.py @@ -1,6 +1,7 @@ import copy from argparse import Namespace from typing import List, Optional +from itertools import cycle from ... import __default_host__ from ...enums import SchedulerType, SocketType, PeaRoleType @@ -12,8 +13,15 @@ def _set_peas_args( args: Namespace, head_args: Optional[Namespace] = None, tail_args: Namespace = None ) -> List[Namespace]: result = [] + _host_list = ( + args.peas_hosts + if args.peas_hosts + else [ + args.host, + ] + ) - for idx in range(args.parallel): + for idx, pea_host in zip(range(args.parallel), cycle(_host_list)): _args = copy.deepcopy(args) if args.parallel > 1: @@ -21,7 +29,7 @@ def _set_peas_args( _args.pea_role = PeaRoleType.PARALLEL _args.identity = random_identity() if _args.peas_hosts: - _args.host = _args.peas_hosts.get(str(_args.pea_id), args.host) + _args.host = pea_host if _args.name: _args.name += f'/{_args.pea_id}' else: diff --git a/tests/distributed/test_index_query_with_shards/flow_distributed_peas_in_pod.yml b/tests/distributed/test_index_query_with_shards/flow_distributed_peas_in_pod.yml index 1781d104632e8..93ff97319bee1 100644 --- a/tests/distributed/test_index_query_with_shards/flow_distributed_peas_in_pod.yml +++ b/tests/distributed/test_index_query_with_shards/flow_distributed_peas_in_pod.yml @@ -18,7 +18,7 @@ pods: polling: all host: $JINA_INDEXER_HOST peas_hosts: - 1: $JINA_ENCODER_HOST + - $JINA_ENCODER_HOST port_expose: 8000 - name: slice uses: slice.yml diff --git a/tests/distributed/test_simple_distributed_with_shards/flow_distributed_peas_in_pod.yml b/tests/distributed/test_simple_distributed_with_shards/flow_distributed_peas_in_pod.yml index ecc1a27eeb5f1..11482f91b98ec 100644 --- a/tests/distributed/test_simple_distributed_with_shards/flow_distributed_peas_in_pod.yml +++ b/tests/distributed/test_simple_distributed_with_shards/flow_distributed_peas_in_pod.yml @@ -10,12 +10,12 @@ pods: parallel: 3 host: $JINA_POD1_HOST peas_hosts: - 1: $JINA_POD2_HOST + - $JINA_POD2_HOST port_expose: 8000 - name: pod2 uses: _pass parallel: 3 host: $JINA_POD2_HOST peas_hosts: - 1: $JINA_POD1_HOST + - $JINA_POD1_HOST port_expose: 8000 diff --git a/tests/unit/peapods/pods/test_pods.py b/tests/unit/peapods/pods/test_pods.py index 8498f54946c52..79401c648f949 100644 --- a/tests/unit/peapods/pods/test_pods.py +++ b/tests/unit/peapods/pods/test_pods.py @@ -173,7 +173,7 @@ def test_pod_args_remove_uses_ba(): def test_pod_remote_pea_without_parallel(): args = set_pod_parser().parse_args( - ['--peas-hosts', '1: 0.0.0.1', '--parallel', str(1)] + ['--peas-hosts', '0.0.0.1', '--parallel', str(1)] ) with Pod(args) as pod: peas = pod.peas @@ -196,7 +196,7 @@ def test_pod_remote_pea_parallel_pea_host_set_partially( expected_host_out, ): args = set_pod_parser().parse_args( - ['--peas-hosts', f'1: {pea1_host}', '--parallel', str(2), '--host', pod_host] + ['--peas-hosts', f'{pea1_host}', '--parallel', str(2), '--host', pod_host] ) assert args.host == pod_host pod = Pod(args) @@ -205,7 +205,7 @@ def test_pod_remote_pea_parallel_pea_host_set_partially( assert v.host == args.host else: for pea_arg in v: - if pea_arg.pea_id == 1: + if pea_arg.pea_id in (1, 2): assert pea_arg.host == pea1_host assert pea_arg.host_in == expected_host_in assert pea_arg.host_out == expected_host_out @@ -232,8 +232,8 @@ def test_pod_remote_pea_parallel_pea_host_set_completely( args = set_pod_parser().parse_args( [ '--peas-hosts', - f'1: {peas_hosts[0]}', - f'2: {peas_hosts[1]}', + f'{peas_hosts[0]}', + f'{peas_hosts[1]}', '--parallel', str(2), '--host',