diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py b/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py index bd17016bd..e1ead67d0 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py @@ -422,10 +422,20 @@ def _init_nodes(self): SparkNodeType.MASTER: master_node } + def _set_zone_from_props(self, prop_container: JSONPropertiesContainer): + """ + Extracts the 'zoneUri' from the properties container and updates the environment variable dictionary. + """ + if prop_container: + zone_uri = prop_container.get_value_silent('config', 'gceClusterConfig', 'zoneUri') + if zone_uri: + self.cli.env_vars['zone'] = FSUtil.get_resource_name(zone_uri) + def _init_connection(self, cluster_id: str = None, props: str = None) -> dict: cluster_args = super()._init_connection(cluster_id=cluster_id, props=props) - # propagate zone to the cluster + # extract and update zone to the environment variable and cluster + self._set_zone_from_props(cluster_args['props']) cluster_args.setdefault('zone', self.cli.get_env_var('zone')) return cluster_args @@ -514,6 +524,7 @@ class DataprocSavingsEstimator(SavingsEstimator): """ A class that calculates the savings based on Dataproc price provider """ + def _calculate_group_cost(self, cluster_inst: ClusterGetAccessor, node_type: SparkNodeType): nodes_cnt = cluster_inst.get_nodes_cnt(node_type) cores_count = cluster_inst.get_node_core_count(node_type)