From e102e352bebccd560e8917eecc4bf727dbd91953 Mon Sep 17 00:00:00 2001 From: ErinWeisbart Date: Thu, 16 Jun 2022 14:07:18 -0700 Subject: [PATCH] correct role handling --- run.py | 110 +++++++++++++++++++++++++++---------------- worker/run-worker.sh | 10 ++-- 2 files changed, 73 insertions(+), 47 deletions(-) diff --git a/run.py b/run.py index a0c3a64..c5e7071 100644 --- a/run.py +++ b/run.py @@ -60,14 +60,45 @@ # AUXILIARY FUNCTIONS ################################# -def get_aws_credentials(AWS_PROFILE): - session = boto3.Session(profile_name=AWS_PROFILE) - credentials = session.get_credentials() - return credentials.access_key, credentials.secret_key - def generate_task_definition(AWS_PROFILE): + taskRoleArn = False task_definition = TASK_DEFINITION.copy() - key, secret = get_aws_credentials(AWS_PROFILE) + + config = configparser.ConfigParser() + config.read(f"{os.environ['HOME']}/.aws/config") + + if config.has_section(AWS_PROFILE): + profile_name = AWS_PROFILE + elif config.has_section(f'profile {AWS_PROFILE}'): + profile_name = f'profile {AWS_PROFILE}' + else: + print ('Problem handling profile') + + if config.has_option(profile_name, 'role_arn'): + print ("Using role for credentials", config[profile_name]['role_arn']) + taskRoleArn = config[profile_name]['role_arn'] + else: + if config.has_option(profile_name, 'source_profile'): + creds = configparser.ConfigParser() + creds.read(f"{os.environ['HOME']}/.aws/credentials") + source_profile = config[profile_name]['source_profile'] + aws_access_key = creds[source_profile]['aws_access_key_id'] + aws_secret_key = creds[source_profile]['aws_secret_access_key'] + elif config.has_option(profile_name, 'aws_access_key_id'): + aws_access_key = config[profile_name]['aws_access_key_id'] + aws_secret_key = config[profile_name]['aws_secret_access_key'] + else: + print ("Problem getting credentials") + task_definition['containerDefinitions'][0]['environment'] += [ + { + "name": "AWS_ACCESS_KEY_ID", + "value": aws_access_key + }, + { + "name": "AWS_SECRET_ACCESS_KEY", + "value": aws_secret_key + }] + sqs = boto3.client('sqs') queue_name = get_queue_url(sqs) task_definition['containerDefinitions'][0]['environment'] += [ @@ -79,14 +110,6 @@ def generate_task_definition(AWS_PROFILE): 'name': 'SQS_QUEUE_URL', 'value': queue_name }, - { - "name": "AWS_ACCESS_KEY_ID", - "value": key - }, - { - "name": "AWS_SECRET_ACCESS_KEY", - "value": secret - }, { "name": "AWS_BUCKET", "value": AWS_BUCKET @@ -119,8 +142,13 @@ def generate_task_definition(AWS_PROFILE): return task_definition def update_ecs_task_definition(ecs, ECS_TASK_NAME, AWS_PROFILE): - task_definition = generate_task_definition(AWS_PROFILE) - ecs.register_task_definition(family=ECS_TASK_NAME,containerDefinitions=task_definition['containerDefinitions']) + task_definition, taskRoleArn = generate_task_definition(AWS_PROFILE) + if not taskRoleArn: + ecs.register_task_definition(family=ECS_TASK_NAME,containerDefinitions=task_definition['containerDefinitions']) + elif taskRoleArn: + ecs.register_task_definition(family=ECS_TASK_NAME,containerDefinitions=task_definition['containerDefinitions'],taskRoleArn=taskRoleArn) + else: + print('Mistake in handling role for Task Definition.') print('Task definition registered') def get_or_create_cluster(ecs): @@ -179,14 +207,14 @@ def killdeadAlarms(fleetId,monitorapp,ec2,cloud): todel.append(eachevent['EventInformation']['InstanceId']) existing_alarms = [x['AlarmName'] for x in cloud.describe_alarms(AlarmNamePrefix=monitorapp)['MetricAlarms']] - + for eachmachine in todel: monitorname = monitorapp+'_'+eachmachine if monitorname in existing_alarms: cloud.delete_alarms(AlarmNames=[monitorname]) print('Deleted', monitorname, 'if it existed') time.sleep(3) - + print('Old alarms deleted') def generateECSconfig(ECS_CLUSTER,APP_NAME,AWS_BUCKET,s3client): @@ -232,7 +260,7 @@ def removequeue(queueName): for eachUrl in queueoutput["QueueUrls"]: if eachUrl.split('/')[-1] == queueName: queueUrl=eachUrl - + sqs.delete_queue(QueueUrl=queueUrl) def deregistertask(taskName, ecs): @@ -262,7 +290,7 @@ def downscaleSpotFleet(queue, spotFleetID, ec2, manual=False): def export_logs(logs, loggroupId, starttime, bucketId): result = logs.create_export_task(taskName = loggroupId, logGroupName = loggroupId, fromTime = int(starttime), to = int(time.time()*1000), destination = bucketId, destinationPrefix = 'exportedlogs/'+loggroupId) - + logExportId = result['taskId'] while True: @@ -285,7 +313,7 @@ def __init__(self,name=None): self.queue = self.sqs.get_queue_by_name(QueueName=SQS_QUEUE_NAME) else: self.queue = self.sqs.get_queue_by_name(QueueName=name) - self.inProcess = -1 + self.inProcess = -1 self.pending = -1 def scheduleBatch(self, data): @@ -342,7 +370,7 @@ def submitJob(): # Step 1: Read the job configuration file jobInfo = loadConfig(sys.argv[2]) - templateMessage = {'Metadata': '', + templateMessage = {'Metadata': '', 'output_file_location': jobInfo["output_file_location"], 'shared_metadata': jobInfo["shared_metadata"] } @@ -357,7 +385,7 @@ def submitJob(): print('Job submitted. Check your queue') ################################# -# SERVICE 3: START CLUSTER +# SERVICE 3: START CLUSTER ################################# def startCluster(): @@ -376,7 +404,7 @@ def startCluster(): spotfleetConfig['SpotPrice'] = '%.2f' %MACHINE_PRICE DOCKER_BASE_SIZE = int(round(float(EBS_VOL_SIZE)/int(TASKS_PER_MACHINE))) - 2 userData=generateUserData(ecsConfigFile,DOCKER_BASE_SIZE) - for LaunchSpecification in range(0,len(spotfleetConfig['LaunchSpecifications'])): + for LaunchSpecification in range(0,len(spotfleetConfig['LaunchSpecifications'])): spotfleetConfig['LaunchSpecifications'][LaunchSpecification]["UserData"]=userData spotfleetConfig['LaunchSpecifications'][LaunchSpecification]['BlockDeviceMappings'][1]['Ebs']["VolumeSize"]= EBS_VOL_SIZE spotfleetConfig['LaunchSpecifications'][LaunchSpecification]['InstanceType'] = MACHINE_TYPE[LaunchSpecification] @@ -399,7 +427,7 @@ def startCluster(): createMonitor.write('"MONITOR_LOG_GROUP_NAME" : "'+LOG_GROUP_NAME+'",\n') createMonitor.write('"MONITOR_START_TIME" : "'+ starttime+'"}\n') createMonitor.close() - + # Step 4: Create a log group for this app and date if one does not already exist logclient=boto3.client('logs') loggroupinfo=logclient.describe_log_groups(logGroupNamePrefix=LOG_GROUP_NAME) @@ -410,13 +438,13 @@ def startCluster(): if LOG_GROUP_NAME+'_perInstance' not in groupnames: logclient.create_log_group(logGroupName=LOG_GROUP_NAME+'_perInstance') logclient.put_retention_policy(logGroupName=LOG_GROUP_NAME+'_perInstance', retentionInDays=60) - + # Step 5: update the ECS service to be ready to inject docker containers in EC2 instances print('Updating service') ecs = boto3.client('ecs') ecs.update_service(cluster=ECS_CLUSTER, service=APP_NAME+'Service', desiredCount=CLUSTER_MACHINES*TASKS_PER_MACHINE) - print('Service updated.') - + print('Service updated.') + # Step 6: Monitor the creation of the instances until all are present status = ec2client.describe_spot_fleet_instances(SpotFleetRequestId=requestInfo['SpotFleetRequestId']) #time.sleep(15) # This is now too fast, so sometimes the spot fleet request history throws an error! @@ -436,7 +464,7 @@ def startCluster(): return ec2client.cancel_spot_fleet_requests(SpotFleetRequestIds=[requestInfo['SpotFleetRequestId']], TerminateInstances=True) return - + # If everything seems good, just bide your time until you're ready to go print('.') time.sleep(20) @@ -445,21 +473,21 @@ def startCluster(): print('Spot fleet successfully created. Your job should start in a few minutes.') ################################# -# SERVICE 4: MONITOR JOB +# SERVICE 4: MONITOR JOB ################################# def monitor(cheapest=False): if len(sys.argv) < 3: print('Use: run.py monitor spotFleetIdFile') sys.exit() - + if '.json' not in sys.argv[2]: print('Use: run.py monitor spotFleetIdFile') sys.exit() if len(sys.argv) == 4: cheapest = sys.argv[3] - + monitorInfo = loadConfig(sys.argv[2]) monitorcluster=monitorInfo["MONITOR_ECS_CLUSTER"] monitorapp=monitorInfo["MONITOR_APP_NAME"] @@ -467,17 +495,17 @@ def monitor(cheapest=False): queueId=monitorInfo["MONITOR_QUEUE_NAME"] ec2 = boto3.client('ec2') - cloud = boto3.client('cloudwatch') + cloud = boto3.client('cloudwatch') # Optional Step 0 - decide if you're going to be cheap rather than fast. This means that you'll get 15 minutes # from the start of the monitor to get as many machines as you get, and then it will set the requested number to 1. - # Benefit: this will always be the cheapest possible way to run, because if machines die they'll die fast, - # Potential downside- if machines are at low availability when you start to run, you'll only ever get a small number + # Benefit: this will always be the cheapest possible way to run, because if machines die they'll die fast, + # Potential downside- if machines are at low availability when you start to run, you'll only ever get a small number # of machines (as opposed to getting more later when they become available), so it might take VERY long to run if that happens. if cheapest: queue = JobQueue(name=queueId) startcountdown = time.time() - while queue.pendingLoad(): + while queue.pendingLoad(): if time.time() - startcountdown > 900: downscaleSpotFleet(queue, fleetId, ec2, manual=1) break @@ -486,7 +514,7 @@ def monitor(cheapest=False): # Step 1: Create job and count messages periodically queue = JobQueue(name=queueId) while queue.pendingLoad(): - #Once an hour (except at midnight) check for terminated machines and delete their alarms. + #Once an hour (except at midnight) check for terminated machines and delete their alarms. #This is slooooooow, which is why we don't just do it at the end curtime=datetime.datetime.now().strftime('%H%M') if curtime[-2:]=='00': @@ -499,7 +527,7 @@ def monitor(cheapest=False): if curtime[-1:]=='9': downscaleSpotFleet(queue, fleetId, ec2) time.sleep(MONITOR_TIME) - + # Step 2: When no messages are pending, stop service # Reload the monitor info, because for long jobs new fleets may have been started, etc monitorInfo = loadConfig(sys.argv[2]) @@ -509,7 +537,7 @@ def monitor(cheapest=False): queueId=monitorInfo["MONITOR_QUEUE_NAME"] bucketId=monitorInfo["MONITOR_BUCKET_NAME"] loggroupId=monitorInfo["MONITOR_LOG_GROUP_NAME"] - starttime=monitorInfo["MONITOR_START_TIME"] + starttime=monitorInfo["MONITOR_START_TIME"] ecs = boto3.client('ecs') ecs.update_service(cluster=monitorcluster, service=monitorapp+'Service', desiredCount=0) @@ -560,14 +588,14 @@ def monitor(cheapest=False): print('All export tasks done') ################################# -# MAIN USER INTERACTION +# MAIN USER INTERACTION ################################# if __name__ == '__main__': if len(sys.argv) < 2: print('Use: run.py setup | submitJob | startCluster | monitor') sys.exit() - + if sys.argv[1] == 'setup': setup() elif sys.argv[1] == 'submitJob': diff --git a/worker/run-worker.sh b/worker/run-worker.sh index 996f9e1..5be5e4b 100644 --- a/worker/run-worker.sh +++ b/worker/run-worker.sh @@ -5,8 +5,6 @@ echo "Queue $SQS_QUEUE_URL" echo "Bucket $AWS_BUCKET" # 1. CONFIGURE AWS CLI -aws configure set aws_access_key_id $AWS_ACCESS_KEY_ID -aws configure set aws_secret_access_key $AWS_SECRET_ACCESS_KEY aws configure set default.region $AWS_REGION MY_INSTANCE_ID=$(curl http://169.254.169.254/latest/meta-data/instance-id) echo "Instance ID $MY_INSTANCE_ID" @@ -17,15 +15,15 @@ aws ec2 create-tags --resources $VOL_0_ID --tags Key=Name,Value=${APP_NAME}Worke VOL_1_ID=$(aws ec2 describe-instance-attribute --instance-id $MY_INSTANCE_ID --attribute blockDeviceMapping --output text --query BlockDeviceMappings[1].Ebs.[VolumeId]) aws ec2 create-tags --resources $VOL_1_ID --tags Key=Name,Value=${APP_NAME}Worker -# 2. MOUNT S3 +# 2. MOUNT S3 echo $AWS_ACCESS_KEY_ID:$AWS_SECRET_ACCESS_KEY > /credentials.txt chmod 600 /credentials.txt mkdir -p /home/ubuntu/bucket mkdir -p /home/ubuntu/local_output -stdbuf -o0 s3fs $AWS_BUCKET /home/ubuntu/bucket -o passwd_file=/credentials.txt +stdbuf -o0 s3fs $AWS_BUCKET /home/ubuntu/bucket -o passwd_file=/credentials.txt # 3. SET UP ALARMS -aws cloudwatch put-metric-alarm --alarm-name ${APP_NAME}_${MY_INSTANCE_ID} --alarm-actions arn:aws:swf:${AWS_REGION}:${OWNER_ID}:action/actions/AWS_EC2.InstanceId.Terminate/1.0 --statistic Maximum --period 60 --threshold 1 --comparison-operator LessThanThreshold --metric-name CPUUtilization --namespace AWS/EC2 --evaluation-periods 15 --dimensions "Name=InstanceId,Value=${MY_INSTANCE_ID}" +aws cloudwatch put-metric-alarm --alarm-name ${APP_NAME}_${MY_INSTANCE_ID} --alarm-actions arn:aws:swf:${AWS_REGION}:${OWNER_ID}:action/actions/AWS_EC2.InstanceId.Terminate/1.0 --statistic Maximum --period 60 --threshold 1 --comparison-operator LessThanThreshold --metric-name CPUUtilization --namespace AWS/EC2 --evaluation-periods 15 --dimensions "Name=InstanceId,Value=${MY_INSTANCE_ID}" # 4. DOWNLOAD PLUGIN FILE wget -P /opt/fiji/Fiji.app/plugins/ $SCRIPT_DOWNLOAD_URL @@ -35,4 +33,4 @@ wget -P /opt/fiji/Fiji.app/plugins/ $SCRIPT_DOWNLOAD_URL python3 instance-monitor.py & # 6. RUN FIJI WORKER -python3 fiji-worker.py |& tee $k.out +python3 fiji-worker.py |& tee $k.out