Skip to content

Commit 7dbc9c4

Browse files
authored
Merge pull request #4 from Abishevs/orm-based-db
Orm based db manager with SQLalchemy
2 parents db2cbed + c5f146d commit 7dbc9c4

File tree

3 files changed

+95
-186
lines changed

3 files changed

+95
-186
lines changed
Lines changed: 50 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,157 +1,67 @@
1-
import sqlite3
21
from typing import List
3-
from tech_cache.models.item import Item
2+
from sqlalchemy import create_engine
3+
from sqlalchemy.orm import sessionmaker
4+
from sqlalchemy.exc import SQLAlchemyError
5+
from tech_cache.models.item import Base, Item
46

57
class DatabaseManager:
68
"""Handles database connection and data persitance
79
Idea is that an instance of this class is passed to
810
models, so that views can then interpret and draw
911
data.
1012
"""
11-
def __init__(self):
12-
self.con: sqlite3.Connection
13-
self.cur: sqlite3.Cursor
14-
self.db_name = "test.db"
15-
self.connect()
13+
def __init__(self, db_url="sqlite:///test.db"):
14+
self.engine = create_engine(db_url, echo=True)
15+
Base.metadata.create_all(self.engine)
16+
self.Session = sessionmaker(bind=self.engine)
1617

17-
def connect(self):
18-
"""Open db connection when you open application"""
19-
self.con = sqlite3.connect(self.db_name)
20-
self.cur = self.con.cursor()
21-
self.init_db()
18+
def get_session(self):
19+
return self.Session()
2220

23-
def close(self):
24-
"""close db before closing application"""
25-
self.con.close()
26-
27-
def init_db(self):
28-
create_table_query = """
29-
CREATE TABLE IF NOT EXISTS items (
30-
id TEXT PRIMARY KEY,
31-
name TEXT,
32-
category TEXT,
33-
quantity INTEGER,
34-
sku TEXT,
35-
specification TEXT,
36-
created_at TEXT,
37-
updated_at TEXT
38-
);
39-
"""
40-
self.execute_query(create_table_query)
41-
42-
def execute_query(self, query:str, data: tuple | None=None):
43-
with self.con:
44-
if data:
45-
self.cur.execute(query, data)
46-
else:
47-
self.cur.execute(query)
48-
49-
def get_item(self, uid):
50-
query = """
51-
SELECT
52-
id,
53-
name,
54-
category,
55-
quantity,
56-
sku,
57-
specification,
58-
created_at,
59-
updated_at
60-
FROM items
61-
WHERE id = ?;
62-
"""
63-
self.execute_query(query, (uid,))
64-
db_item = self.cur.fetchone()
65-
new_item = Item(
66-
id=db_item[0],
67-
name=db_item[1],
68-
category=db_item[2],
69-
quantity=db_item[3],
70-
sku=db_item[4],
71-
specification=db_item[5],
72-
created_at=db_item[6],
73-
updated_at=db_item[7]
74-
)
75-
return new_item
21+
def get_item(self, item_id:int):
22+
with self.get_session() as session:
23+
return session.query(Item).filter(Item.id == item_id).one_or_none()
7624

7725
def add_item(self, item: Item):
78-
query = """
79-
INSERT INTO items (
80-
id,
81-
name,
82-
category,
83-
quantity,
84-
sku,
85-
specification,
86-
created_at,
87-
updated_at
88-
)
89-
VALUES(?, ?, ?, ?, ?, ?, ?, ?);
90-
"""
91-
data = (str(item.id),
92-
item.name,
93-
item.category,
94-
item.quantity,
95-
item.sku,
96-
item.specification,
97-
item.created_at,
98-
item.updated_at,
99-
)
100-
self.execute_query(query, data)
26+
with self.get_session() as session:
27+
try:
28+
session.add(item)
29+
session.commit()
30+
except SQLAlchemyError as e:
31+
session.rollback()
32+
raise e
10133

102-
def update_item(self, item: Item):
103-
query = """
104-
UPDATE items
105-
SET name = ?,
106-
category = ?,
107-
quantity = ?,
108-
sku = ?,
109-
specification = ?,
110-
updated_at = ?
111-
WHERE id = ?;
112-
"""
113-
data = (item.name,
114-
item.category,
115-
item.quantity,
116-
item.sku,
117-
item.specification,
118-
item.updated_at,
34+
def update_item(self, item_id: int):
11935

120-
item.id)
121-
self.execute_query(query, data)
36+
with self.get_session() as session:
37+
try:
38+
item = session.query(Item).filter(Item.id == item_id).one_or_none()
39+
40+
if item:
41+
session.delete(item)
42+
session.commit()
43+
else:
44+
print(f"Item with ID {item_id} not found.")
45+
except SQLAlchemyError as e:
46+
session.rollback()
47+
raise e
12248

