From 42b0e6eace676a0e7ee5eb42bf20206b5fc90219 Mon Sep 17 00:00:00 2001 From: Raphael Gibson <42935757+raphaelgibson@users.noreply.github.com> Date: Sat, 27 Aug 2022 20:49:29 -0300 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Allow=20setting=20`unique`=20in=20`?= =?UTF-8?q?Field()`=20for=20a=20column=20(#83)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Raphael Gibson Co-authored-by: Sebastián Ramírez --- sqlmodel/main.py | 6 +++ tests/test_main.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 tests/test_main.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index bdfe6df..7c79edd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -61,6 +61,7 @@ class FieldInfo(PydanticFieldInfo): primary_key = kwargs.pop("primary_key", False) nullable = kwargs.pop("nullable", Undefined) foreign_key = kwargs.pop("foreign_key", Undefined) + unique = kwargs.pop("unique", False) index = kwargs.pop("index", Undefined) sa_column = kwargs.pop("sa_column", Undefined) sa_column_args = kwargs.pop("sa_column_args", Undefined) @@ -80,6 +81,7 @@ class FieldInfo(PydanticFieldInfo): self.primary_key = primary_key self.nullable = nullable self.foreign_key = foreign_key + self.unique = unique self.index = index self.sa_column = sa_column self.sa_column_args = sa_column_args @@ -141,6 +143,7 @@ def Field( regex: Optional[str] = None, primary_key: bool = False, foreign_key: Optional[Any] = None, + unique: bool = False, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore @@ -171,6 +174,7 @@ def Field( regex=regex, primary_key=primary_key, foreign_key=foreign_key, + unique=unique, nullable=nullable, index=index, sa_column=sa_column, @@ -426,12 +430,14 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore nullable = not primary_key and _is_field_nullable(field) args = [] foreign_key = getattr(field.field_info, "foreign_key", None) + unique = getattr(field.field_info, "unique", False) if foreign_key: args.append(ForeignKey(foreign_key)) kwargs = { "primary_key": primary_key, "nullable": nullable, "index": index, + "unique": unique, } sa_default = Undefined if field.field_info.default_factory: diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..22c6232 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,93 @@ +from typing import Optional + +import pytest +from sqlalchemy.exc import IntegrityError +from sqlmodel import Field, Session, SQLModel, create_engine + + +def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with Session(engine) as session: + session.add(hero_2) + session.commit() + session.refresh(hero_2) + + with Session(engine) as session: + heroes = session.query(Hero).all() + assert len(heroes) == 2 + assert heroes[0].name == heroes[1].name + + +def test_should_allow_duplicate_row_if_unique_constraint_is_false(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str = Field(unique=False) + age: Optional[int] = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with Session(engine) as session: + session.add(hero_2) + session.commit() + session.refresh(hero_2) + + with Session(engine) as session: + heroes = session.query(Hero).all() + assert len(heroes) == 2 + assert heroes[0].name == heroes[1].name + + +def test_should_raise_exception_when_try_to_duplicate_row_if_unique_constraint_is_true( + clear_sqlmodel, +): + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str = Field(unique=True) + age: Optional[int] = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with pytest.raises(IntegrityError): + with Session(engine) as session: + session.add(hero_2) + session.commit() + session.refresh(hero_2)