Automatically use a function at insert or selectΒΆ

Sometimes the application wants to apply a function in an insert or in a select. For example, the application might need the geometry with lat/lon coordinates while they are projected in the DB. To avoid having to always tweak the query with a ST_Transform(), it is possible to define a TypeDecorator

 11 import re
 12
 13 from sqlalchemy import Column
 14 from sqlalchemy import Integer
 15 from sqlalchemy import MetaData
 16 from sqlalchemy import func
 17 from sqlalchemy import text
 18 from sqlalchemy.ext.declarative import declarative_base
 19 from sqlalchemy.types import TypeDecorator
 20
 21 from geoalchemy2 import Geometry
 22 from geoalchemy2 import shape
 23
 24 # Tests imports
 25 from tests import test_only_with_dialects
 26
 27 metadata = MetaData()
 28
 29 Base = declarative_base(metadata=metadata)
 30
 31
 32 class TransformedGeometry(TypeDecorator):
 33     """This class is used to insert a ST_Transform() in each insert or select."""
 34
 35     impl = Geometry
 36
 37     cache_ok = True
 38
 39     def __init__(self, db_srid, app_srid, **kwargs):
 40         kwargs["srid"] = db_srid
 41         super().__init__(**kwargs)
 42         self.app_srid = app_srid
 43         self.db_srid = db_srid
 44
 45     def column_expression(self, col):
 46         """The column_expression() method is overridden to set the correct type.
 47
 48         This is needed so that the returned element will also be decorated. In this case we don't
 49         want to transform it again afterwards so we set the same SRID to both the ``db_srid`` and
 50         ``app_srid`` arguments.
 51         Without this the SRID of the WKBElement would be wrong.
 52         """
 53         return getattr(func, self.impl.as_binary)(
 54             func.ST_Transform(col, self.app_srid),
 55             type_=self.__class__(db_srid=self.app_srid, app_srid=self.app_srid),
 56         )
 57
 58     def bind_expression(self, bindvalue):
 59         return func.ST_Transform(
 60             self.impl.bind_expression(bindvalue),
 61             self.db_srid,
 62             type_=self,
 63         )
 64
 65
 66 class ThreeDGeometry(TypeDecorator):
 67     """This class is used to insert a ST_Force3D() in each insert."""
 68
 69     impl = Geometry
 70
 71     cache_ok = True
 72
 73     def column_expression(self, col):
 74         """The column_expression() method is overridden to set the correct type.
 75
 76         This is not needed in this example but it is needed if one wants to override other methods
 77         of the TypeDecorator class, like ``process_result_value()`` for example.
 78         """
 79         return getattr(func, self.impl.as_binary)(col, type_=self)
 80
 81     def bind_expression(self, bindvalue):
 82         return func.ST_Force3D(
 83             self.impl.bind_expression(bindvalue),
 84             type=self,
 85         )
 86
 87
 88 class Point(Base):
 89     __tablename__ = "point"
 90     id = Column(Integer, primary_key=True)
 91     raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
 92     geom = Column(TransformedGeometry(db_srid=2154, app_srid=4326, geometry_type="POINT"))
 93     three_d_geom = Column(ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
 94
 95
 96 def check_wkb(wkb, x, y):
 97     pt = shape.to_shape(wkb)
 98     assert round(pt.x, 5) == x
 99     assert round(pt.y, 5) == y
100
101
102 @test_only_with_dialects("postgresql")
103 class TestTypeDecorator:
104     def _create_one_point(self, session, conn):
105         metadata.drop_all(conn, checkfirst=True)
106         metadata.create_all(conn)
107
108         # Create new point instance
109         p = Point()
110         p.raw_geom = "SRID=4326;POINT(5 45)"
111         p.geom = "SRID=4326;POINT(5 45)"
112         p.three_d_geom = "SRID=4326;POINT(5 45)"  # Insert 2D geometry into 3D column
113
114         # Insert point
115         session.add(p)
116         session.flush()
117         session.expire(p)
118
119         return p.id
120
121     def test_transform(self, session, conn):
122         self._create_one_point(session, conn)
123
124         # Query the point and check the result
125         pt = session.query(Point).one()
126         assert pt.id == 1
127         assert pt.raw_geom.srid == 4326
128         check_wkb(pt.raw_geom, 5, 45)
129
130         assert pt.geom.srid == 4326
131         check_wkb(pt.geom, 5, 45)
132
133         # Check that the data is correct in DB using raw query
134         q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
135         res_q = session.execute(q).fetchone()
136         assert res_q.id == 1
137         assert re.match(r"SRID=2154;POINT\(857581\.8993196681? 6435414\.7478354\)", res_q.geom)
138
139         # Compare geom, raw_geom with auto transform and explicit transform
140         pt_trans = session.query(
141             Point,
142             Point.raw_geom,
143             func.ST_Transform(Point.raw_geom, 2154).label("trans"),
144         ).one()
145
146         assert pt_trans[0].id == 1
147
148         assert pt_trans[0].geom.srid == 4326
149         check_wkb(pt_trans[0].geom, 5, 45)
150
151         assert pt_trans[0].raw_geom.srid == 4326
152         check_wkb(pt_trans[0].raw_geom, 5, 45)
153
154         assert pt_trans[1].srid == 4326
155         check_wkb(pt_trans[1], 5, 45)
156
157         assert pt_trans[2].srid == 2154
158         check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
159
160     def test_force_3d(self, session, conn):
161         self._create_one_point(session, conn)
162
163         # Query the point and check the result
164         pt = session.query(Point).one()
165
166         assert pt.id == 1
167         assert pt.three_d_geom.srid == 4326
168         assert pt.three_d_geom.desc.lower() == (
169             "01010000a0e6100000000000000000144000000000008046400000000000000000"
170         )

Gallery generated by Sphinx-Gallery