from sqlalchemy import inspect
import sqlalchemy.sql.sqltypes as Sqltypes
from sqlalchemy.sql.sqltypes import String, Integer, Float, Boolean, Date, DateTime, Text
from sqlalchemy.sql.sqltypes import Enum as SAEnum
from pydantic import BaseModel, create_model, ConfigDict
from typing import Any, Optional, Literal, get_origin, get_args, Union
from specatalog.models.base import TimeStampedModel
import datetime
import textwrap
from enum import Enum
def _type_name_for_doc(typ: type) -> str:
"""
Get the string of a type for automatic documentation. If typ is a
Union[T, NoneType] (as it is the case for the type of optional arguments)
T is returned.
Parameters
----------
typ : type
Any type of an object or variable.
Returns
-------
typ_string : str
The string of the type.
"""
origin = get_origin(typ)
# union types
if origin is Union:
args = [a for a in get_args(typ) if a is not type(None)]
if args:
t = args[0]
typ_string = t.__name__ if hasattr(t, "__name__") else str(t)
return typ_string
# normal types
typ_string = typ.__name__ if hasattr(typ, "__name__") else str(typ)
return typ_string
def _map_sqla_type(sqlatype: Sqltypes) -> type:
"""
Map SQLAlchemy types to Python types.
Parameters
----------
sqlatype : Sqltypes
An SQLAlchemy type.
Returns
-------
type
The Python type that corresponds the SQLAlchemy type.
"""
if isinstance(sqlatype, SAEnum):
return sqlatype.enum_class
if isinstance(sqlatype, (Integer,)):
return int
elif isinstance(sqlatype, (Float,)):
return float
elif isinstance(sqlatype, (Boolean,)):
return bool
elif isinstance(sqlatype, (String, Text)):
return str
elif isinstance(sqlatype, (Date,)):
return datetime.date
elif isinstance(sqlatype, (DateTime,)):
return datetime.datetime
else:
return Any
[docs]
def make_filter_model(model: TimeStampedModel, creation_model: BaseModel) -> BaseModel:
"""
Create dynamically a pydantic-class for filtering based on an SQLAlchemy-
model.
Parameters
----------
model : TimeStampedModel
An SQLAlchemy model of the type TimeStampedModel or any subclass.
creation_model : BaseModel
The creation_model from models.creation_pydantic_measurements/molecules
that corresponds to the class of model. E.g. model = ms.TREPR ->
creation_model = cpm.TREPRModel.
Returns
-------
FilterModel : BaseModel
The filter-model contains all column names as optional fields. The type
of each field is determined by the type of the column. The default
value is None. For numerical fileds additionally fields with comparison
operators and for string-type fields text-comparison operators are
created. When using the model no addtional fields are allowed.
"""
pyd_fields = {name: field.annotation
for name, field in creation_model.model_fields.items()}
mapper = inspect(model) # get all columns of the SQLA-model
# fill fields-dictionary
fields: dict[str, tuple[Any, None]] = {}
for column in mapper.columns:
field_name = column.name
if field_name in pyd_fields:
py_type = pyd_fields[field_name]
else:
py_type = _map_sqla_type(column.type)
# add basisfield (equality) for all columns
fields[field_name] = (Optional[py_type], None)
# add comparison operators for numerical types
if py_type in (int, float) or py_type.__name__ in ("date", "datetime"):
for op in ("gt", "lt", "ge", "le", "ne"):
fields[f"{field_name}__{op}"] = (Optional[py_type], None)
# add comparison operators for strings
if py_type == str:
for op in ("like", "ilike", "contains"):
fields[f"{field_name}__{op}"] = (Optional[str], None)
# create pydantic model from fields-dictionary
name = f"{model.__name__}Filter"
FilterModel = create_model(
name,
__config__=ConfigDict(extra="forbid", validate_assignment=True),
**fields,
)
# *** create docstring ***
operator_lines = [
" gt: greater than",
" lt: less than",
" ge: greater than or equal to",
" le: less than or equal to",
" ne: not equal",
" like: SQL LIKE pattern match",
" ilike: case-insensitive LIKE",
" contains: substring match (for strings)",
]
operator_explanation = "".join(
f"\n\t\t- {op}"
for op in operator_lines
)
field_lines = "".join(
f"\n\t\t- {fname}: {_type_name_for_doc(typ)}"
for fname, (typ, _) in fields.items()
) # field_name + type
FilterModel.__doc__ = textwrap.dedent(
f"""
Pydantic filter model for {model.__name__}.
The following operators can (but do not have to) be applied to the
attributes by appending the operator to the field name, e.g.
temperature__gt=20 (-> temperature > 20):
{operator_explanation}
The following fields can be selected:
{field_lines}
"""
)
FilterModel.model = model # add the original model to the FilterModel
return FilterModel
[docs]
def make_ordering_model(model: TimeStampedModel) -> BaseModel:
"""
Create dynamically a pydantic-class for ordering based on an SQLAlchemy
model.
Parameters
----------
model : TimeStampedModel
An SQLAlchemy model of the type TimeStampedModel or any subclass.
Returns
-------
OrderingModel : BaseModel
The ordering-model contains all column names as optional fields with
the default-value None. The value of the field can only be set to
"asc" (to indicate ascending ordering) or "desc" (to indicate
descending ordering). No additional fields can be created when using
the model.
"""
mapper = inspect(model) # get all columns of the SQLA-model
# fill fields-dictionary
fields: dict[str, tuple[Any, None]] = {}
for column in mapper.columns:
fields[column.key] = (Optional[Literal["asc", "desc"]], None)
# create pydantic model from fields-dictionary
name = f"{model.__name__}Ordering"
OrderingModel = create_model(
name,
**fields,
__config__=ConfigDict(
extra="forbid",
validate_assignment=True
),
)
# *** create docstring ***
field_lines = "".join(
f"\n\t\t- {fname}: {_type_name_for_doc(typ)}"
for fname, (typ, _) in fields.items()
)
OrderingModel.__doc__ = textwrap.dedent(
f"""
Pydantic ordering model for {model.__name__}.
Choose "asc" (for ascending ordering) or "desc" (for descending ordering)
for each attribute that shall be included in the ordering of the results.
The following fields can be selected:
{field_lines}
""")
return OrderingModel
[docs]
def make_update_model(model: TimeStampedModel, creation_model: BaseModel
) -> BaseModel:
"""
Create dynamically a pydantic-class for updating based on an SQLAlchemy
model.
Parameters
----------
model : TimeStampedModel
An SQLAlchemy model of the type TimeStampedModel. Has to be of the
class Molecule or Measurement (or a subclass).
creation_model : BaseModel
The creation_model from models.creation_pydantic_measurements/molecules
that corresponds to the class of model. E.g. model = ms.TREPR ->
creation_model = cpm.TREPRModel.
Raises
------
ValueError
If the model is not of the class Molecule or Measurement an error is
raised.
Returns
-------
UpdateModel : BaseModel
The update-model contains all column names as optional fields with
the default-value None. The value of the field can be set to any
value but must be the same type as the original column type. Fields
that must not be updated are excluded from the UpdateModel.
No additional fields can be created when using the model.
"""
# define fields that must not be updated
if model.__module__ == "specatalog.models.measurements":
exclude_fields = ["id", "molecular_id", "method", "created_at", "updated_at"]
elif model.__module__ == "specatalog.models.molecules":
exclude_fields = ["id", "group", "created_at", "updated_at"]
else:
raise ValueError("Unknown model class")
pyd_fields = {name: field.annotation
for name, field in creation_model.model_fields.items()}
mapper = inspect(model) # get all columns of the SQLA-model
# fill fields-dictionary
fields: dict[str, tuple[Any, None]] = {}
for column in mapper.columns:
field_name = column.name
if field_name in exclude_fields:
continue
if field_name in pyd_fields:
py_type = pyd_fields[field_name]
else:
py_type = _map_sqla_type(column.type)
fields[field_name] = (Optional[py_type], None)
# create pydantic model from fields-dictionary
name = f"{model.__name__}Update"
UpdateModel = create_model(
name,
__config__=ConfigDict(extra="forbid", validate_assignment=True),
**fields,
)
# *** create docstrin ***
field_lines = "".join(
f"\n\t\t- {fname}: {_type_name_for_doc(typ)}"
for fname, (typ, _) in fields.items()
)
UpdateModel.__doc__ = textwrap.dedent(
f"""
Pydantic update model for {model.__name__}. The fields that are set
are the parameters that shall be updateted in the database.
The following fields can be selected:
{field_lines}
""")
UpdateModel.model = model # reference SQLA-model
return UpdateModel
def _enum_to_value(v):
"""
Change the enum instance to the value of the enum-instance. In case the
value is not of enum instance return the value.
"""
if isinstance(v, Enum):
return v.value
return v