Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 43e0c45

Browse files
committedMar 17, 2025
dependency inject app configuration
1 parent 1d0a34c commit 43e0c45

24 files changed

+154
-245
lines changed
 

‎pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ dependencies = [
5757
"pytest-xdist~=3.6",
5858
"pytest-asyncio~=0.24",
5959
"pytest-httpx~=0.30",
60+
"tomli-w>=1.2.0",
6061
]
6162

6263
[[tool.hatch.envs.hatch-test.matrix]]

‎src/shelloracle/__main__.py

+2-66
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,4 @@
1-
import argparse
2-
import logging
3-
import sys
4-
from importlib.metadata import version
5-
6-
from shelloracle import shelloracle
7-
from shelloracle.config import initialize_config
8-
from shelloracle.settings import Settings
9-
from shelloracle.tty_log_handler import TtyLogHandler
10-
11-
logger = logging.getLogger(__name__)
12-
13-
14-
def configure_logging():
15-
root_logger = logging.getLogger()
16-
root_logger.setLevel(logging.DEBUG)
17-
18-
file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
19-
file_handler = logging.FileHandler(Settings.shelloracle_home / "shelloracle.log")
20-
file_handler.setLevel(logging.DEBUG)
21-
file_handler.setFormatter(file_formatter)
22-
23-
tty_formatter = logging.Formatter("%(message)s")
24-
tty_handler = TtyLogHandler()
25-
tty_handler.setLevel(logging.WARNING)
26-
tty_handler.setFormatter(tty_formatter)
27-
28-
root_logger.addHandler(file_handler)
29-
root_logger.addHandler(tty_handler)
30-
31-
32-
def configure():
33-
# nest this import in a function to avoid expensive module loads
34-
from shelloracle.bootstrap import bootstrap_shelloracle
35-
36-
bootstrap_shelloracle()
37-
38-
39-
def parse_args() -> argparse.Namespace:
40-
parser = argparse.ArgumentParser()
41-
parser.add_argument("--version", action="version", version=f"{__package__} {version(__package__)}")
42-
43-
subparsers = parser.add_subparsers()
44-
configure_subparser = subparsers.add_parser("configure", help=f"install {__package__} keybindings")
45-
configure_subparser.set_defaults(action=configure)
46-
47-
return parser.parse_args()
48-
49-
50-
def main() -> None:
51-
args = parse_args()
52-
configure_logging()
53-
54-
if action := getattr(args, "action", None):
55-
action()
56-
sys.exit(0)
57-
58-
try:
59-
initialize_config()
60-
except FileNotFoundError:
61-
logger.warning("ShellOracle configuration not found. Run `shor configure` to initialize.")
62-
sys.exit(1)
63-
64-
shelloracle.cli()
65-
66-
671
if __name__ == "__main__":
2+
from shelloracle.cli import main
3+
684
main()

‎src/shelloracle/bootstrap.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from prompt_toolkit.shortcuts import confirm
1313

1414
from shelloracle.providers import Provider, Setting, get_provider, list_providers
15-
from shelloracle.settings import Settings
1615

1716
if TYPE_CHECKING:
1817
from collections.abc import Iterator, Sequence
@@ -104,7 +103,7 @@ def correct_name_setting():
104103
yield from correct_name_setting()
105104

106105

107-
def write_shelloracle_config(provider: type[Provider], settings: dict[str, Any]) -> None:
106+
def write_shelloracle_config(provider: type[Provider], settings: dict[str, Any], config_path: Path) -> None:
108107
config = tomlkit.document()
109108

110109
shor_table = tomlkit.table()
@@ -119,8 +118,7 @@ def write_shelloracle_config(provider: type[Provider], settings: dict[str, Any])
119118
provider_configuration_table.add(setting, value)
120119
provider_table.add(provider.name, provider_configuration_table)
121120

122-
filepath = Settings.shelloracle_home / "config.toml"
123-
with filepath.open("w") as config_file:
121+
with config_path.open("w") as config_file:
124122
tomlkit.dump(config, config_file)
125123

126124

@@ -164,7 +162,7 @@ def user_select_provider() -> type[Provider]:
164162
return get_provider(provider_name)
165163

166164

