mirror of
https://github.com/wowlikon/LiB.git
synced 2026-02-04 04:31:09 +00:00
Динамическое создание er-диаграммы по моделям
This commit is contained in:
@@ -12,6 +12,7 @@ from .captcha import (
|
||||
REDEEM_TTL,
|
||||
prng,
|
||||
)
|
||||
from .describe_er import SchemaGenerator
|
||||
|
||||
__all__ = [
|
||||
"limiter",
|
||||
@@ -26,4 +27,5 @@ __all__ = [
|
||||
"CHALLENGE_TTL",
|
||||
"REDEEM_TTL",
|
||||
"prng",
|
||||
"SchemaGenerator",
|
||||
]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Модуль создания и проверки capjs"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import hashlib
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
"""Модуль генерации описания схемы БД"""
|
||||
|
||||
import inspect
|
||||
from typing import List, Dict, Any, Set, Type, Tuple
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from sqlalchemy.inspection import inspect as sa_inspect
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class SchemaGenerator:
|
||||
"""Сервис генерации json описания схемы БД"""
|
||||
|
||||
def __init__(self, db_module, dto_module=None):
|
||||
self.db_models = self._get_classes(db_module, is_table=True)
|
||||
self.dto_models = (
|
||||
self._get_classes(dto_module, is_table=False) if dto_module else []
|
||||
)
|
||||
self.link_table_names = self._identify_link_tables()
|
||||
self.field_descriptions = self._collect_all_descriptions()
|
||||
self._table_to_model = {m.__tablename__: m for m in self.db_models}
|
||||
|
||||
def _get_classes(
|
||||
self, module, is_table: bool | None = None
|
||||
) -> List[Type[SQLModel]]:
|
||||
if module is None:
|
||||
return []
|
||||
|
||||
classes = []
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, SQLModel)
|
||||
and obj is not SQLModel
|
||||
):
|
||||
if is_table is True and hasattr(obj, "__table__"):
|
||||
classes.append(obj)
|
||||
elif is_table is False and not hasattr(obj, "__table__"):
|
||||
classes.append(obj)
|
||||
return classes
|
||||
|
||||
def _normalize_model_name(self, name: str) -> str:
|
||||
suffixes = [
|
||||
"Create",
|
||||
"Read",
|
||||
"Update",
|
||||
"DTO",
|
||||
"Base",
|
||||
"List",
|
||||
"Detail",
|
||||
"Response",
|
||||
"Request",
|
||||
]
|
||||
result = name
|
||||
for suffix in suffixes:
|
||||
if result.endswith(suffix) and len(result) > len(suffix):
|
||||
result = result[: -len(suffix)]
|
||||
return result
|
||||
|
||||
def _get_field_descriptions_from_class(self, cls: Type) -> Dict[str, str]:
|
||||
descriptions = {}
|
||||
|
||||
for parent in cls.__mro__:
|
||||
if parent is SQLModel or parent is object:
|
||||
continue
|
||||
|
||||
fields = getattr(parent, "model_fields", {})
|
||||
for field_name, field_info in fields.items():
|
||||
if field_name in descriptions:
|
||||
continue
|
||||
|
||||
desc = getattr(field_info, "description", None) or getattr(
|
||||
field_info, "title", None
|
||||
)
|
||||
if desc:
|
||||
descriptions[field_name] = desc
|
||||
|
||||
return descriptions
|
||||
|
||||
def _collect_all_descriptions(self) -> Dict[str, Dict[str, str]]:
|
||||
result = {}
|
||||
|
||||
dto_map = {}
|
||||
for dto in self.dto_models:
|
||||
base_name = self._normalize_model_name(dto.__name__)
|
||||
if base_name not in dto_map:
|
||||
dto_map[base_name] = {}
|
||||
|
||||
for field, desc in self._get_field_descriptions_from_class(dto).items():
|
||||
if field not in dto_map[base_name]:
|
||||
dto_map[base_name][field] = desc
|
||||
|
||||
for model in self.db_models:
|
||||
model_name = model.__name__
|
||||
result[model_name] = {
|
||||
**dto_map.get(model_name, {}),
|
||||
**self._get_field_descriptions_from_class(model),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _identify_link_tables(self) -> Set[str]:
|
||||
link_tables = set()
|
||||
for model in self.db_models:
|
||||
try:
|
||||
for rel in sa_inspect(model).relationships:
|
||||
if rel.secondary is not None:
|
||||
link_tables.add(rel.secondary.name)
|
||||
except Exception:
|
||||
continue
|
||||
return link_tables
|
||||
|
||||
def _collect_fk_relations(self) -> List[Dict[str, Any]]:
|
||||
relations = []
|
||||
processed: Set[Tuple[str, str, str, str]] = set()
|
||||
|
||||
for model in self.db_models:
|
||||
if model.__tablename__ in self.link_table_names:
|
||||
continue
|
||||
|
||||
for col in sa_inspect(model).columns:
|
||||
for fk in col.foreign_keys:
|
||||
target_table = fk.column.table.name
|
||||
if target_table in self.link_table_names:
|
||||
continue
|
||||
|
||||
target_model = self._table_to_model.get(target_table)
|
||||
if not target_model:
|
||||
continue
|
||||
|
||||
key = (
|
||||
model.__name__,
|
||||
col.name,
|
||||
target_model.__name__,
|
||||
fk.column.name,
|
||||
)
|
||||
|
||||
if key not in processed:
|
||||
relations.append(
|
||||
{
|
||||
"fromEntity": model.__name__,
|
||||
"fromField": col.name,
|
||||
"toEntity": target_model.__name__,
|
||||
"toField": fk.column.name,
|
||||
"fromMultiplicity": "N",
|
||||
"toMultiplicity": "1",
|
||||
}
|
||||
)
|
||||
processed.add(key)
|
||||
return relations
|
||||
|
||||
def _collect_m2m_relations(self) -> List[Dict[str, Any]]:
|
||||
relations = []
|
||||
processed: Set[Tuple[str, str]] = set()
|
||||
|
||||
for model in self.db_models:
|
||||
if model.__tablename__ in self.link_table_names:
|
||||
continue
|
||||
|
||||
try:
|
||||
for rel in sa_inspect(model).relationships:
|
||||
if rel.direction.name != "MANYTOMANY":
|
||||
continue
|
||||
|
||||
target_model = rel.mapper.class_
|
||||
if target_model.__tablename__ in self.link_table_names:
|
||||
continue
|
||||
|
||||
pair = tuple(sorted([model.__name__, target_model.__name__]))
|
||||
if pair not in processed:
|
||||
relations.append(
|
||||
{
|
||||
"fromEntity": pair[0],
|
||||
"fromField": "id",
|
||||
"toEntity": pair[1],
|
||||
"toField": "id",
|
||||
"fromMultiplicity": "N",
|
||||
"toMultiplicity": "N",
|
||||
}
|
||||
)
|
||||
processed.add(pair)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return relations
|
||||
|
||||
def generate(self) -> Dict[str, Any]:
|
||||
entities = []
|
||||
|
||||
for model in self.db_models:
|
||||
table_name = model.__tablename__
|
||||
if table_name in self.link_table_names:
|
||||
continue
|
||||
|
||||
columns = sorted(
|
||||
sa_inspect(model).columns,
|
||||
key=lambda c: (
|
||||
0 if c.primary_key else (1 if c.foreign_keys else 2),
|
||||
c.name,
|
||||
),
|
||||
)
|
||||
|
||||
entity_fields = []
|
||||
descriptions = self.field_descriptions.get(model.__name__, {})
|
||||
|
||||
for col in columns:
|
||||
label = col.name
|
||||
if col.primary_key:
|
||||
label += " (PK)"
|
||||
if col.foreign_keys:
|
||||
label += " (FK)"
|
||||
|
||||
field_obj = {"id": col.name, "label": label}
|
||||
|
||||
if col.name in descriptions:
|
||||
field_obj["tooltip"] = descriptions[col.name]
|
||||
|
||||
entity_fields.append(field_obj)
|
||||
|
||||
entities.append(
|
||||
{"id": model.__name__, "title": table_name, "fields": entity_fields}
|
||||
)
|
||||
|
||||
relations = self._collect_fk_relations() + self._collect_m2m_relations()
|
||||
return {"entities": entities, "relations": relations}
|
||||
Reference in New Issue
Block a user