Skip to content

Commit

Permalink
Merge pull request #90 from lsst-camera-dh/LSSTTD-1531_stage_BOT_data…
Browse files Browse the repository at this point in the history
…_to_local_scratch

Lssttd 1531 stage bot data to local scratch
  • Loading branch information
jchiang87 authored Oct 17, 2020
2 parents dfe2890 + 72c5ef0 commit dc8ee6e
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from bot_eo_analyses import repackage_summary_files, \
run_python_task_or_cl_script
from raft_results_task import raft_results_task
from stage_bot_data import clean_up_scratch

def make_focal_plane_plots():
#
Expand Down Expand Up @@ -150,3 +151,6 @@ def make_focal_plane_plots():
f'--css_file {css_file} --htmldir {htmldir} --overwrite')
print(command)
subprocess.check_call(command, shell=True)

# Clean up scratch areas
clean_up_scratch(run_number)
6 changes: 4 additions & 2 deletions python/bot_eo_analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def get_mask_files(det_name):
mask_files = hj_fp_server.get_files('pixel_defects_BOT',
f'{det_name}*mask*.fits',
run=badpixel_run)
mask_files = siteUtils.get_scratch_files(mask_files)
print(f"Mask files from run {badpixel_run} and {det_name}:")
for item in mask_files:
print(item)
Expand All @@ -168,6 +169,7 @@ def get_mask_files(det_name):
rolloff_mask_files = hj_fp_server.get_files('bias_frame_BOT',
f'{det_name}_*mask*.fits',
run=bias_run)
rolloff_mask_files = siteUtils.get_scratch_files(rolloff_mask_files)
print(f"Edge rolloff mask file from run {bias_run} and {det_name}:")
for item in rolloff_mask_files:
print(item)
Expand Down Expand Up @@ -214,6 +216,7 @@ def medianed_dark_frame(det_name):
pattern = f'{det_name}_*_median_dark_current.fits'
filename = hj_fp_server.get_files('dark_current_BOT', pattern,
run=dark_run)[0]
filename = siteUtils.get_scratch_files([filename])[0]
print("Dark frame:")
print(filename)
return filename
Expand All @@ -237,6 +240,7 @@ def bias_filename(run, det_name):
filename = hj_fp_server.get_files('bias_frame_BOT',
f'*{det_name}*median_bias.fits',
run=bias_run)[0]
filename = siteUtils.get_scratch_files([filename])[0]
print("Bias frame:")
print(filename)
return filename
Expand Down Expand Up @@ -1146,7 +1150,6 @@ def run_jh_tasks(*jh_tasks, device_names=None, processes=None, walltime=3600):
These functions should take a device name as its only argument, and
the parallelization will take place over device_names.
Parameters
----------
jh_tasks: list-like container of functions
Expand All @@ -1168,7 +1171,6 @@ def run_jh_tasks(*jh_tasks, device_names=None, processes=None, walltime=3600):
------
parsl.app.errors.AppTimeout
Notes
-----
Because the number of jh_task functions can vary, the keyword arguments
Expand Down
69 changes: 58 additions & 11 deletions python/ssh_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import copy
import time
import json
import socket
import logging
import subprocess
Expand Down Expand Up @@ -87,13 +88,16 @@ def __init__(self, script, working_dir, setup, max_retries=1,
self.task_ids = dict()
self.log_files = dict()
self.retries = defaultdict(zero_func)
self.host_map = None

def make_log_file(self, task_id, clean_up=True):
def make_log_file(self, task_id, clean_up=True, params=None):
"""
Create a log filename from the task name and task_id and
clean up any existing log files in the logging directory.
"""
script, _, _ = self.params
if params is None:
params = self.params
script, _, _ = params
task_name = os.path.basename(script).split('.')[0]
log_file = os.path.join(self.log_dir, f'{task_name}_{task_id}.log')
self.task_ids[log_file] = task_id
Expand All @@ -102,14 +106,17 @@ def make_log_file(self, task_id, clean_up=True):
os.remove(log_file)
return log_file

def launch_script(self, remote_host, task_id, niceness=10, *args):
def launch_script(self, remote_host, task_id, *args, niceness=10,
params=None, wait=False):
"""
Function to launch the script as a remote process via ssh.
"""
logger = logging.getLogger('TaskRunner.launch_script')
logger.setLevel(logging.INFO)

script, working_dir, setup = self.params
if params is None:
params = self.params
script, working_dir, setup = params
log_file = self.log_files[task_id]
command = f'ssh {remote_host} '
command += f'"cd {working_dir}; source {setup}; '
Expand All @@ -119,7 +126,10 @@ def launch_script(self, remote_host, task_id, niceness=10, *args):
command += ' '.join([str(_) for _ in args])
command += r' && echo Task succeeded on \`hostname\`'
command += r' || echo Task failed on \`hostname\`)'
command += f' &>> {log_file}&"'
if wait:
command += f' &>> {log_file}"'
else:
command += f' &>> {log_file}&"'
if self.verbose:
logger.info(command)
logger.info('Launching %s on %s', script, remote_host)
Expand Down Expand Up @@ -177,7 +187,7 @@ def monitor_tasks(self, max_time=None, interval=1):
logger.info('Retrying tasks for: ')
for item in to_retry:
logger.info(' %s', item)
self.submit_jobs(to_retry)
self.submit_jobs(to_retry, retry=True)
time.sleep(interval)
messages = []
if log_files:
Expand All @@ -189,7 +199,41 @@ def monitor_tasks(self, max_time=None, interval=1):
if messages:
raise RuntimeError('\n'.join(messages))

def submit_jobs(self, device_names):
def stage_data(self, device_map_file='device_list_map.json'):
"""
Function to dispatch data staging script to the remote hosts.
"""
# Make inverse index of host -> list of devices and
# save as json for the staging script to use.
device_map = defaultdict(list)
for device_name, host in self.host_map.items():
device_map[host].append(device_name)
with open(device_map_file, 'w') as fd:
json.dump(dict(device_map), fd)
# Loop over hosts and launch the copy script on each host.
copy_script = os.path.join(os.environ['EOANALYSISJOBSDIR'],
'python', 'stage_bot_data.py')
# Set params to override self.params in self.make_log_file
# and self.launch_script
params = (copy_script, *self.params[1:])
# Loop over hosts and launch staging script.
with multiprocessing.Pool(processes=len(device_map)) as pool:
workers = []
for host in device_map:
if host not in self.log_files:
self.make_log_file(host, params=params)
args = (host, host)
kwds = dict(params=params, wait=True)
time.sleep(0.5)
workers.append(pool.apply_async(self.launch_script, args, kwds))
pool.close()
pool.join()
_ = [_.get() for _ in workers]

# Clear self.log_files of staging script entries.
self.log_files = dict()

def submit_jobs(self, device_names, retry=False):
"""
Submit a task script process for each device.
Expand All @@ -199,17 +243,20 @@ def submit_jobs(self, device_names):
List of devices for which the task script will be run.
"""
num_tasks = len(device_names)
if not retry:
self.host_map = dict(zip(device_names, self.remote_hosts))
if bool(os.environ.get('LCATR_STAGE_DATA', False)):
self.stage_data()

# Using multiprocessing allows one to launch the scripts much
# faster since it can be done asynchronously.
with multiprocessing.Pool(processes=num_tasks) as pool:
outputs = []
for device_name, remote_host in zip(device_names,
self.remote_hosts):
for device_name, remote_host in self.host_map.items():
if device_name not in self.log_files:
self.make_log_file(device_name)
args = remote_host, device_name
time.sleep(0.1)
time.sleep(0.5)
outputs.append(pool.apply_async(self.launch_script, args))
pool.close()
pool.join()
Expand Down Expand Up @@ -265,7 +312,7 @@ def ssh_device_analysis_pool(task_script, device_names, cwd='.', setup=None,
num_batches = 2 if ndev > 100 else 1

# Use override value from LCATR_NUM_BATCHES if it is set.
num_batches = os.environ.get('LCATR_NUM_BATCHES', num_batches)
num_batches = int(os.environ.get('LCATR_NUM_BATCHES', num_batches))
print("# devices, # batches, # hosts:",
ndev, num_batches, task_runner.remote_hosts.num_hosts)

Expand Down
Loading

0 comments on commit dc8ee6e

Please sign in to comment.