#| default_exp core
#| export
import sqlalchemy
from fastcore.utils import *
from fastcore.net import urlsave
from collections import namedtuple
from sqlalchemy import create_engine,text,MetaData,Table,Column,engine,sql
from sqlalchemy.sql.base import ReadOnlyColumnCollection
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.cursor import CursorResult
url = 'https://github.com/lerocha/chinook-database/raw/master/ChinookDatabase/DataSources/Chinook_Sqlite.sqlite'
path = Path('chinook.sqlite')
if not path.exists(): urlsave(url, path)
connstr = f"sqlite:///{path}"
#| export
def conn_db(connstr, **kwargs):
"Connect to DB using `url.URL()` params and return filled `MetaData`"
eng = create_engine(connstr, **kwargs)
conn = eng.connect()
meta = MetaData()
meta.reflect(bind=eng)
meta.bind = eng
meta.conn = conn
return meta
db = conn_db(connstr)
#| export
old_md_dir = MetaData.__dir__
old_cc_dir = ReadOnlyColumnCollection.__dir__
@patch
def __dir__(self:MetaData): return old_md_dir(self) + list(self.tables)
@patch
def __dir__(self:ReadOnlyColumnCollection): return old_cc_dir(self) + self.keys()
def _getattr_(self, n):
if n[0]=='_': raise AttributeError
if n in self.tables: return self.tables[n]
#return super().__getattr__(n)
raise AttributeError
MetaData.__getattr__ = _getattr_
' '.join(db.tables)
'Album Artist Customer Employee Genre Invoice InvoiceLine Track MediaType Playlist PlaylistTrack'
a = db.Album
list(a.c)
[Column('AlbumId', INTEGER(), table=<Album>, primary_key=True, nullable=False), Column('Title', NVARCHAR(length=160), table=<Album>, nullable=False), Column('ArtistId', INTEGER(), ForeignKey('Artist.ArtistId'), table=<Album>, nullable=False)]
#| export
@patch
def tuples(self:CursorResult, nm='Row'):
"Get all results as named tuples"
rs = self.mappings().fetchall()
nt = namedtuple(nm, self.keys())
return [nt(**o) for o in rs]
@patch
def sql(self:Connection, statement, nm='Row', *args, **kwargs):
"Execute `statement` string and return results (if any)"
if isinstance(statement,str): statement=text(statement)
t = self.execute(statement)
return t.tuples()
@patch
def sql(self:MetaData, statement, *args, **kwargs):
"Execute `statement` string and return `DataFrame` of results (if any)"
return self.conn.sql(statement, *args, **kwargs)
rs = db.sql('select AlbumId,Title from Album')
rs[0]
Row(AlbumId=1, Title='For Those About To Rock We Salute You')
#| export
@patch
def get(self:Table, where=None, limit=None):
"Select from table, optionally limited by `where` and `limit` clauses"
return self.metadata.conn.sql(self.select().where(where).limit(limit))
a.get(a.c.Title.startswith('F'), limit=5)
[Row(AlbumId=1, Title='For Those About To Rock We Salute You', ArtistId=1), Row(AlbumId=7, Title='Facelift', ArtistId=5), Row(AlbumId=60, Title='Fireball', ArtistId=58), Row(AlbumId=88, Title='Faceless', ArtistId=87), Row(AlbumId=99, Title='Fear Of The Dark', ArtistId=90)]
This is the query that will run behind the scenes:
print(a.select().where(a.c.Title.startswith('F')).limit(5))
SELECT "Album"."AlbumId", "Album"."Title", "Album"."ArtistId" FROM "Album" WHERE ("Album"."Title" LIKE :Title_1 || '%') LIMIT :param_1
#| export
@patch
def close(self:MetaData):
"Close the connection"
self.conn.close()
db.close()