Note
Click here to download the full example code
Automatically use a function at insert or selectΒΆ
Sometimes the application wants to apply a function in an insert or in a select.
For example, the application might need the geometry with lat/lon coordinates while they
are projected in the DB. To avoid having to always tweak the query with a
ST_Transform()
, it is possible to define a TypeDecorator
11 import re
12
13 from sqlalchemy import Column
14 from sqlalchemy import Integer
15 from sqlalchemy import MetaData
16 from sqlalchemy import func
17 from sqlalchemy import text
18 from sqlalchemy.ext.declarative import declarative_base
19 from sqlalchemy.types import TypeDecorator
20
21 from geoalchemy2 import Geometry
22 from geoalchemy2 import shape
23
24 # Tests imports
25 from tests import test_only_with_dialects
26
27 metadata = MetaData()
28
29 Base = declarative_base(metadata=metadata)
30
31
32 class TransformedGeometry(TypeDecorator):
33 """This class is used to insert a ST_Transform() in each insert or select."""
34
35 impl = Geometry
36
37 cache_ok = True
38
39 def __init__(self, db_srid, app_srid, **kwargs):
40 kwargs["srid"] = db_srid
41 super().__init__(**kwargs)
42 self.app_srid = app_srid
43 self.db_srid = db_srid
44
45 def column_expression(self, col):
46 """The column_expression() method is overridden to set the correct type.
47
48 This is needed so that the returned element will also be decorated. In this case we don't
49 want to transform it again afterwards so we set the same SRID to both the ``db_srid`` and
50 ``app_srid`` arguments.
51 Without this the SRID of the WKBElement would be wrong.
52 """
53 return getattr(func, self.impl.as_binary)(
54 func.ST_Transform(col, self.app_srid),
55 type_=self.__class__(db_srid=self.app_srid, app_srid=self.app_srid),
56 )
57
58 def bind_expression(self, bindvalue):
59 return func.ST_Transform(
60 self.impl.bind_expression(bindvalue),
61 self.db_srid,
62 type_=self,
63 )
64
65
66 class ThreeDGeometry(TypeDecorator):
67 """This class is used to insert a ST_Force3D() in each insert."""
68
69 impl = Geometry
70
71 cache_ok = True
72
73 def column_expression(self, col):
74 """The column_expression() method is overridden to set the correct type.
75
76 This is not needed in this example but it is needed if one wants to override other methods
77 of the TypeDecorator class, like ``process_result_value()`` for example.
78 """
79 return getattr(func, self.impl.as_binary)(col, type_=self)
80
81 def bind_expression(self, bindvalue):
82 return func.ST_Force3D(
83 self.impl.bind_expression(bindvalue),
84 type=self,
85 )
86
87
88 class Point(Base):
89 __tablename__ = "point"
90 id = Column(Integer, primary_key=True)
91 raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
92 geom = Column(TransformedGeometry(db_srid=2154, app_srid=4326, geometry_type="POINT"))
93 three_d_geom = Column(ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
94
95
96 def check_wkb(wkb, x, y):
97 pt = shape.to_shape(wkb)
98 assert round(pt.x, 5) == x
99 assert round(pt.y, 5) == y
100
101
102 @test_only_with_dialects("postgresql")
103 class TestTypeDecorator:
104 def _create_one_point(self, session, conn):
105 metadata.drop_all(conn, checkfirst=True)
106 metadata.create_all(conn)
107
108 # Create new point instance
109 p = Point()
110 p.raw_geom = "SRID=4326;POINT(5 45)"
111 p.geom = "SRID=4326;POINT(5 45)"
112 p.three_d_geom = "SRID=4326;POINT(5 45)" # Insert 2D geometry into 3D column
113
114 # Insert point
115 session.add(p)
116 session.flush()
117 session.expire(p)
118
119 return p.id
120
121 def test_transform(self, session, conn):
122 self._create_one_point(session, conn)
123
124 # Query the point and check the result
125 pt = session.query(Point).one()
126 assert pt.id == 1
127 assert pt.raw_geom.srid == 4326
128 check_wkb(pt.raw_geom, 5, 45)
129
130 assert pt.geom.srid == 4326
131 check_wkb(pt.geom, 5, 45)
132
133 # Check that the data is correct in DB using raw query
134 q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
135 res_q = session.execute(q).fetchone()
136 assert res_q.id == 1
137 assert re.match(r"SRID=2154;POINT\(857581\.8993196681? 6435414\.7478354\)", res_q.geom)
138
139 # Compare geom, raw_geom with auto transform and explicit transform
140 pt_trans = session.query(
141 Point,
142 Point.raw_geom,
143 func.ST_Transform(Point.raw_geom, 2154).label("trans"),
144 ).one()
145
146 assert pt_trans[0].id == 1
147
148 assert pt_trans[0].geom.srid == 4326
149 check_wkb(pt_trans[0].geom, 5, 45)
150
151 assert pt_trans[0].raw_geom.srid == 4326
152 check_wkb(pt_trans[0].raw_geom, 5, 45)
153
154 assert pt_trans[1].srid == 4326
155 check_wkb(pt_trans[1], 5, 45)
156
157 assert pt_trans[2].srid == 2154
158 check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
159
160 def test_force_3d(self, session, conn):
161 self._create_one_point(session, conn)
162
163 # Query the point and check the result
164 pt = session.query(Point).one()
165
166 assert pt.id == 1
167 assert pt.three_d_geom.srid == 4326
168 assert pt.three_d_geom.desc.lower() == (
169 "01010000a0e6100000000000000000144000000000008046400000000000000000"
170 )