167-
def bootstrap_shelloracle() -> None:
165+
def bootstrap_shelloracle(config_path: Path) -> None:
168166
try:
169167
provider = user_select_provider()
170168
settings = user_configure_settings(provider)
@@ -173,5 +171,5 @@ def bootstrap_shelloracle() -> None:
173171
return
174172
except KeyboardInterrupt:
175173
return
176-
write_shelloracle_config(provider, settings)
174+
write_shelloracle_config(provider, settings, config_path)
177175
install_keybindings()

‎src/shelloracle/cli/__init__.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
import logging
22
import sys
3-
from importlib.metadata import version
3+
from pathlib import Path
44

55
import click
66

77
from shelloracle import shelloracle
8+
from shelloracle.cli.application import Application
89
from shelloracle.cli.config import config
9-
from shelloracle.config import initialize_config
10-
from shelloracle.settings import Settings
10+
from shelloracle.config import Configuration
1111
from shelloracle.tty_log_handler import TtyLogHandler
1212

1313
logger = logging.getLogger(__name__)
1414

1515

16-
def configure_logging():
16+
def configure_logging(log_path: Path):
1717
root_logger = logging.getLogger()
1818
root_logger.setLevel(logging.DEBUG)
1919

2020
file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
21-
file_handler = logging.FileHandler(Settings.shelloracle_home / "shelloracle.log")
21+
file_handler = logging.FileHandler(log_path)
2222
file_handler.setLevel(logging.DEBUG)
2323
file_handler.setFormatter(file_formatter)
2424

@@ -32,21 +32,25 @@ def configure_logging():
3232

3333

3434
@click.group(invoke_without_command=True)
35-
@click.version_option(version=version("shelloracle"))
35+
@click.version_option()
3636
@click.pass_context
3737
def cli(ctx):
3838
"""ShellOracle command line interface."""
39-
configure_logging()
39+
app = Application()
40+
configure_logging(app.log_path)
41+
ctx.obj = app
4042

41-
# If no subcommand is invoked, run the main CLI
42-
if ctx.invoked_subcommand is None:
43-
try:
44-
initialize_config()
45-
except FileNotFoundError:
46-
logger.warning("ShellOracle configuration not found. Run `shor config init` to initialize.")
47-
sys.exit(1)
43+
if ctx.invoked_subcommand is not None:
44+
# If no subcommand is invoked, run the main CLI
45+
return
4846

49-
shelloracle.cli()
47+
try:
48+
app.configuration = Configuration.from_file(app.config_path)
49+
except FileNotFoundError:
50+
logger.warning("Configuration not found. Run `shor config init` to initialize.")
51+
sys.exit(1)
52+
53+
shelloracle.cli(app)
5054

5155

5256
cli.add_command(config)

‎src/shelloracle/cli/application.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from pathlib import Path
2+
3+
from shelloracle.config import Configuration
4+
5+
shelloracle_home = Path.home() / ".shelloracle"
6+
shelloracle_home.mkdir(exist_ok=True)
7+
8+
9+
class Application:
10+
configuration: Configuration
11+
12+
def __init__(self):
13+
self.config_path = shelloracle_home / "config.toml"
14+
self.log_path = shelloracle_home / "shelloracle.log"

‎src/shelloracle/cli/config/init.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import click
22

3+
from shelloracle.cli import Application
4+
35

46
@click.command()
5-
def init():
7+
@click.pass_obj
8+
def init(app: Application):
69
"""Install shelloracle keybindings."""
710
# nest this import in a function to avoid expensive module loads
811
from shelloracle.bootstrap import bootstrap_shelloracle
912

10-
bootstrap_shelloracle()
13+
bootstrap_shelloracle(app.config_path)

‎src/shelloracle/config.py

+13-35
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77

88
from yaspin.spinners import SPINNERS_DATA
99

10-
from shelloracle.settings import Settings
11-
1210
if TYPE_CHECKING:
1311
from pathlib import Path
1412

13+
1514
if sys.version_info < (3, 11):
1615
import tomli as tomllib
1716
else:
@@ -21,15 +20,13 @@
2120

2221

2322
class Configuration(Mapping):
24-
def __init__(self, filepath: Path) -> None:
23+
def __init__(self, config: dict[str, Any]) -> None:
2524
"""ShellOracle application configuration
2625
27-
:param filepath: Path to the configuration file
26+
:param config: configuration dict
2827
:raises FileNotFoundError: if the configuration file does not exist
2928
"""
30-
self.filepath = filepath
31-
with filepath.open("rb") as config_file:
32-
self._config = tomllib.load(config_file)
29+
self._config = config
3330

