Skip to content

DO NOT MERGE! Oetc #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ doc/_build
doc/generated
doc/api
.vscode
.idea
Highs.log
paper/
monkeytype.sqlite3
Expand Down
10 changes: 10 additions & 0 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)
from linopy.matrices import MatrixAccessor
from linopy.objective import Objective
from linopy.remote import OetCloudHandler
from linopy.solvers import available_solvers, quadratic_solvers
from linopy.types import (
ConstantLike,
Expand Down Expand Up @@ -951,6 +952,7 @@ def solve(
env: None = None,
sanitize_zeros: bool = True,
remote: None = None,
remote_settings: dict = None,
**solver_options,
) -> tuple[str, str]:
"""
Expand Down Expand Up @@ -1003,6 +1005,8 @@ def solve(
Remote handler to use for solving model on a server. Note that when
solving on a rSee
linopy.remote.RemoteHandler for more details.
remote_settings : dict, optional
Settings for the remote handler
**solver_options : kwargs
Options passed to the solver.

Expand All @@ -1015,6 +1019,10 @@ def solve(
# clear cached matrix properties potentially present from previous solve commands
self.matrices.clean_cached_properties()

if remote_settings:
remote = OetCloudHandler(settings=remote_settings["oetc"])
# TODO for consistency should we change the way the SSH remote handler is used to be to pass a similar string argument?

if remote:
solved = remote.solve_on_remote(
self,
Expand All @@ -1027,6 +1035,7 @@ def solve(
warmstart_fn=warmstart_fn,
keep_files=keep_files,
sanitize_zeros=sanitize_zeros,
structured_solver_options=solver_options, # TODO: When unifying remote handlers we should send solver options only once
**solver_options,
)

Expand All @@ -1038,6 +1047,7 @@ def solve(
for k, c in self.constraints.items():
if "dual" in solved.constraints[k]:
c.dual = solved.constraints[k].dual

return self.status, self.termination_condition

if len(available_solvers) == 0:
Expand Down
187 changes: 186 additions & 1 deletion linopy/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@

@author: fabian
"""

import gzip
import json
import logging
import tempfile
from dataclasses import dataclass
import os
from pathlib import Path

from google.cloud import storage
from google.cloud.storage import Client, Bucket, Blob
from google.oauth2 import service_account
from requests import Response, post, get, HTTPError
from time import sleep
from typing import Callable, Union

import linopy
from linopy.io import read_netcdf

paramiko_present = True
Expand Down Expand Up @@ -236,3 +246,178 @@ def solve_on_remote(self, model, **kwargs):
self.sftp_client.remove(self.model_solved_file)

return solved


# TODO perhaps we make RemoteHandler an abstract base class, and rename the class above as SshHandler,
# and have that and OetCloudHandler inherit from RemoteHandler?
class OetCloudHandler:

def __init__(self, settings: dict):
self.oetc_url = settings.get("url", "http://127.0.0.1:5000")
self.cpu_cores = settings.get("cpu_cores", 2)
self.ram_amount_gb = settings.get("ram_amount_gb", 4)
self.disk_space_gb = settings.get("disk_space_gb", 10)
# TODO: Temporary solution, in the future read service key from auth API
with open(
os.environ["GCP_SERVICE_KEY_PATH"], "r"
) as service_key_file:
self.gcp_service_key = json.load(service_key_file)
# TODO: Temporary solution, in the future use auth login to obtain the JWT
self.jwt = os.environ["OETC_JWT"]

def solve_on_remote(self, model, **kwargs):
"""
Solve a linopy model on the OET Cloud compute app.

Parameters
----------
model : linopy.model.Model
**kwargs :
Keyword arguments passed to `linopy.model.Model.solve`.

Returns
-------
linopy.model.Model
Solved model.
"""
logger.warning(f'Ignoring these kwargs for now: {kwargs}') # TODO

with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn:
model.to_netcdf(fn.name)
logger.info(f'Model written to: {fn.name}')
input_file_name = self._upload_file_to_gcp(fn.name)

job_uuid = self._submit_job(input_file_name, kwargs["solver_name"], kwargs["structured_solver_options"])
job_data = self._wait_and_get_job_data(job_uuid)

out_file_name = job_data['output_files'][0]['name']
out_file_name_path = self._download_file_from_gcp(out_file_name)

solution = linopy.read_netcdf(out_file_name_path)
os.remove(out_file_name_path)
logger.info(
f'OETC result: {solution.status}, {solution.termination_condition}, Objective: {solution.objective.value:.2e}'
)

return solution

def _submit_job(self, input_file_name, solver_name, solver_options):
logger.info('Calling OETC...')
try:
response: Response = post(
f"{self.oetc_url}/compute-job/create",
headers={"Authorization": f"Bearer {self.jwt}"},
json={
"solver": solver_name,
"solver_options": solver_options,
"cpu_cores": self.cpu_cores,
"ram_amount_gb": self.ram_amount_gb,
"disk_space_gb": self.disk_space_gb,
"input_file_name": input_file_name,
},
)
response.raise_for_status()
content = response.json()
logger.info(f'OETC job submitted successfully. ID: {content["uuid"]}')
return content['uuid']
except HTTPError as e:
logger.error(e.response.json())
raise

def _wait_and_get_job_data(self, uuid: str, retry_every_s=60) -> dict:
"""Waits for job completion until it completes or an error occurs,
waiting `retry_every_s` seconds in between retries; returns job data including output file download links."""
while True:
logger.info('Checking job status...')
response: Response = get(
f"{self.oetc_url}/compute-job/{uuid}",
headers={"Authorization": f"Bearer {self.jwt}"},
)
if not response.ok:
raise ValueError(f'OETC Error: {response.text}')
content = response.json()
if not "status" in content:
raise ValueError(f'Unexpected response: {response.text}')
if content['status'] == 'FINISHED':
logger.info('OETC completed job execution')
return content
elif content['status'] not in {'RUNNING', 'PENDING'}:
raise ValueError(f"Unexpected status: {content['status']}")
logger.info('OETC still crunching...')
sleep(retry_every_s)

def _gzip_compress(self, source_path: str) -> str:
output_path = source_path + ".gz"
chunk_size = 1024 * 1024

with open(source_path, "rb") as f_in:
with gzip.open(output_path, "wb", compresslevel=9) as f_out:
while True:
chunk = f_in.read(chunk_size)
if not chunk:
break
f_out.write(chunk)

return output_path

def _gzip_decompress(self, input_path: str) -> str:
output_path = str(Path(input_path).with_suffix(""))

chunk_size = 1024 * 1024

with gzip.open(input_path, "rb") as f_in:
with open(output_path, "wb") as f_out:
while True:
chunk = f_in.read(chunk_size)
if not chunk:
break
f_out.write(chunk)

return output_path

def _upload_file_to_gcp(self, file_path: str) -> str:
logger.info(f'Compressing model...')
compressed_file_path = self._gzip_compress(file_path)
logger.info(f'Model compressed to: {compressed_file_path}')
compressed_file_name = os.path.basename(compressed_file_path)

credentials = service_account.Credentials.from_service_account_info(
self.gcp_service_key,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)

storage_client: Client = storage.Client(credentials=credentials)
bucket: Bucket = storage_client.bucket(
"oetc_files" # TODO: Make bucket name dynamic if necessary
)
blob: Blob = bucket.blob(compressed_file_name)

logger.info(f"Uploading model from {compressed_file_path}...")
blob.upload_from_filename(compressed_file_path)
logger.info(f"Uploaded {compressed_file_name}")

return compressed_file_name

def _download_file_from_gcp(self, file_name: str) -> str:
credentials = service_account.Credentials.from_service_account_info(
self.gcp_service_key,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)

storage_client: Client = storage.Client(credentials=credentials)
bucket: Bucket = storage_client.bucket(
"oetc_files"
) # TODO: Make bucket name dynamic if necessary
blob: Blob = bucket.blob(file_name)

file_path: str = f"/tmp/{file_name}"

logger.info(f"Downloading model {file_name}")
blob.download_to_filename(file_path)
logger.info(f"Model saved at {file_path}")

logger.info(f"Decompressing .gz file {file_name}")
decompressed_file_path = self._gzip_decompress(file_path)
logger.info(f"Decompressed input file: {decompressed_file_path}")

return decompressed_file_path
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"polars",
"tqdm",
"deprecation",
"google-cloud-storage",
]

[project.urls]
Expand Down