diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index 123e0c005..e9d9117b3 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -70,7 +70,7 @@ def bootloader_shell(self): pass yield self.serial - def flash( + def flash( # noqa: C901 self, path: PathBuf, *, @@ -81,10 +81,19 @@ def flash( force_flash_bundle: str | None = None, cacert_file: str | None = None, insecure_tls: bool = False, + headers: dict[str, str] | None = None, + bearer_token: str | None = None, ): + if bearer_token: + bearer_token = self._validate_bearer_token(bearer_token) + + if headers: + headers = self._validate_header_dict(headers) + """Flash image to DUT""" should_download_to_httpd = True image_url = "" + original_http_url = None operator_scheme = None # initrmafs cannot handle https yet, fallback to using the exporter's http server if path.startswith(("http://", "https://")) and not force_exporter_http: @@ -94,7 +103,17 @@ def flash( else: # use the exporter's http server for the flasher image, we should download it first if operator is None: - path, operator, operator_scheme = operator_for_path(path) + if path.startswith(("http://", "https://")) and bearer_token: + parsed = urlparse(path) + self.logger.info(f"Using Bearer token authentication for {parsed.netloc}") + original_http_url = path + operator = Operator( + "http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", token=bearer_token + ) + operator_scheme = "http" + path = Path(parsed.path) + else: + path, operator, operator_scheme = operator_for_path(path) image_url = self.http.get_url() + "/" + path.name # start counting time for the flash operation @@ -107,7 +126,16 @@ def flash( # Start the storage write operation in the background storage_thread = threading.Thread( target=self._transfer_bg_thread, - args=(path, operator, operator_scheme, os_image_checksum, self.http.storage, error_queue, image_url), + args=( + path, + operator, + operator_scheme, + os_image_checksum, + self.http.storage, + error_queue, + original_http_url, + headers, + ), name="storage_transfer", ) storage_thread.start() @@ -152,9 +180,17 @@ def flash( else: stored_cacert = self._setup_flasher_ssl(console, manifest, cacert_file) - - self._flash_with_progress(console, manifest, path, image_url, target_device, - insecure_tls, stored_cacert) + header_args = self._prepare_headers(headers, bearer_token) + self._flash_with_progress( + console, + manifest, + path, + image_url, + target_device, + insecure_tls, + stored_cacert, + header_args, + ) total_time = time.time() - start_time # total time in minutes:seconds @@ -222,7 +258,36 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str: tls_args += f"--cacert {stored_cacert} " return tls_args.strip() - def _flash_with_progress(self, console, manifest, path, image_url, target_path, insecure_tls, stored_cacert): + def _curl_header_args(self, headers: dict[str, str] | None) -> str: + """Generate header arguments for curl command""" + if not headers: + return "" + + parts: list[str] = [] + + def _sq(s: str) -> str: + return s.replace("'", "'\"'\"'") + + for k, v in headers.items(): + k = str(k).strip() + v = str(v).strip() + if not k: + continue + parts.append(f"-H '{_sq(k)}: {_sq(v)}'") + + return " ".join(parts) + + def _flash_with_progress( + self, + console, + manifest, + path, + image_url, + target_path, + insecure_tls, + stored_cacert, + header_args: str, + ): """Flash image to target device with progress monitoring. Args: @@ -241,11 +306,11 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path, tls_args = self._curl_tls_args(insecure_tls, stored_cacert) # Check if the image URL is accessible using curl and the TLS arguments - self._check_url_access(console, prompt, image_url, tls_args) + self._check_url_access(console, prompt, image_url, tls_args, header_args) # Flash the image, we run curl -> decompress -> dd in the background, so we can monitor dd's progress flash_cmd = ( - f'( curl -fsSL {tls_args} "{image_url}" | ' + f'( curl -fsSL {tls_args} {header_args} "{image_url}" | ' f"{decompress_cmd} " f"dd of={target_path} bs=64k iflag=fullblock oflag=direct) &" ) @@ -287,7 +352,7 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path, console.sendline("sync") console.expect(prompt, timeout=EXPECT_TIMEOUT_SYNC) - def _check_url_access(self, console, prompt, image_url: str, tls_args: str): + def _check_url_access(self, console, prompt, image_url: str, tls_args: str, header_args: str): """Check if the image URL is accessible using curl. Args: @@ -299,7 +364,9 @@ def _check_url_access(self, console, prompt, image_url: str, tls_args: str): Raises: RuntimeError: If the URL is not accessible """ - console.sendline(f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null {tls_args} "{image_url}"') + console.sendline( + f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null {tls_args} {header_args} "{image_url}"' + ) console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT) curl_output = console.before.decode(errors="ignore").strip() console.sendline("echo $?") @@ -358,6 +425,7 @@ def _transfer_bg_thread( to_storage: OpendalClient, error_queue, original_url: str | None = None, + headers: dict[str, str] | None = None, ): """Transfer image to exporter storage in the background Args: @@ -367,6 +435,7 @@ def _transfer_bg_thread( error_queue: Queue to put exceptions in if any known_hash: Known hash of the image original_url: Original URL for HTTP fallback + headers: HTTP headers for requests """ self.logger.info(f"Writing image to storage in the background: {src_path}") try: @@ -392,7 +461,9 @@ def _transfer_bg_thread( self.logger.info(f"Uploading image to storage: {filename}") to_storage.write_from_path(filename, src_path, src_operator) - metadata, metadata_json = self._create_metadata_and_json(src_operator, src_path, file_hash, original_url) + metadata, metadata_json = self._create_metadata_and_json( + src_operator, src_path, file_hash, original_url, headers + ) metadata_file = filename + ".metadata" to_storage.write_bytes(metadata_file, metadata_json.encode(errors="ignore")) @@ -415,7 +486,7 @@ def _sha256_file(self, src_operator, src_path) -> str: return m.hexdigest() def _create_metadata_and_json( - self, src_operator, src_path, file_hash=None, original_url=None + self, src_operator, src_path, file_hash=None, original_url=None, headers: dict[str, str] | None = None ) -> tuple[Metadata | None, str]: """Create a metadata json string from a metadata object""" metadata = None @@ -436,7 +507,10 @@ def _create_metadata_and_json( if original_url and original_url.startswith(("http://", "https://")): try: - response = requests.head(original_url) + if headers: + response = requests.head(original_url, headers=headers) + else: + response = requests.head(original_url) http_metadata = {} if "content-length" in response.headers: @@ -611,6 +685,71 @@ def manifest(self): self._manifest = FlasherBundleManifestV1Alpha1.from_string(yaml_str) return self._manifest + def _validate_header_dict(self, header_map: dict[str, str]) -> dict[str, str]: + token_re = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$") + seen: set[str] = set() + for key, value in header_map.items(): + key = key.strip() + value = value.strip() + if not key: + raise ArgumentError(f"Invalid header key: '{key}'") + + if not token_re.match(key): + raise ArgumentError(f"Invalid header name '{key}': must be an HTTP token (RFC7230)") + if any(c in ("\r", "\n") for c in key) or any(c in ("\r", "\n") for c in value): + raise ArgumentError("Header names/values must not contain CR/LF") + kl = key.lower() + if kl in seen: + raise ArgumentError(f"Duplicate header '{key}'") + seen.add(kl) + return header_map + + def _parse_headers(self, headers: list[str]) -> dict[str, str]: + header_map: dict[str, str] = {} + for h in headers: + if ":" not in h: + raise click.ClickException(f"Invalid header format: {h!r}. Expected 'Key: Value'.") + + key, value = h.split(":", 1) + header_map[key.strip()] = value.strip() + + try: + return self._validate_header_dict(header_map) + except ArgumentError as e: + raise click.ClickException(str(e)) from e + + def _prepare_headers(self, headers: dict[str, str] | None, bearer_token: str | None) -> str: + all_headers = headers.copy() if headers else {} + if bearer_token: + if any(k.lower() == "authorization" for k in all_headers.keys()): + self.logger.warning("Authorization header provided - ignoring bearer token") + else: + all_headers["Authorization"] = f"Bearer {bearer_token}" + + if bearer_token and "Authorization" not in (headers or {}): + auth_header = {"Authorization": all_headers["Authorization"]} + self._validate_header_dict(auth_header) + + return self._curl_header_args(all_headers) + + def _validate_bearer_token(self, token: str | None) -> str | None: + if token is None: + return None + + token = token.strip() + if not token: + raise click.ClickException("Bearer token cannot be empty") + + # RFC 6750 allows token68 format (base64url-encoded) or other token formats + # Basic validation: printable ASCII excluding whitespace and special chars that could cause issues + if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token): + raise click.ClickException("Bearer token contains invalid characters") + + if len(token) > 4096: + raise click.ClickException("Bearer token is too long (max 4096 characters)") + + return token + def cli(self): @driver_click_group(self) def base(): @@ -630,6 +769,17 @@ def base(): @click.option("--force-flash-bundle", type=str, help="Force use of a specific flasher OCI bundle") @click.option("--cacert", type=click.Path(exists=True, dir_okay=False), help="CA certificate to use for HTTPS") @click.option("--insecure-tls", is_flag=True, help="Skip TLS certificate verification") + @click.option( + "--header", + "header", + multiple=True, + help="Custom HTTP header in 'Key: Value' format", + ) + @click.option( + "--bearer", + type=str, + help="Bearer token for HTTP authentication", + ) @debug_console_option def flash( file, @@ -641,6 +791,8 @@ def flash( force_flash_bundle, cacert, insecure_tls, + header, + bearer, ): """Flash image to DUT from file""" if os_image_checksum_file and os.path.exists(os_image_checksum_file): @@ -649,6 +801,9 @@ def flash( self.logger.info(f"Read checksum from file: {os_image_checksum}") self.set_console_debug(console_debug) + + headers = self._parse_headers(header) if header else None + self.flash( file, partition=target, @@ -656,6 +811,8 @@ def flash( force_flash_bundle=force_flash_bundle, cacert_file=cacert, insecure_tls=insecure_tls, + headers=headers, + bearer_token=bearer, ) @base.command() diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py new file mode 100644 index 000000000..67282b55d --- /dev/null +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py @@ -0,0 +1,51 @@ +import click +import pytest + +from .client import BaseFlasherClient +from jumpstarter.common.exceptions import ArgumentError + + +class MockFlasherClient(BaseFlasherClient): + """Mock client for testing without full initialization""" + + def __init__(self): + self._manifest = None + self._console_debug = False + self.logger = type( + "MockLogger", (), {"warning": lambda msg: None, "info": lambda msg: None, "error": lambda msg: None} + )() + + def close(self): + pass + + +def test_validate_bearer_token_fails_invalid(): + """Test bearer token validation fails with invalid tokens""" + client = MockFlasherClient() + + with pytest.raises(click.ClickException, match="Bearer token cannot be empty"): + client._validate_bearer_token("") + + with pytest.raises(click.ClickException, match="Bearer token contains invalid characters"): + client._validate_bearer_token("token with spaces") + + with pytest.raises(click.ClickException, match="Bearer token contains invalid characters"): + client._validate_bearer_token('token"with"quotes') + + +def test_curl_header_args_handles_quotes(): + """Test curl header formatting safely handles quotes""" + client = MockFlasherClient() + + result = client._curl_header_args({"Authorization": "Bearer abc'def"}) + assert "'\"'\"'" in result + assert result.startswith("-H '") + assert result.endswith("'") + + +def test_flash_fails_with_invalid_headers(): + """Test flash method fails early with invalid headers""" + client = MockFlasherClient() + + with pytest.raises(ArgumentError, match="Invalid header name 'Invalid Header': must be an HTTP token"): + client.flash("test.raw", headers={"Invalid Header": "value"})