123-
def delete_item(self, uid:str):
124-
query = "DELETE FROM items WHERE id = ?;"
125-
self.execute_query(query, (uid,))
49+
def delete_item(self, item_id: int):
50+
session = self.get_session()
51+
with self.get_session() as session:
52+
try:
53+
item = session.query(Item).filter(Item.id == item_id).one_or_none()
54+
55+
if item:
56+
session.delete(item)
57+
session.commit()
58+
else:
59+
print(f"Item with ID {item_id} not found.")
60+
61+
except SQLAlchemyError as e:
62+
session.rollback()
63+
raise e
12664

12765
def get_all_items(self) -> List[Item]:
128-
query = """
129-
SELECT
130-
id,
131-
name,
132-
category,
133-
quantity,
134-
sku,
135-
specification,
136-
created_at,
137-
updated_at
138-
FROM items;
139-
"""
140-
self.execute_query(query)
141-
db_items = self.cur.fetchall()
142-
items = []
143-
for row in db_items:
144-
# follows query parameter order
145-
new_item = Item(
146-
id=row[0],
147-
name=row[1],
148-
category=row[2],
149-
quantity=row[3],
150-
sku=row[4],
151-
specification=row[5],
152-
created_at=row[6],
153-
updated_at=row[7]
154-
)
155-
items.append(new_item)
156-
return items
157-
66+
with self.get_session() as session:
67+
return list(session.query(Item))

src/tech_cache/models/item.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,46 @@
1-
from typing import Any, List
1+
from typing import Any
2+
from typing import List
3+
from typing import Optional
24
from datetime import datetime
3-
from uuid import uuid4
5+
from sqlalchemy import String
6+
from sqlalchemy import DateTime
7+
from sqlalchemy import func
8+
from sqlalchemy import String
9+
from sqlalchemy.orm import DeclarativeBase
10+
from sqlalchemy.orm import Mapped
11+
from sqlalchemy.orm import mapped_column
412

5-
class Item:
6-
"""Represents rows in tableview """
7-
def __init__(self, **kwargs):
8-
self.id:str = str(kwargs.get('id', uuid4())) # for in database data integrity
9-
self._name:str = kwargs.get('name', "")
10-
self.category:str = kwargs.get('category', "")
11-
self._quantity:int = kwargs.get('quantity', 0)
12-
self.sku:str = kwargs.get('sku', "") # TODO: Gen SKU
13-
self.specification:str = kwargs.get('specification', "")
14-
self.created_at:str = kwargs.get('created_at', datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
15-
self.updated_at:str = kwargs.get('updated_at',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
13+
class Base(DeclarativeBase):
14+
pass
1615

17-
@property
18-
def name(self):
19-
"""The name property."""
20-
return self._name
16+
class Item(Base):
17+
__tablename__ = "items"
2118

22-
@name.setter
23-
def name(self, value):
24-
if value == "":
25-
raise ValueError("Mandatory name field can't be empty")
26-
self._name = value
19+
id: Mapped[int] = mapped_column(primary_key=True)
20+
sku: Mapped[str] = mapped_column(String(30))
21+
name: Mapped[str] = mapped_column(String(30))
22+
category: Mapped[str] = mapped_column(String(30))
23+
quantity: Mapped[int] = mapped_column()
24+
specification: Mapped[Optional[str]] = mapped_column(String(255))
25+
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now())
26+
updated_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), onupdate=func.now())
2727

28-
@property
29-
def quantity(self):
30-
"""The quantity property."""
31-
return self._quantity
28+
def __init__(self,
29+
name:str = "",
30+
sku:str = "",
31+
category:str = "",
32+
quantity:int = 0,
33+
specification:str = "",
34+
):
3235

33-
@quantity.setter
34-
def quantity(self, value:int):
35-
if not isinstance(value, int):
36-
raise ValueError("Quantity must be an integer")
36+
self.name = name
37+
self.sku = sku
38+
self.category = category
39+
self.quantity = quantity
40+
self.specification = specification
3741

38-
if value < 0:
39-
raise ValueError("Quantity cannot be negative")
40-
41-
self._quantity = value
42-
43-
def get_fields(self) -> List[Any]:
44-
"""Item fields to display"""
45-
fields = []
46-
fields.append(self.sku)
47-
fields.append(self.name)
48-
fields.append(self.category)
49-
fields.append(self.quantity)
50-
fields.append(self.specification)
51-
return fields
42+
def __repr__(self) -> str:
43+
return f"Item(name={self.name!r}, category={self.category!r}, quantity={self.quantity!r})"
5244

5345
def __getitem__(self, column_index):
5446
if column_index == 0:
@@ -64,4 +56,12 @@ def __getitem__(self, column_index):
6456
else:
6557
raise IndexError("Invalid column index")
6658

67-
59+
def get_fields(self) -> List[Any]:
60+
"""Item fields to display"""
61+
fields = []
62+
fields.append(self.sku)
63+
fields.append(self.name)
64+
fields.append(self.category)
65+
fields.append(self.quantity)
66+
fields.append(self.specification)
67+
return fields

src/tech_cache/views/main_window.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def closeEvent(self, event):
174174

175175
if reply == QtWidgets.QMessageBox.StandardButton.Yes:
176176
# TODO: clean up
177-
self.database.close()
178177
event.accept()
179178
else:
180179
event.ignore()

0 commit comments

Comments
 (0)