Skip to content

Commit 28cb58e

Browse files
committed
querysets can be passed to model object completers, fixes #96
1 parent bb20629 commit 28cb58e

File tree

8 files changed

+148
-21
lines changed

8 files changed

+148
-21
lines changed

django_typer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
model_parser_completer, # noqa: F401
4848
)
4949

50-
VERSION = (2, 1, 3)
50+
VERSION = (2, 2, 0)
5151

5252
__title__ = "Django Typer"
5353
__version__ = ".".join(str(i) for i in VERSION)

django_typer/completers.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# pylint: disable=line-too-long
1818

19+
import inspect
1920
import os
2021
import pkgutil
2122
import sys
@@ -38,12 +39,15 @@
3839
FloatField,
3940
GenericIPAddressField,
4041
IntegerField,
42+
Manager,
4143
Max,
4244
Model,
4345
Q,
4446
TextField,
4547
UUIDField,
4648
)
49+
from django.db.models.query import QuerySet
50+
from django.utils.translation import gettext as _
4751

4852
Completer = t.Callable[[Context, Parameter, str], t.List[CompletionItem]]
4953
Strings = t.Union[t.Sequence[str], t.KeysView[str], t.Generator[str, None, None]]
@@ -107,7 +111,7 @@ def handle(
107111
function that returns a configured parser and completer for a model object
108112
and helps reduce boilerplate.
109113
110-
:param model_cls: The Django model class to query.
114+
:param model_or_qry: The Django model class or a queryset to filter against.
111115
:param lookup_field: The name of the model field to use for lookup.
112116
:param help_field: The name of the model field to use for help text or None if
113117
no help text should be provided.
@@ -130,6 +134,7 @@ def handle(
130134
QueryBuilder = t.Callable[["ModelObjectCompleter", Context, Parameter, str], Q]
131135

132136
model_cls: t.Type[Model]
137+
_queryset: t.Optional[QuerySet] = None
133138
lookup_field: str
134139
help_field: t.Optional[str] = None
135140
query: t.Callable[[Context, Parameter, str], Q]
@@ -144,6 +149,10 @@ def handle(
144149

145150
_field: Field
146151

152+
@property
153+
def queryset(self) -> t.Union[QuerySet, Manager[Model]]:
154+
return self._queryset or self.model_cls.objects
155+
147156
def to_str(self, obj: t.Any) -> str:
148157
return str(obj)
149158

@@ -253,7 +262,11 @@ def uuid_query(self, context: Context, parameter: Parameter, incomplete: str) ->
253262
self._offset += 1
254263

255264
if len(uuid) > 32:
256-
raise ValueError(f"Too many UUID characters: {incomplete}")
265+
raise ValueError(
266+
_("Too many UUID characters: {incomplete}").format(
267+
incomplete=incomplete
268+
)
269+
)
257270
min_uuid = UUID(uuid + "0" * (32 - len(uuid)))
258271
max_uuid = UUID(uuid + "f" * (32 - len(uuid)))
259272
return Q(**{f"{self.lookup_field}__gte": min_uuid}) & Q(
@@ -262,15 +275,23 @@ def uuid_query(self, context: Context, parameter: Parameter, incomplete: str) ->
262275

263276
def __init__(
264277
self,
265-
model_cls: t.Type[Model],
278+
model_or_qry: t.Union[t.Type[Model], QuerySet],
266279
lookup_field: t.Optional[str] = None,
267280
help_field: t.Optional[str] = help_field,
268281
query: t.Optional[QueryBuilder] = None,
269282
limit: t.Optional[int] = limit,
270283
case_insensitive: bool = case_insensitive,
271284
distinct: bool = distinct,
272285
):
273-
self.model_cls = model_cls
286+
if inspect.isclass(model_or_qry) and issubclass(model_or_qry, Model):
287+
self.model_cls = model_or_qry
288+
elif isinstance(model_or_qry, QuerySet): # type: ignore
289+
self.model_cls = model_or_qry.model
290+
self._queryset = model_or_qry
291+
else:
292+
raise ValueError(
293+
_("ModelObjectCompleter requires a Django model class or queryset.")
294+
)
274295
self.lookup_field = str(
275296
lookup_field or getattr(self.model_cls._meta.pk, "name", "id")
276297
)
@@ -295,7 +316,9 @@ def __init__(
295316
self.query = self.float_query
296317
else:
297318
raise ValueError(
298-
f"Unsupported lookup field class: {self._field.__class__.__name__}"
319+
_("Unsupported lookup field class: {cls}").format(
320+
cls=self._field.__class__.__name__
321+
)
299322
)
300323

301324
def __call__(
@@ -343,9 +366,7 @@ def __call__(
343366
],
344367
help=getattr(obj, self.help_field, None) if self.help_field else "",
345368
)
346-
for obj in getattr(self.model_cls, "objects")
347-
.filter(completion_qry)
348-
.distinct()[0 : self.limit]
369+
for obj in self.queryset.filter(completion_qry).distinct()[0 : self.limit]
349370
if (
350371
getattr(obj, self.lookup_field) is not None
351372
and self.to_str(getattr(obj, self.lookup_field))

django_typer/management/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from django.core.management.base import OutputWrapper as BaseOutputWrapper
1616
from django.core.management.color import Style as ColorStyle
1717
from django.db.models import Model
18+
from django.db.models.query import QuerySet
1819
from django.utils.functional import Promise, classproperty
1920
from django.utils.translation import gettext as _
2021

@@ -109,7 +110,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
109110

110111

111112
def model_parser_completer(
112-
model_cls: t.Type[Model],
113+
model_or_qry: t.Union[t.Type[Model], QuerySet],
113114
lookup_field: t.Optional[str] = None,
114115
case_insensitive: bool = False,
115116
help_field: t.Optional[str] = ModelObjectCompleter.help_field,
@@ -139,7 +140,7 @@ def handle(
139140
...
140141
141142
142-
:param model_cls: the model class to use for lookup
143+
:param model_or_qry: the model class or QuerySet to use for lookup
143144
:param lookup_field: the field to use for lookup, by default the primary key
144145
:param case_insensitive: whether to perform case insensitive lookups and
145146
completions, default: False
@@ -155,13 +156,13 @@ def handle(
155156
"""
156157
return {
157158
"parser": ModelObjectParser(
158-
model_cls,
159+
model_or_qry if inspect.isclass(model_or_qry) else model_or_qry.model, # type: ignore
159160
lookup_field,
160161
case_insensitive=case_insensitive,
161162
on_error=on_error,
162163
),
163164
"shell_complete": ModelObjectCompleter(
164-
model_cls,
165+
model_or_qry,
165166
lookup_field,
166167
case_insensitive=case_insensitive,
167168
help_field=help_field,

doc/source/changelog.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
Change Log
33
==========
44

5+
v2.2.0 (26-JUL-2024)
6+
====================
7+
8+
* Implemented `ModelObjectCompleter should optionally accept a QuerySet in place of a Model class. <https://github.com/bckohan/django-typer/issues/96>`_
9+
510
v2.1.3 (15-JUL-2024)
611
====================
712

doc/source/shell_completion.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ Model Objects
392392

393393
* completer: :class:`~django_typer.completers.ModelObjectCompleter`
394394
* parser: :class:`~django_typer.parsers.ModelObjectParser`
395-
* convenience: :func:`~django_typer.model_parser_completer`
395+
* convenience: :func:`~django_typer.management.model_parser_completer`
396396

397397
This completer/parser pairing provides the ability to fetch a model object from one of its fields.
398398
Most field types are supported. Additionally any other field can be set as the help text that some
@@ -419,7 +419,7 @@ shells support. Refer to the reference documentation and the
419419
ModelClass,
420420
typer.Argument(
421421
**model_parser_completer(
422-
ModelClass,
422+
ModelClass, # may also accept a QuerySet for pre-filtering
423423
'field_name', # the field that should be matched (defaults to id)
424424
help_field='other_field' # optionally provide some additional help text
425425
),

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "django-typer"
3-
version = "2.1.3"
3+
version = "2.2.0"
44
description = "Use Typer to define the CLI for your Django management commands."
55
authors = ["Brian Kohan <[email protected]>"]
66
license = "MIT"

tests/apps/test_app/management/commands/model_fields.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import typer
1212
from django.utils.translation import gettext_lazy as _
13+
from django.db.models import Q
1314

1415
from django_typer.management import (
1516
TyperCommand,
@@ -55,7 +56,7 @@ def test(
5556
),
5657
] = None,
5758
text: Annotated[
58-
t.List[ShellCompleteTester],
59+
t.Optional[t.List[ShellCompleteTester]],
5960
typer.Option(
6061
**model_parser_completer(
6162
ShellCompleteTester,
@@ -66,7 +67,7 @@ def test(
6667
),
6768
] = None,
6869
itext: Annotated[
69-
t.List[ShellCompleteTester],
70+
t.Optional[t.List[ShellCompleteTester]],
7071
typer.Option(
7172
**model_parser_completer(
7273
ShellCompleteTester,
@@ -113,26 +114,41 @@ def test(
113114
),
114115
] = None,
115116
ip: Annotated[
116-
t.List[ShellCompleteTester],
117+
t.Optional[t.List[ShellCompleteTester]],
117118
typer.Option(
118119
**model_parser_completer(ShellCompleteTester, "ip_field"),
119120
help=_("Fetch objects by their IP address fields."),
120121
),
121122
] = None,
122123
email: Annotated[
123-
t.List[ShellCompleteTester],
124+
t.Optional[t.List[ShellCompleteTester]],
124125
typer.Option(
125126
**model_parser_completer(ShellCompleteTester, "email_field"),
126127
help=_("Fetch objects by their email fields."),
127128
),
128129
] = None,
129130
url: Annotated[
130-
t.List[ShellCompleteTester],
131+
t.Optional[t.List[ShellCompleteTester]],
131132
typer.Option(
132133
**model_parser_completer(ShellCompleteTester, "url_field"),
133134
help=_("Fetch objects by their url fields."),
134135
),
135136
] = None,
137+
filtered: Annotated[
138+
t.Optional[t.List[ShellCompleteTester]],
139+
typer.Option(
140+
**model_parser_completer(
141+
ShellCompleteTester.objects.filter(
142+
~(
143+
Q(text_field__istartswith="a")
144+
| Q(text_field__istartswith="s")
145+
)
146+
),
147+
"text_field",
148+
),
149+
help=_("Fetch objects by their text fields."),
150+
),
151+
] = None,
136152
):
137153
assert self.__class__ is Command
138154
objects = {}
@@ -166,4 +182,8 @@ def test(
166182
for addr in ip:
167183
assert isinstance(addr, ShellCompleteTester)
168184
objects["ip"] = [{addr.id: addr.ip_field} for addr in ip]
185+
if filtered is not None:
186+
for txt in filtered:
187+
assert isinstance(txt, ShellCompleteTester)
188+
objects["filtered"] = [{txt.id: txt.text_field} for txt in filtered]
169189
return json.dumps(objects)

tests/test_parser_completers.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,77 @@ def test_text_field(self):
468468
},
469469
)
470470

471+
def test_filtered_text_field(self):
472+
result = StringIO()
473+
with contextlib.redirect_stdout(result):
474+
call_command("shellcompletion", "complete", "model_fields test --filtered ")
475+
result = result.getvalue()
476+
self.assertFalse("sockeye" in result)
477+
self.assertTrue("chinook" in result)
478+
self.assertFalse("steelhead" in result)
479+
self.assertTrue("coho" in result)
480+
self.assertFalse("atlantic" in result)
481+
self.assertTrue("pink" in result)
482+
self.assertTrue("chum" in result)
483+
484+
result = StringIO()
485+
with contextlib.redirect_stdout(result):
486+
call_command(
487+
"shellcompletion", "complete", "model_fields test --filtered ch"
488+
)
489+
result = result.getvalue()
490+
self.assertFalse("sockeye" in result)
491+
self.assertTrue("chinook" in result)
492+
self.assertFalse("steelhead" in result)
493+
self.assertFalse("coho" in result)
494+
self.assertFalse("atlantic" in result)
495+
self.assertFalse("pink" in result)
496+
self.assertTrue("chum" in result)
497+
498+
# distinct completions by default
499+
result = StringIO()
500+
with contextlib.redirect_stdout(result):
501+
call_command(
502+
"shellcompletion",
503+
"complete",
504+
"model_fields test --filtered coho --filtered chinook --filtered ",
505+
)
506+
result = result.getvalue()
507+
self.assertFalse("sockeye" in result)
508+
self.assertFalse("chinook" in result)
509+
self.assertFalse("steelhead" in result)
510+
self.assertFalse("coho" in result)
511+
self.assertFalse("atlantic" in result)
512+
self.assertTrue("pink" in result)
513+
self.assertTrue("chum" in result)
514+
515+
self.assertEqual(
516+
json.loads(
517+
call_command(
518+
"model_fields",
519+
"test",
520+
"--filtered",
521+
"coho",
522+
"--filtered",
523+
"chinook",
524+
)
525+
),
526+
{
527+
"filtered": [
528+
{
529+
str(
530+
ShellCompleteTester.objects.get(text_field="coho").pk
531+
): "coho"
532+
},
533+
{
534+
str(
535+
ShellCompleteTester.objects.get(text_field="chinook").pk
536+
): "chinook"
537+
},
538+
]
539+
},
540+
)
541+
471542
def test_uuid_field(self):
472543
from uuid import UUID
473544

@@ -1387,3 +1458,12 @@ def test_databases_completer(self):
13871458
].strip()
13881459

13891460
self.assertTrue("default" in result)
1461+
1462+
def test_model_completer_argument_test(self):
1463+
from django_typer.completers import ModelObjectCompleter
1464+
1465+
class NotAModel:
1466+
pass
1467+
1468+
with self.assertRaises(ValueError):
1469+
ModelObjectCompleter(NotAModel, "char_field", "test")

0 commit comments

Comments
 (0)