-
Notifications
You must be signed in to change notification settings - Fork 33
feat(tools): add HuggingFace config → architecture_config converter #198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,59 @@ | ||||||||||||||||||||||
| #!/usr/bin/env python3 | ||||||||||||||||||||||
| """Convert HuggingFace config.json to architecture_config format.""" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import json | ||||||||||||||||||||||
| import sys | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| REQUIRED_MAPPINGS = { | ||||||||||||||||||||||
| "numLayers": "num_hidden_layers", | ||||||||||||||||||||||
| "hiddenSize": "hidden_size", | ||||||||||||||||||||||
| "numAttentionHeads": "num_attention_heads", | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def convert_hf_config(hf_config: dict) -> dict: | ||||||||||||||||||||||
| """Convert HuggingFace config to architecture_config format.""" | ||||||||||||||||||||||
| arch_config = {"type": "transformer"} | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for arch_key, hf_key in REQUIRED_MAPPINGS.items(): | ||||||||||||||||||||||
| if hf_key not in hf_config: | ||||||||||||||||||||||
| raise ValueError(f"missing required field: {hf_key}") | ||||||||||||||||||||||
| value = hf_config[hf_key] | ||||||||||||||||||||||
| if not isinstance(value, int) or isinstance(value, bool): | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current type check for integers is functionally correct but a bit subtle because it relies on
Suggested change
|
||||||||||||||||||||||
| raise ValueError(f"field {hf_key} must be an integer, got {type(value).__name__}") | ||||||||||||||||||||||
| if value < 1: | ||||||||||||||||||||||
| raise ValueError(f"field {hf_key} must be >= 1, got {value}") | ||||||||||||||||||||||
| arch_config[arch_key] = value | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return arch_config | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def main(): | ||||||||||||||||||||||
| if len(sys.argv) != 2: | ||||||||||||||||||||||
| print(f"usage: {sys.argv[0]} <config.json>", file=sys.stderr) | ||||||||||||||||||||||
| sys.exit(1) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| config_path = sys.argv[1] | ||||||||||||||||||||||
|
Comment on lines
+33
to
+37
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For better command-line interface design and maintainability, consider using the
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| with open(config_path, "r") as f: | ||||||||||||||||||||||
|
||||||||||||||||||||||
| with open(config_path, "r") as f: | |
| with open(config_path, "r", encoding="utf-8") as f: |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,146 @@ | ||||||||||||||||||||||||||||
| #!/usr/bin/env python3 | ||||||||||||||||||||||||||||
| """Tests for hf_to_arch.py""" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||
| import subprocess | ||||||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||||||
| import tempfile | ||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| SCRIPT_PATH = os.path.join(os.path.dirname(__file__), "hf_to_arch.py") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def run_script(config_content: str) -> tuple: | ||||||||||||||||||||||||||||
| """Run hf_to_arch.py with given config content, return (exitcode, stdout, stderr).""" | ||||||||||||||||||||||||||||
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: | ||||||||||||||||||||||||||||
| f.write(config_content) | ||||||||||||||||||||||||||||
| f.flush() | ||||||||||||||||||||||||||||
| temp_path = f.name | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||
| result = subprocess.run( | ||||||||||||||||||||||||||||
| [sys.executable, SCRIPT_PATH, temp_path], | ||||||||||||||||||||||||||||
| capture_output=True, | ||||||||||||||||||||||||||||
| text=True, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| return result.returncode, result.stdout, result.stderr | ||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||
| os.unlink(temp_path) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_valid_config(): | ||||||||||||||||||||||||||||
| """Valid HuggingFace config produces correct output.""" | ||||||||||||||||||||||||||||
| config = json.dumps({ | ||||||||||||||||||||||||||||
| "num_hidden_layers": 32, | ||||||||||||||||||||||||||||
| "hidden_size": 4096, | ||||||||||||||||||||||||||||
| "num_attention_heads": 32, | ||||||||||||||||||||||||||||
| "vocab_size": 32000, | ||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| exitcode, stdout, stderr = run_script(config) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| assert exitcode == 0, f"expected exit 0, got {exitcode}: {stderr}" | ||||||||||||||||||||||||||||
| output = json.loads(stdout) | ||||||||||||||||||||||||||||
| assert output == { | ||||||||||||||||||||||||||||
| "type": "transformer", | ||||||||||||||||||||||||||||
| "numLayers": 32, | ||||||||||||||||||||||||||||
| "hiddenSize": 4096, | ||||||||||||||||||||||||||||
| "numAttentionHeads": 32, | ||||||||||||||||||||||||||||
| }, f"unexpected output: {output}" | ||||||||||||||||||||||||||||
| print("PASS: test_valid_config") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_missing_field(): | ||||||||||||||||||||||||||||
| """Missing required field produces error.""" | ||||||||||||||||||||||||||||
| config = json.dumps({ | ||||||||||||||||||||||||||||
| "num_hidden_layers": 32, | ||||||||||||||||||||||||||||
| "hidden_size": 4096, | ||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| exitcode, stdout, stderr = run_script(config) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| assert exitcode != 0, "expected non-zero exit for missing field" | ||||||||||||||||||||||||||||
| assert "num_attention_heads" in stderr, f"error should mention missing field: {stderr}" | ||||||||||||||||||||||||||||
| print("PASS: test_missing_field") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_invalid_json(): | ||||||||||||||||||||||||||||
| """Invalid JSON produces error.""" | ||||||||||||||||||||||||||||
| exitcode, stdout, stderr = run_script("not valid json {") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| assert exitcode != 0, "expected non-zero exit for invalid JSON" | ||||||||||||||||||||||||||||
| assert "invalid JSON" in stderr.lower() or "json" in stderr.lower(), f"error should mention JSON: {stderr}" | ||||||||||||||||||||||||||||
| print("PASS: test_invalid_json") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_file_not_found(): | ||||||||||||||||||||||||||||
| """Non-existent file produces error.""" | ||||||||||||||||||||||||||||
| result = subprocess.run( | ||||||||||||||||||||||||||||
| [sys.executable, SCRIPT_PATH, "/nonexistent/path/config.json"], | ||||||||||||||||||||||||||||
| capture_output=True, | ||||||||||||||||||||||||||||
| text=True, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| assert result.returncode != 0, "expected non-zero exit for missing file" | ||||||||||||||||||||||||||||
| assert "not found" in result.stderr.lower(), f"error should mention file not found: {result.stderr}" | ||||||||||||||||||||||||||||
| print("PASS: test_file_not_found") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_invalid_field_type(): | ||||||||||||||||||||||||||||
| """Non-integer field produces error.""" | ||||||||||||||||||||||||||||
| config = json.dumps({ | ||||||||||||||||||||||||||||
| "num_hidden_layers": "32", | ||||||||||||||||||||||||||||
| "hidden_size": 4096, | ||||||||||||||||||||||||||||
| "num_attention_heads": 32, | ||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| exitcode, stdout, stderr = run_script(config) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| assert exitcode != 0, "expected non-zero exit for invalid type" | ||||||||||||||||||||||||||||
| assert "integer" in stderr.lower(), f"error should mention type: {stderr}" | ||||||||||||||||||||||||||||
| print("PASS: test_invalid_field_type") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_zero_value(): | ||||||||||||||||||||||||||||
| """Zero value produces error.""" | ||||||||||||||||||||||||||||
| config = json.dumps({ | ||||||||||||||||||||||||||||
| "num_hidden_layers": 0, | ||||||||||||||||||||||||||||
| "hidden_size": 4096, | ||||||||||||||||||||||||||||
| "num_attention_heads": 32, | ||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| exitcode, stdout, stderr = run_script(config) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| assert exitcode != 0, "expected non-zero exit for zero value" | ||||||||||||||||||||||||||||
| assert ">= 1" in stderr, f"error should mention minimum: {stderr}" | ||||||||||||||||||||||||||||
| print("PASS: test_zero_value") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_bool_value(): | ||||||||||||||||||||||||||||
| """Boolean value produces error.""" | ||||||||||||||||||||||||||||
| config = json.dumps({ | ||||||||||||||||||||||||||||
| "num_hidden_layers": True, | ||||||||||||||||||||||||||||
| "hidden_size": 4096, | ||||||||||||||||||||||||||||
| "num_attention_heads": 32, | ||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| exitcode, stdout, stderr = run_script(config) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| assert exitcode != 0, "expected non-zero exit for bool value" | ||||||||||||||||||||||||||||
| assert "integer" in stderr.lower(), f"error should mention type: {stderr}" | ||||||||||||||||||||||||||||
| print("PASS: test_bool_value") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def main(): | ||||||||||||||||||||||||||||
| test_valid_config() | ||||||||||||||||||||||||||||
| test_missing_field() | ||||||||||||||||||||||||||||
| test_invalid_json() | ||||||||||||||||||||||||||||
| test_file_not_found() | ||||||||||||||||||||||||||||
| test_invalid_field_type() | ||||||||||||||||||||||||||||
| test_zero_value() | ||||||||||||||||||||||||||||
| test_bool_value() | ||||||||||||||||||||||||||||
| print("\nAll tests passed.") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||
| main() | ||||||||||||||||||||||||||||
|
Comment on lines
+134
to
+146
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test file implements a custom test runner by calling test functions from
Migrating to
Comment on lines
+134
to
+146
|
||||||||||||||||||||||||||||
| def main(): | |
| test_valid_config() | |
| test_missing_field() | |
| test_invalid_json() | |
| test_file_not_found() | |
| test_invalid_field_type() | |
| test_zero_value() | |
| test_bool_value() | |
| print("\nAll tests passed.") | |
| if __name__ == "__main__": | |
| main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convert_hf_configassumeshf_configis a dict-like JSON object. If the input JSON is a list/string/etc., thehf_key not in hf_config/hf_config[hf_key]access can raise aTypeError(or behave unexpectedly) and will bypass the currentValueErrorhandling, resulting in a raw stack trace. Add an explicit top-level type check (e.g., requiredict) and raise aValueErrorwith a clear message when the JSON root is not an object.