Files
task_flow/app/storage/repository.py

168 lines
4.9 KiB
Python

from __future__ import annotations
import json
import shutil
from pathlib import Path
from threading import RLock
from typing import Iterable
from uuid import UUID
from app.storage.models import DATA_FORMAT_VERSION, StoragePayloadV1, StoredTask, upgrade_payload
class StorageError(Exception):
"""Base storage error."""
class StorageTaskNotFoundError(StorageError):
"""Task was not found in storage."""
class JsonFileTaskRepository:
"""
File-based repository with:
- versioned payload format
- atomic writes via temporary file + replace
- corruption handling with backup + reset
"""
def __init__(self, file_path: str | Path) -> None:
self._file_path = Path(file_path)
self._lock = RLock()
self._ensure_parent_dir()
self._ensure_storage_exists()
@property
def data_format_version(self) -> int:
return DATA_FORMAT_VERSION
def list_tasks(self) -> list[StoredTask]:
with self._lock:
payload = self._load_payload()
return list(payload.tasks)
def get_task(self, task_id: UUID) -> StoredTask | None:
with self._lock:
payload = self._load_payload()
for task in payload.tasks:
if task.id == task_id:
return task
return None
def create_task(self, task: StoredTask) -> StoredTask:
with self._lock:
payload = self._load_payload()
payload.tasks.append(task)
self._save_payload(payload)
return task
def update_task(self, task: StoredTask) -> StoredTask:
with self._lock:
payload = self._load_payload()
updated = False
new_tasks: list[StoredTask] = []
for existing_task in payload.tasks:
if existing_task.id == task.id:
new_tasks.append(task)
updated = True
else:
new_tasks.append(existing_task)
if not updated:
raise StorageTaskNotFoundError(f"Task {task.id} not found")
payload.tasks = new_tasks
self._save_payload(payload)
return task
def delete_task(self, task_id: UUID) -> bool:
with self._lock:
payload = self._load_payload()
initial_count = len(payload.tasks)
payload.tasks = [task for task in payload.tasks if task.id != task_id]
if len(payload.tasks) == initial_count:
return False
self._save_payload(payload)
return True
def replace_all(self, tasks: Iterable[StoredTask]) -> None:
with self._lock:
payload = StoragePayloadV1(
version=DATA_FORMAT_VERSION,
tasks=list(tasks),
)
self._save_payload(payload)
def _ensure_parent_dir(self) -> None:
self._file_path.parent.mkdir(parents=True, exist_ok=True)
def _ensure_storage_exists(self) -> None:
if self._file_path.exists():
return
self._save_payload(
StoragePayloadV1(
version=DATA_FORMAT_VERSION,
tasks=[],
),
)
def _load_payload(self) -> StoragePayloadV1:
if not self._file_path.exists():
payload = StoragePayloadV1(
version=DATA_FORMAT_VERSION,
tasks=[],
)
self._save_payload(payload)
return payload
try:
raw_text = self._file_path.read_text(encoding="utf-8")
raw_data = json.loads(raw_text)
return upgrade_payload(raw_data)
except Exception:
self._backup_corrupted_file()
reset_payload = StoragePayloadV1(
version=DATA_FORMAT_VERSION,
tasks=[],
)
self._save_payload(reset_payload)
return reset_payload
def _save_payload(self, payload: StoragePayloadV1) -> None:
tmp_path = self._file_path.with_name(f"{self._file_path.name}.tmp")
serialized = json.dumps(
payload.model_dump(mode="json"),
ensure_ascii=False,
indent=2,
)
tmp_path.write_text(serialized + "\n", encoding="utf-8")
tmp_path.replace(self._file_path)
def _backup_corrupted_file(self) -> None:
if not self._file_path.exists():
return
backup_path = self._next_backup_path()
shutil.copy2(self._file_path, backup_path)
def _next_backup_path(self) -> Path:
base_name = f"{self._file_path.name}.corrupted"
candidate = self._file_path.with_name(base_name)
if not candidate.exists():
return candidate
index = 1
while True:
candidate = self._file_path.with_name(f"{base_name}.{index}")
if not candidate.exists():
return candidate
index += 1