3431
def __getitem__(self, key: str) -> Any:
3532
return self._config[key]
@@ -46,6 +43,10 @@ def __str__(self):
4643
def __repr__(self) -> str:
4744
return str(self)
4845

46+
@property
47+
def raw_config(self) -> dict[str, Any]:
48+
return self._config
49+
4950
@property
5051
def provider(self) -> str:
5152
return self["shelloracle"]["provider"]
@@ -60,31 +61,8 @@ def spinner_style(self) -> str | None:
6061
return None
6162
return style
6263

63-
64-
_config: Configuration | None = None
65-
66-
67-
def initialize_config() -> None:
68-
"""Initialize the configuration file
69-
70-
:raises RuntimeError: if the config is already initialized
71-
:raises FileNotFoundError: if the config file is not found
72-
"""
73-
global _config # noqa: PLW0603
74-
if _config:
75-
msg = "Configuration already initialized"
76-
raise RuntimeError(msg)
77-
filepath = Settings.shelloracle_home / "config.toml"
78-
_config = Configuration(filepath)
79-
80-
81-
def get_config() -> Configuration:
82-
"""Returns the global configuration object.
83-
84-
:return: the global configuration
85-
:raises RuntimeError: if the configuration is not initialized
86-
"""
87-
if _config is None:
88-
msg = "Configuration not initialized"
89-
raise RuntimeError(msg)
90-
return _config
64+
@classmethod
65+
def from_file(cls, filepath: Path):
66+
with filepath.open("rb") as config_file:
67+
config = tomllib.load(config_file)
68+
return cls(config)

‎src/shelloracle/providers/__init__.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable
5-
6-
from shelloracle.config import get_config
4+
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
75

86
if TYPE_CHECKING:
97
from collections.abc import AsyncIterator
108

9+
from shelloracle.config import Configuration
10+
1111
system_prompt = (
1212
"Based on the following user description, generate a corresponding shell command. Focus solely "
1313
"on interpreting the requirements and translating them into a single, executable Bash command. "
@@ -22,7 +22,6 @@ class ProviderError(Exception):
2222
"""LLM providers raise this error to gracefully indicate something has gone wrong."""
2323

2424

25-
@runtime_checkable
2625
class Provider(Protocol):
2726
"""
2827
LLM Provider Protocol
@@ -31,6 +30,10 @@ class Provider(Protocol):
3130
"""
3231

3332
name: str
33+
config: Configuration
34+
35+
def __init__(self, config: Configuration) -> None:
36+
self.config = config
3437

3538
@abstractmethod
3639
def generate(self, prompt: str) -> AsyncIterator[str]:
@@ -64,9 +67,8 @@ def __get__(self, instance: Provider, owner: type[Provider]) -> T:
6467
# inspect.get_members from determining the object type
6568
msg = "Settings must be accessed through a provider instance."
6669
raise AttributeError(msg)
67-
config = get_config()
6870
try:
69-
return config["provider"][owner.name][self.name]
71+
return instance.config["provider"][owner.name][self.name]
7072
except KeyError:
7173
if self.default is None:
7274
raise

‎src/shelloracle/providers/deepseek.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class Deepseek(Provider):
1111
api_key = Setting(default="")
1212
model = Setting(default="deepseek-chat")
1313

14-
def __init__(self):
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
1516
if not self.api_key:
1617
msg = "No API key provided"
1718
raise ProviderError(msg)

‎src/shelloracle/providers/google.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class Google(Provider):
1111
api_key = Setting(default="")
1212
model = Setting(default="gemini-2.0-flash") # Assuming a default model name
1313

14-
def __init__(self):
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
1516
if not self.api_key:
1617
msg = "No API key provided"
1718
raise ProviderError(msg)

‎src/shelloracle/providers/localai.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ class LocalAI(Provider):
1616
def endpoint(self) -> str:
1717
return f"http://{self.host}:{self.port}"
1818

19-
def __init__(self):
19+
def __init__(self, *args, **kwargs):
20+
super().__init__(*args, **kwargs)
2021
# Use a placeholder API key so the client will work
2122
self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint)
2223

