Automatically use a function at insert or selectΒΆ

Sometimes the application wants to apply a function in an insert or in a select. For example, the application might need the geometry with lat/lon coordinates while they are projected in the DB. To avoid having to always tweak the query with a ST_Transform(), it is possible to define a TypeDecorator

 11 from sqlalchemy import create_engine
 12 from sqlalchemy import MetaData
 13 from sqlalchemy import Column
 14 from sqlalchemy import Integer
 15 from sqlalchemy import func
 16 from sqlalchemy.ext.declarative import declarative_base
 17 from sqlalchemy.orm import sessionmaker
 18 from sqlalchemy.types import TypeDecorator
 19
 20 from geoalchemy2 import Geometry
 21 from geoalchemy2 import shape
 22
 23
 24 engine = create_engine('postgresql://gis:gis@localhost/gis', echo=True)
 25 metadata = MetaData(engine)
 26
 27 Base = declarative_base(metadata=metadata)
 28
 29
 30 class TransformedGeometry(TypeDecorator):
 31     """This class is used to insert a ST_Transform() in each insert or select."""
 32     impl = Geometry
 33
 34     def __init__(self, db_srid, app_srid, **kwargs):
 35         kwargs["srid"] = db_srid
 36         self.impl = self.__class__.impl(**kwargs)
 37         self.app_srid = app_srid
 38         self.db_srid = db_srid
 39
 40     def column_expression(self, col):
 41         """The column_expression() method is overrided to ensure that the
 42         SRID of the resulting WKBElement is correct"""
 43         return getattr(func, self.impl.as_binary)(
 44             func.ST_Transform(col, self.app_srid),
 45             type_=self.__class__.impl(srid=self.app_srid)
 46             # srid could also be -1 so that the SRID is deduced from the
 47             # WKB data
 48         )
 49
 50     def bind_expression(self, bindvalue):
 51         return func.ST_Transform(
 52             self.impl.bind_expression(bindvalue), self.db_srid)
 53
 54
 55 class ThreeDGeometry(TypeDecorator):
 56     """This class is used to insert a ST_Force3D() in each insert."""
 57     impl = Geometry
 58
 59     def bind_expression(self, bindvalue):
 60         return func.ST_Force3D(self.impl.bind_expression(bindvalue))
 61
 62
 63 class Point(Base):
 64     __tablename__ = "point"
 65     id = Column(Integer, primary_key=True)
 66     raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
 67     geom = Column(
 68         TransformedGeometry(
 69             db_srid=2154, app_srid=4326, geometry_type="POINT"))
 70     three_d_geom = Column(
 71         ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
 72
 73
 74 session = sessionmaker(bind=engine)()
 75
 76
 77 def check_wkb(wkb, x, y):
 78     pt = shape.to_shape(wkb)
 79     assert round(pt.x, 5) == x
 80     assert round(pt.y, 5) == y
 81
 82
 83 class TestTypeDecorator():
 84
 85     def setup(self):
 86         metadata.drop_all(checkfirst=True)
 87         metadata.create_all()
 88
 89     def teardown(self):
 90         session.rollback()
 91         metadata.drop_all()
 92
 93     def _create_one_point(self):
 94         # Create new point instance
 95         p = Point()
 96         p.raw_geom = "SRID=4326;POINT(5 45)"
 97         p.geom = "SRID=4326;POINT(5 45)"
 98         p.three_d_geom = "SRID=4326;POINT(5 45)"  # Insert 2D geometry into 3D column
 99
100         # Insert point
101         session.add(p)
102         session.flush()
103         session.expire(p)
104
105         return p.id
106
107     def test_transform(self):
108         self._create_one_point()
109
110         # Query the point and check the result
111         pt = session.query(Point).one()
112         assert pt.id == 1
113         assert pt.raw_geom.srid == 4326
114         check_wkb(pt.raw_geom, 5, 45)
115
116         assert pt.geom.srid == 4326
117         check_wkb(pt.geom, 5, 45)
118
119         # Check that the data is correct in DB using raw query
120         q = "SELECT id, ST_AsEWKT(geom) AS geom FROM point;"
121         res_q = session.execute(q).fetchone()
122         assert res_q.id == 1
123         assert res_q.geom == "SRID=2154;POINT(857581.899319668 6435414.7478354)"
124
125         # Compare geom, raw_geom with auto transform and explicit transform
126         pt_trans = session.query(
127             Point,
128             Point.raw_geom,
129             func.ST_Transform(Point.raw_geom, 2154).label("trans")
130         ).one()
131
132         assert pt_trans[0].id == 1
133
134         assert pt_trans[0].geom.srid == 4326
135         check_wkb(pt_trans[0].geom, 5, 45)
136
137         assert pt_trans[0].raw_geom.srid == 4326
138         check_wkb(pt_trans[0].raw_geom, 5, 45)
139
140         assert pt_trans[1].srid == 4326
141         check_wkb(pt_trans[1], 5, 45)
142
143         assert pt_trans[2].srid == 2154
144         check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
145
146     def test_force_3d(self):
147         self._create_one_point()
148
149         # Query the point and check the result
150         pt = session.query(Point).one()
151
152         assert pt.id == 1
153         assert pt.three_d_geom.srid == 4326
154         assert pt.three_d_geom.desc.lower() == (
155             '01010000a0e6100000000000000000144000000000008046400000000000000000')

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery