Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def bootloader_shell(self):
pass
yield self.serial

def flash(
def flash( # noqa: C901
self,
path: PathBuf,
*,
Expand All @@ -81,10 +81,19 @@ def flash(
force_flash_bundle: str | None = None,
cacert_file: str | None = None,
insecure_tls: bool = False,
headers: dict[str, str] | None = None,
bearer_token: str | None = None,
):
if bearer_token:
bearer_token = self._validate_bearer_token(bearer_token)

if headers:
headers = self._validate_header_dict(headers)

"""Flash image to DUT"""
should_download_to_httpd = True
image_url = ""
original_http_url = None
operator_scheme = None
# initrmafs cannot handle https yet, fallback to using the exporter's http server
if path.startswith(("http://", "https://")) and not force_exporter_http:
Expand All @@ -94,7 +103,17 @@ def flash(
else:
# use the exporter's http server for the flasher image, we should download it first
if operator is None:
path, operator, operator_scheme = operator_for_path(path)
if path.startswith(("http://", "https://")) and bearer_token:
parsed = urlparse(path)
self.logger.info(f"Using Bearer token authentication for {parsed.netloc}")
original_http_url = path
operator = Operator(
"http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", token=bearer_token
)
operator_scheme = "http"
path = Path(parsed.path)
else:
path, operator, operator_scheme = operator_for_path(path)
image_url = self.http.get_url() + "/" + path.name

# start counting time for the flash operation
Expand All @@ -107,7 +126,16 @@ def flash(
# Start the storage write operation in the background
storage_thread = threading.Thread(
target=self._transfer_bg_thread,
args=(path, operator, operator_scheme, os_image_checksum, self.http.storage, error_queue, image_url),
args=(
path,
operator,
operator_scheme,
os_image_checksum,
self.http.storage,
error_queue,
original_http_url,
headers,
),
name="storage_transfer",
)
storage_thread.start()
Expand Down Expand Up @@ -152,9 +180,17 @@ def flash(
else:
stored_cacert = self._setup_flasher_ssl(console, manifest, cacert_file)


self._flash_with_progress(console, manifest, path, image_url, target_device,
insecure_tls, stored_cacert)
header_args = self._prepare_headers(headers, bearer_token)
self._flash_with_progress(
console,
manifest,
path,
image_url,
target_device,
insecure_tls,
stored_cacert,
header_args,
)

total_time = time.time() - start_time
# total time in minutes:seconds
Expand Down Expand Up @@ -222,7 +258,36 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str:
tls_args += f"--cacert {stored_cacert} "
return tls_args.strip()

def _flash_with_progress(self, console, manifest, path, image_url, target_path, insecure_tls, stored_cacert):
def _curl_header_args(self, headers: dict[str, str] | None) -> str:
"""Generate header arguments for curl command"""
if not headers:
return ""

parts: list[str] = []

def _sq(s: str) -> str:
return s.replace("'", "'\"'\"'")

for k, v in headers.items():
k = str(k).strip()
v = str(v).strip()
if not k:
continue
parts.append(f"-H '{_sq(k)}: {_sq(v)}'")

return " ".join(parts)

def _flash_with_progress(
self,
console,
manifest,
path,
image_url,
target_path,
insecure_tls,
stored_cacert,
header_args: str,
):
"""Flash image to target device with progress monitoring.

Args:
Expand All @@ -241,11 +306,11 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path,
tls_args = self._curl_tls_args(insecure_tls, stored_cacert)

# Check if the image URL is accessible using curl and the TLS arguments
self._check_url_access(console, prompt, image_url, tls_args)
self._check_url_access(console, prompt, image_url, tls_args, header_args)

# Flash the image, we run curl -> decompress -> dd in the background, so we can monitor dd's progress
flash_cmd = (
f'( curl -fsSL {tls_args} "{image_url}" | '
f'( curl -fsSL {tls_args} {header_args} "{image_url}" | '
f"{decompress_cmd} "
f"dd of={target_path} bs=64k iflag=fullblock oflag=direct) &"
)
Expand Down Expand Up @@ -287,7 +352,7 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path,
console.sendline("sync")
console.expect(prompt, timeout=EXPECT_TIMEOUT_SYNC)

def _check_url_access(self, console, prompt, image_url: str, tls_args: str):
def _check_url_access(self, console, prompt, image_url: str, tls_args: str, header_args: str):
"""Check if the image URL is accessible using curl.

Args:
Expand All @@ -299,7 +364,9 @@ def _check_url_access(self, console, prompt, image_url: str, tls_args: str):
Raises:
RuntimeError: If the URL is not accessible
"""
console.sendline(f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null {tls_args} "{image_url}"')
console.sendline(
f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null {tls_args} {header_args} "{image_url}"'
)
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
curl_output = console.before.decode(errors="ignore").strip()
console.sendline("echo $?")
Expand Down Expand Up @@ -358,6 +425,7 @@ def _transfer_bg_thread(
to_storage: OpendalClient,
error_queue,
original_url: str | None = None,
headers: dict[str, str] | None = None,
):
"""Transfer image to exporter storage in the background
Args:
Expand All @@ -367,6 +435,7 @@ def _transfer_bg_thread(
error_queue: Queue to put exceptions in if any
known_hash: Known hash of the image
original_url: Original URL for HTTP fallback
headers: HTTP headers for requests
"""
self.logger.info(f"Writing image to storage in the background: {src_path}")
try:
Expand All @@ -392,7 +461,9 @@ def _transfer_bg_thread(
self.logger.info(f"Uploading image to storage: {filename}")
to_storage.write_from_path(filename, src_path, src_operator)

metadata, metadata_json = self._create_metadata_and_json(src_operator, src_path, file_hash, original_url)
metadata, metadata_json = self._create_metadata_and_json(
src_operator, src_path, file_hash, original_url, headers
)
metadata_file = filename + ".metadata"
to_storage.write_bytes(metadata_file, metadata_json.encode(errors="ignore"))

Expand All @@ -415,7 +486,7 @@ def _sha256_file(self, src_operator, src_path) -> str:
return m.hexdigest()

def _create_metadata_and_json(
self, src_operator, src_path, file_hash=None, original_url=None
self, src_operator, src_path, file_hash=None, original_url=None, headers: dict[str, str] | None = None
) -> tuple[Metadata | None, str]:
"""Create a metadata json string from a metadata object"""
metadata = None
Expand All @@ -436,7 +507,10 @@ def _create_metadata_and_json(

if original_url and original_url.startswith(("http://", "https://")):
try:
response = requests.head(original_url)
if headers:
response = requests.head(original_url, headers=headers)
else:
response = requests.head(original_url)

http_metadata = {}
if "content-length" in response.headers:
Expand Down Expand Up @@ -611,6 +685,71 @@ def manifest(self):
self._manifest = FlasherBundleManifestV1Alpha1.from_string(yaml_str)
return self._manifest

def _validate_header_dict(self, header_map: dict[str, str]) -> dict[str, str]:
token_re = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$")
seen: set[str] = set()
for key, value in header_map.items():
key = key.strip()
value = value.strip()
if not key:
raise ArgumentError(f"Invalid header key: '{key}'")

if not token_re.match(key):
raise ArgumentError(f"Invalid header name '{key}': must be an HTTP token (RFC7230)")
if any(c in ("\r", "\n") for c in key) or any(c in ("\r", "\n") for c in value):
raise ArgumentError("Header names/values must not contain CR/LF")
kl = key.lower()
if kl in seen:
raise ArgumentError(f"Duplicate header '{key}'")
seen.add(kl)
return header_map

def _parse_headers(self, headers: list[str]) -> dict[str, str]:
header_map: dict[str, str] = {}
for h in headers:
if ":" not in h:
raise click.ClickException(f"Invalid header format: {h!r}. Expected 'Key: Value'.")

key, value = h.split(":", 1)
header_map[key.strip()] = value.strip()

try:
return self._validate_header_dict(header_map)
except ArgumentError as e:
raise click.ClickException(str(e)) from e

def _prepare_headers(self, headers: dict[str, str] | None, bearer_token: str | None) -> str:
all_headers = headers.copy() if headers else {}
if bearer_token:
if any(k.lower() == "authorization" for k in all_headers.keys()):
self.logger.warning("Authorization header provided - ignoring bearer token")
else:
all_headers["Authorization"] = f"Bearer {bearer_token}"

if bearer_token and "Authorization" not in (headers or {}):
auth_header = {"Authorization": all_headers["Authorization"]}
self._validate_header_dict(auth_header)

return self._curl_header_args(all_headers)

def _validate_bearer_token(self, token: str | None) -> str | None:
if token is None:
return None

token = token.strip()
if not token:
raise click.ClickException("Bearer token cannot be empty")

# RFC 6750 allows token68 format (base64url-encoded) or other token formats
# Basic validation: printable ASCII excluding whitespace and special chars that could cause issues
if not all(32 < ord(c) < 127 and c not in ' "\\' for c in token):
raise click.ClickException("Bearer token contains invalid characters")

if len(token) > 4096:
raise click.ClickException("Bearer token is too long (max 4096 characters)")

return token

def cli(self):
@driver_click_group(self)
def base():
Expand All @@ -630,6 +769,17 @@ def base():
@click.option("--force-flash-bundle", type=str, help="Force use of a specific flasher OCI bundle")
@click.option("--cacert", type=click.Path(exists=True, dir_okay=False), help="CA certificate to use for HTTPS")
@click.option("--insecure-tls", is_flag=True, help="Skip TLS certificate verification")
@click.option(
"--header",
"header",
multiple=True,
help="Custom HTTP header in 'Key: Value' format",
)
@click.option(
"--bearer",
type=str,
help="Bearer token for HTTP authentication",
)
@debug_console_option
def flash(
file,
Expand All @@ -641,6 +791,8 @@ def flash(
force_flash_bundle,
cacert,
insecure_tls,
header,
bearer,
):
"""Flash image to DUT from file"""
if os_image_checksum_file and os.path.exists(os_image_checksum_file):
Expand All @@ -649,13 +801,18 @@ def flash(
self.logger.info(f"Read checksum from file: {os_image_checksum}")

self.set_console_debug(console_debug)

headers = self._parse_headers(header) if header else None

self.flash(
file,
partition=target,
force_exporter_http=force_exporter_http,
force_flash_bundle=force_flash_bundle,
cacert_file=cacert,
insecure_tls=insecure_tls,
headers=headers,
bearer_token=bearer,
)

@base.command()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import click
import pytest

from .client import BaseFlasherClient
from jumpstarter.common.exceptions import ArgumentError


class MockFlasherClient(BaseFlasherClient):
"""Mock client for testing without full initialization"""

def __init__(self):
self._manifest = None
self._console_debug = False
self.logger = type(
"MockLogger", (), {"warning": lambda msg: None, "info": lambda msg: None, "error": lambda msg: None}
)()

def close(self):
pass


def test_validate_bearer_token_fails_invalid():
"""Test bearer token validation fails with invalid tokens"""
client = MockFlasherClient()

with pytest.raises(click.ClickException, match="Bearer token cannot be empty"):
client._validate_bearer_token("")

with pytest.raises(click.ClickException, match="Bearer token contains invalid characters"):
client._validate_bearer_token("token with spaces")

with pytest.raises(click.ClickException, match="Bearer token contains invalid characters"):
client._validate_bearer_token('token"with"quotes')


def test_curl_header_args_handles_quotes():
"""Test curl header formatting safely handles quotes"""
client = MockFlasherClient()

result = client._curl_header_args({"Authorization": "Bearer abc'def"})
assert "'\"'\"'" in result
assert result.startswith("-H '")
assert result.endswith("'")


def test_flash_fails_with_invalid_headers():
"""Test flash method fails early with invalid headers"""
client = MockFlasherClient()

with pytest.raises(ArgumentError, match="Invalid header name 'Invalid Header': must be an HTTP token"):
client.flash("test.raw", headers={"Invalid Header": "value"})
Loading