‎src/shelloracle/providers/openai.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class OpenAI(Provider):
1111
api_key = Setting(default="")
1212
model = Setting(default="gpt-3.5-turbo")
1313

14-
def __init__(self):
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
1516
if not self.api_key:
1617
msg = "No API key provided"
1718
raise ProviderError(msg)

‎src/shelloracle/providers/openai_compat.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class OpenAICompat(Provider):
1212
api_key = Setting(default="")
1313
model = Setting(default="")
1414

15-
def __init__(self):
15+
def __init__(self, *args, **kwargs):
16+
super().__init__(*args, **kwargs)
1617
if not self.api_key:
1718
msg = "No API key provided. Use a dummy placeholder if no key is required"
1819
raise ProviderError(msg)

‎src/shelloracle/providers/xai.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class XAI(Provider):
1111
api_key = Setting(default="")
1212
model = Setting(default="grok-beta")
1313

14-
def __init__(self):
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
1516
if not self.api_key:
1617
msg = "No API key provided"
1718
raise ProviderError(msg)

‎src/shelloracle/settings.py

-6
This file was deleted.

‎src/shelloracle/shelloracle.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from yaspin import yaspin
1515
from yaspin.spinners import Spinners
1616

17-
from shelloracle.config import get_config
1817
from shelloracle.providers import get_provider
1918

2019
if TYPE_CHECKING:
2120
from yaspin.core import Yaspin
2221

22+
from shelloracle.cli import Application
23+
2324
logger = logging.getLogger(__name__)
2425

2526

@@ -47,19 +48,19 @@ def get_query_from_pipe() -> str | None:
4748
return lines[0].rstrip()
4849

4950

50-
def spinner() -> Yaspin:
51+
def spinner(style: str | None) -> Yaspin:
5152
"""Get the correct spinner based on the user's configuration
5253
54+
:param style: The spinner style
5355
:returns: yaspin object
5456
"""
55-
config = get_config()
56-
if not config.spinner_style:
57-
return yaspin()
58-
style = getattr(Spinners, config.spinner_style)
59-
return yaspin(style)
57+
if style:
58+
style = getattr(Spinners, style)
59+
return yaspin(style)
60+
return yaspin()
6061

6162

62-
async def shelloracle() -> None:
63+
async def shelloracle(app: Application) -> None:
6364
"""ShellOracle program entrypoint
6465
6566
If there is a query from the input pipe, it processes the query to generate a response.
@@ -76,11 +77,10 @@ async def shelloracle() -> None:
7677
prompt = await prompt_user(default_prompt)
7778
logger.info("user prompt: %s", prompt)
7879

79-
config = get_config()
80-
provider = get_provider(config.provider)()
80+
provider = get_provider(app.configuration.provider)(app.configuration)
8181

8282
shell_command = ""
83-
with create_app_session_from_tty(), patch_stdout(raw=True), spinner() as sp:
83+
with create_app_session_from_tty(), patch_stdout(raw=True), spinner(app.configuration.spinner_style) as sp:
8484
async for token in provider.generate(prompt):
8585
# some models may erroneously return a newline, which causes issues with the status spinner
8686
shell_command += token.replace("\n", "")
@@ -89,13 +89,13 @@ async def shelloracle() -> None:
8989
sys.stdout.write(shell_command)
9090

9191

92-
def cli() -> None:
92+
def cli(app: Application) -> None:
9393
"""Run the ShellOracle command line interface
9494
9595
:returns: None
9696
"""
9797
try:
98-
asyncio.run(shelloracle())
98+
asyncio.run(shelloracle(app))
9999
except (KeyboardInterrupt, asyncio.exceptions.CancelledError):
100100
return
101101
except Exception:

‎tests/conftest.py

+3-20
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,8 @@
11
import pytest
2-
import tomlkit
32

4-
from shelloracle.config import Configuration
5-
6-
7-
@pytest.fixture(autouse=True)
8-
def tmp_shelloracle_home(monkeypatch, tmp_path):
9-
monkeypatch.setattr("shelloracle.settings.Settings.shelloracle_home", tmp_path)
10-
return tmp_path
3+
from shelloracle.cli import Application
114

125

136
@pytest.fixture
14-
def set_config(monkeypatch, tmp_shelloracle_home):
15-
config_path = tmp_shelloracle_home / "config.toml"
16-
17-
def _set_config(config: dict) -> None:
18-
with config_path.open("w") as f:
19-
tomlkit.dump(config, f)
20-
configuration = Configuration(config_path)
21-
monkeypatch.setattr("shelloracle.config._config", configuration)
22-
23-
yield _set_config
24-
25-
config_path.unlink()
7+
def global_app():
8+
return Application()

‎tests/providers/test_deepseek.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import pytest
22

3+
from shelloracle.config import Configuration
34
from shelloracle.providers.deepseek import Deepseek
45

56

67
class TestOpenAI:
78
@pytest.fixture
8-
def deepseek_config(self, set_config):
9+
def deepseek_config(self):
910
config = {
1011
"shelloracle": {"provider": "Deepseek"},
1112
"provider": {
@@ -15,11 +16,11 @@ def deepseek_config(self, set_config):
1516
}
1617
},
1718
}
18-
set_config(config)
19+
return Configuration(config)
1920

2021
@pytest.fixture
2122
def deepseek_instance(self, deepseek_config):
22-
return Deepseek()
23+
return Deepseek(deepseek_config)
2324

2425
def test_name(self):
2526
assert Deepseek.name == "Deepseek"

‎tests/providers/test_localai.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@
55

66
class TestOpenAI:
77
@pytest.fixture
8-
def localai_config(self, set_config):
9-
config = {
8+
def localai_config(self):
9+
return {
1010
"shelloracle": {"provider": "LocalAI"},
1111
"provider": {"LocalAI": {"host": "localhost", "port": 8080, "model": "mistral-openorca"}},
1212
}
13-
set_config(config)
1413

1514
@pytest.fixture
1615
def localai_instance(self, localai_config):
17-
return LocalAI()
16+
return LocalAI(localai_config)
1817

1918
def test_name(self):
2019
assert LocalAI.name == "LocalAI"

‎tests/providers/test_ollama.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
import pytest
22
from pytest_httpx import IteratorStream
33

4+
from shelloracle.config import Configuration
45
from shelloracle.providers.ollama import Ollama
56

67

78
class TestOllama:
89
@pytest.fixture
9-
def ollama_config(self, set_config):
10+
def ollama_config(self):
1011
config = {
1112
"shelloracle": {"provider": "Ollama"},
1213
"provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}},
1314
}
14-
set_config(config)
15+
return Configuration(config)
1516

1617
@pytest.fixture
1718
def ollama_instance(self, ollama_config):
18-
return Ollama()
19+
return Ollama(ollama_config)
1920

2021
def test_name(self):
2122
assert Ollama.name == "Ollama"

‎tests/providers/test_openai.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
import pytest
22

3+
from shelloracle.config import Configuration
34
from shelloracle.providers.openai import OpenAI
45

56

67
class TestOpenAI:
78
@pytest.fixture
8-
def openai_config(self, set_config):
9+
def openai_config(self):
910
config = {
1011
"shelloracle": {"provider": "OpenAI"},
1112
"provider": {
1213
"OpenAI": {"api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", "model": "gpt-3.5-turbo"}
1314
},
1415
}
15-
set_config(config)
16+
return Configuration(config)
1617

1718
@pytest.fixture
1819
def openai_instance(self, openai_config):
19-
return OpenAI()
20+
return OpenAI(openai_config)
2021

2122
def test_name(self):
2223
assert OpenAI.name == "OpenAI"

‎tests/providers/test_xai.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import pytest
22

3+
from shelloracle.config import Configuration
34
from shelloracle.providers.xai import XAI
45

56

67
class TestOpenAI:
78
@pytest.fixture
8-
def xai_config(self, set_config):
9+
def xai_config(self):
910
config = {
1011
"shelloracle": {"provider": "XAI"},
1112
"provider": {
@@ -15,11 +16,11 @@ def xai_config(self, set_config):
1516
}
1617
},
1718
}
18-
set_config(config)
19+
return Configuration(config)
1920

2021
@pytest.fixture
2122
def xai_instance(self, xai_config):
22-
return XAI()
23+
return XAI(xai_config)
2324

2425
def test_name(self):
2526
assert XAI.name == "XAI"

‎tests/test_config.py

+37-41
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,59 @@
11
from __future__ import annotations
22

33
import pytest
4+
import tomli_w
45

5-
from shelloracle.config import get_config, initialize_config
6+
from shelloracle.config import Configuration
67

78

89
class TestConfiguration:
910
@pytest.fixture
10-
def default_config(self, set_config):
11-
config = {
12-
"shelloracle": {"provider": "Ollama", "spinner_style": "earth"},
13-
"provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}},
14-
}
15-
set_config(config)
16-
return config
17-
18-
def test_initialize_config(self, default_config):
19-
with pytest.raises(RuntimeError):
20-
initialize_config()
21-
22-
def test_from_file(self, default_config):
23-
assert get_config() == default_config
11+
def default_config(self):
12+
return Configuration(
13+
{
14+
"shelloracle": {"provider": "Ollama", "spinner_style": "earth"},
15+
"provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}},
16+
}
17+
)
18+
19+
def test_from_file(self, default_config, tmp_path):
20+
config_path = tmp_path / "config.toml"
21+
with config_path.open("wb") as f:
22+
tomli_w.dump(default_config.raw_config, f)
23+
assert Configuration.from_file(config_path) == default_config
2424

2525
def test_getitem(self, default_config):
2626
for key in default_config:
27-
assert default_config[key] == get_config()[key]
27+
assert default_config[key] == default_config.raw_config[key]
2828

2929
def test_len(self, default_config):
30-
assert len(default_config) == len(get_config())
30+
assert len(default_config) == len(default_config.raw_config)
3131

3232
def test_iter(self, default_config):
33-
assert list(iter(default_config)) == list(iter(get_config()))
34-
35-
def test_str(self, default_config):
36-
assert str(get_config()) == f"Configuration({default_config})"
37-
38-
def test_repr(self, default_config):
39-
assert repr(default_config) == str(default_config)
33+
assert list(iter(default_config)) == list(iter(default_config.raw_config))
4034

4135
def test_provider(self, default_config):
42-
assert get_config().provider == "Ollama"
36+
assert default_config.provider == "Ollama"
4337

4438
def test_spinner_style(self, default_config):
45-
assert get_config().spinner_style == "earth"
46-
47-
def test_no_spinner_style(self, caplog, set_config):
48-
config_dict = {
49-
"shelloracle": {"provider": "Ollama"},
50-
"provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}},
51-
}
52-
set_config(config_dict)
53-
assert get_config().spinner_style is None
39+
assert default_config.spinner_style == "earth"
40+
41+
def test_no_spinner_style(self, caplog):
42+
config = Configuration(
43+
{
44+
"shelloracle": {"provider": "Ollama"},
45+
"provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}},
46+
}
47+
)
48+
assert config.spinner_style is None
5449
assert "invalid spinner style" not in caplog.text
5550

56-
def test_invalid_spinner_style(self, caplog, set_config):
57-
config_dict = {
58-
"shelloracle": {"provider": "Ollama", "spinner_style": "invalid"},
59-
"provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}},
60-
}
61-
set_config(config_dict)
62-
assert get_config().spinner_style is None
51+
def test_invalid_spinner_style(self, caplog):
52+
config = Configuration(
53+
{
54+
"shelloracle": {"provider": "Ollama", "spinner_style": "invalid"},
55+
"provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}},
56+
}
57+
)
58+
assert config.spinner_style is None
6359
assert "invalid spinner style" in caplog.text

‎tests/test_shelloracle.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,15 @@ def mock_yaspin(monkeypatch):
1717
return mock
1818

1919

20-
@pytest.fixture
21-
def mock_config(monkeypatch):
22-
config = MagicMock()
23-
monkeypatch.setattr("shelloracle.config._config", config)
24-
return config
25-
26-
2720
@pytest.mark.parametrize(("spinner_style", "expected"), [(None, call()), ("earth", call(Spinners.earth))])
28-
def test_spinner(spinner_style, expected, mock_config, mock_yaspin):
29-
mock_config.spinner_style = spinner_style
30-
spinner()
21+
def test_spinner(spinner_style, expected, mock_yaspin):
22+
spinner(spinner_style)
3123
assert mock_yaspin.call_args == expected
3224

3325

34-
def test_spinner_fail(mock_yaspin, mock_config):
35-
mock_config.spinner_style = "not a spinner style"
26+
def test_spinner_fail(mock_yaspin):
3627
with pytest.raises(AttributeError):
37-
spinner()
28+
spinner("not a spinner style")
3829

3930

4031
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)
Please sign in to comment.