1 /**
2 A SQLite driver for EzDb.
3 */
4 module ezdb.driver.sqlite;
5 
6 import ezdb.repository;
7 import ezdb.entity;
8 import ezdb.foreign;
9 import ezdb.query;
10 
11 import d2sqlite3;
12 import optional;
13 
14 import std.conv;
15 import std.stdio;
16 import std.range;
17 import std.algorithm;
18 import std.traits;
19 import std.exception;
20 version(unittest) import fluent.asserts;
21 
22 /**
23 The strategy used to create tables.
24 */
25 enum DDLStrategy
26 {
27     /// Creates the table if one doesn't exist.
28     /// Doesn't do anything else.
29     create,
30 
31     /// Drops a table if one exists, and then recreates it.
32     drop_create,
33 }
34 
35 private template GetIdColumn(Entity)
36 {
37     enum GetIdColumn = getSymbolsByUDA!(Entity, primaryKey)[0].stringof;
38 }
39 
40 /**
41 A factory that can create SQLite databases.
42 */
43 final class SqliteFactory
44 {
45     private Database _db;
46     private int _openConnections = 0;
47     private bool _open = true;
48 
49     /**
50     Creates a new SQLite factory.
51     */
52     this(string filename = "sqlite.db")
53     {
54         _db = Database(filename);
55         _db.execute("PRAGMA foreign_keys = ON;");
56     }
57 
58     /**
59     Returns `true` if the factory has been fully closed, `false` if it is still
60     possible to open new repositories.
61     */
62     bool isClosed()
63     {
64         return !_open;
65     }
66 
67     /**
68     Opens a connection to a SQLite database.
69     */
70     auto open(Repository)()
71     in (isClosed == false)
72     {
73         _openConnections++;
74         return new SqliteDriver!Repository(this);
75     }
76 
77     /**
78     Attempts to close the database, if it is no longer being used.
79     */
80     private void close()
81     {
82         _openConnections--;
83         if (_openConnections <= 0)
84         {
85             _open = false;
86             _db.close();
87         }
88     }
89 
90     /**
91     Gets a reference to the Sqlite database.
92     */
93     private ref Database db()
94     {
95         return _db;
96     }
97 }
98 
99 /**
100 Implements a repository using a Sqlite database.
101 */
102 final class SqliteDriver(Db : Repository!Entity, Entity) : Db
103 {
104     private enum Table = Entity.stringof;
105     private enum IdColumn = GetIdColumn!Entity;
106     private SqliteFactory _factory;
107     private immutable DDLStrategy _strategy;
108 
109     /**
110     Creates a SQLite database.
111     Params:
112       filename = The name of the file used to store the database.
113     */
114     this(SqliteFactory factory, DDLStrategy strategy = DDLStrategy.create)
115     {
116         _strategy = strategy;
117         _factory = factory;
118 
119         final switch (strategy)
120         {
121             case DDLStrategy.drop_create:
122                 dropTable();
123                 createTable();
124                 break;
125             case DDLStrategy.create:
126                 createTable();
127         }
128     }
129 
130     private void dropTable()
131     {
132         _factory.db.run(text("DROP TABLE IF EXISTS ", Table));
133     }
134 
135     private void createTable()
136     {
137         string statement = CreationStatement!Entity;
138         _factory.db.run(statement);
139     }
140 
141     private PrimaryKeyType!Entity lastRowId()
142     {
143         return _factory.db
144             .execute("SELECT last_insert_rowid()")
145             .oneValue!(PrimaryKeyType!Entity);
146     }
147 
148     override void close()
149     {
150         _factory.close();
151     }
152 
153     override void remove(PrimaryKeyType!Entity id)
154     {
155         auto statement = _factory.db.prepare(text("DELETE FROM ", Table, " WHERE ", IdColumn, " = :id"));
156         statement.bind(":id", id);
157         statement.execute();
158         statement.reset();
159     }
160 
161     override Optional!Entity find(PrimaryKeyType!Entity id)
162     {
163         auto statement = _factory.db.prepare(text("SELECT * FROM ", Table, " WHERE ",
164             IdColumn, " = :id"));
165         statement.bind(":id", id);
166         auto results = statement.execute();
167         if (results.empty)
168             return no!Entity;
169         auto result = results.front().as!Entity;
170         statement.reset();
171         return some(result);
172     }
173 
174     override Entity[] findAll()
175     {
176         auto statement = _factory.db.prepare(text("SELECT * FROM ", Table));
177         auto results = statement.execute();
178         Entity[] entities;
179         foreach (result; results)
180         {
181             entities ~= result.as!Entity;
182         }
183         statement.reset();
184         return entities;
185     }
186 
187     override Entity save(Entity entity)
188     {
189         string statementString = InsertStatement!Entity;
190         auto statement = _factory.db.prepare(statementString);
191         static foreach (name; FieldNameTuple!Entity)
192         {
193             static if (!hasUDA!(__traits(getMember, Entity, name), primaryKey))
194             {
195                 statement.bind(":" ~ name, __traits(getMember, entity, name));
196             }
197         }
198         statement.execute();
199         statement.reset();
200         return find(lastRowId()).front;
201     }
202 
203     /*
204     Auto-implementation of custom queries.
205     */
206 
207     private auto autoQuery(string query, Args...)(Args args)
208     {
209         enum query = parseQuery(query);
210         return executeQuery!query(args);
211     }
212 
213     private Entity[] executeQuery(Query query, Args...)(Args args)
214     if (query.action == QueryAction.select)
215     {
216         Statement statement = _factory.db.prepare(text("SELECT * FROM ", Table, " WHERE ", createWhereClause!query));
217         statement.bindAll(args);
218         auto results = statement.execute();
219         Entity[] entities;
220         foreach (result; results)
221             entities ~= result.as!Entity;
222         statement.reset();
223         return entities;
224     }
225 
226     private enum createWhereClause(Query query)()
227     {
228         string[] clauses;
229         static foreach (filter; query.filters)
230         {
231             static assert([FieldNameTuple!Entity].canFind(filter.column),
232                 "The entity " ~ Entity.stringof ~ " does not have the column '" ~ filter.column ~ "'");
233             static if (filter.type == QueryFilterType.equal)
234                 clauses ~= text(filter.column, "=?");
235             else
236                 static assert(0, "Unsupporterd filter type");
237         }
238         return clauses.join(" ");
239     }
240 
241     // Add user-defined methods
242     static foreach (member; __traits(allMembers, Db))
243     {
244         // Make sure they are not standard methods.
245         static if (![__traits(allMembers, Repository!Entity)].canFind(member))
246         {
247             static assert(MemberFunctionsTuple!(Db, member).length == 1, "Overloading is not support in Db interfaces");
248             //alias parameters = DescribeParameters!(MemberFunctionsTuple!(Db, Member));
249             //pragma(msg, Parameters!(MemberFunctionsTuple!(Db, member).stringof));
250             mixin(`
251                 Entity[] %member(%params)
252                 {
253                     return autoQuery!("%member")(%args);
254                 }
255             `
256             .replace("%member", member)
257             .replace("%params", makeParameterList!(MemberFunctionsTuple!(Db, member)[0]))
258             .replace("%args", makeArgumentList!(MemberFunctionsTuple!(Db, member)[0]))
259             .text());
260         }
261     }
262 }
263 
264 /*
265 Helper method for user-defined methods
266 */
267 private string makeParameterList(alias func)()
268 {
269     string[] params;
270     static foreach (parameter; Parameters!func)
271         params ~= parameter.stringof ~ " " ~ cast(char) (params.length + 'a');
272     return params.join(", ");
273 }
274 
275 private string makeArgumentList(alias func)()
276 {
277     string[] params;
278     static foreach (parameter; Parameters!func)
279         params ~= text(cast(char) (params.length + 'a'));
280     return params.join(", ");
281 }
282 
283 /*
284 Templates for statement generation
285 */
286 private template CreationStatement(Entity)
287 {
288     static CreationStatement = text("CREATE TABLE IF NOT EXISTS ", Entity.stringof,
289         " (", parseCreationMembers!Entity, ")");
290 }
291 
292 private string parseCreationMembers(Entity)()
293 {
294     string[] lines;
295     string[] foreignKeys;
296     static foreach (memberName; FieldNameTuple!Entity)
297     {{
298         string[] attributes = [memberName];
299         alias member = __traits(getMember, Entity, memberName);
300 
301         // Add the SQL type identifier
302         static if (is(typeof(member) == int))
303             attributes ~= "INTEGER";
304         else static if (is(typeof(member) == bool))
305             attributes ~= "INTEGER";
306         else static if (is(typeof(member) == string))
307             attributes ~= "TEXT";
308         else
309             assert(0, "Cannot convert field of type " ~ typeof(member).stringof ~ " to a SQL type");
310 
311         // Add the primary key attribute if necessary.
312         static if (hasUDA!(member, primaryKey))
313             attributes ~= ["PRIMARY KEY", "AUTOINCREMENT"];
314         else static if (IsForeign!(member))
315         {
316             alias foreign = GetForeignEntity!member;
317             foreignKeys ~= text("FOREIGN KEY (", memberName, ") REFERENCES ",
318                 foreign.stringof, "(", GetIdColumn!foreign, ")");
319         }
320 
321         // Add the Not Null attribute.
322         attributes ~= "NOT NULL";
323         lines ~= attributes.join(' ');
324     }}
325     return (lines ~ foreignKeys).join(", ");
326 }
327 
328 private template InsertStatement(Entity)
329 {
330     static InsertStatement = text("INSERT INTO ", Entity.stringof,
331         "(", [FieldNameTuple!Entity].join(", "), ") VALUES ",
332         "(", [FieldNameTuple!Entity].map!(member => ":" ~ member).join(", "), ")");
333 }
334 
335 @("Can create a SQLite database")
336 unittest
337 {
338     static struct Entity
339     {
340         @primaryKey int id;
341     }
342     static interface Repo : Repository!Entity {}
343     auto db = new SqliteFactory(":memory:").open!Repo;
344     scope(exit) db.close();
345     assert(db !is null);
346 }
347 
348 @("Empty database should return no results")
349 unittest
350 {
351     static struct Entity
352     {
353         @primaryKey int id;
354     }
355     static interface Repo : Repository!Entity {}
356     auto db = new SqliteFactory(":memory:").open!Repo;
357     scope(exit) db.close();
358     assert(db.findAll() == [], "findAll() should return an empty list of the database is empty");
359 }
360 
361 @("Save() should return a saved instance")
362 unittest
363 {
364     static struct Entity
365     {
366         @primaryKey int id;
367         int value;
368     }
369     static interface Repo : Repository!Entity {}
370     auto db = new SqliteFactory(":memory:").open!Repo;
371     scope(exit) db.close();
372 
373     Entity toSave;
374     toSave.value = 5;
375 
376     const saved1 = db.save(toSave);
377     assert(saved1.value == 5, "Entity.value was not correctly saved");
378     assert(saved1.id == 1, "Entity.id was not generated");
379 
380     const saved2 = db.save(toSave);
381     assert(saved2.value == 5, "Entity.value was not correctly saved");
382     assert(saved2.id == 2, "Entity.id was not generated");
383 }
384 
385 @("findAll() should return all instances when saved")
386 unittest
387 {
388     static struct Entity
389     {
390         @primaryKey int id;
391         int value;
392     }
393     static interface Repo : Repository!Entity {}
394     auto db = new SqliteFactory(":memory:").open!Repo;
395     scope(exit) db.close();
396 
397     Entity toSave;
398     toSave.value = 5;
399 
400     const saved = db.save(toSave);
401 
402     assert(db.findAll() == [saved], "Did not correctly retrieve all results");
403 }
404 
405 @("remove() should remove an instance")
406 unittest
407 {
408     static struct Entity
409     {
410         @primaryKey int id;
411         int value;
412     }
413     static interface Repo : Repository!Entity {}
414     auto db = new SqliteFactory(":memory:").open!Repo;
415     scope(exit) db.close();
416 
417     Entity toSave;
418     const saved = db.save(toSave);
419     db.remove(saved.id);
420 }
421 
422 @("find() should return an empty optional if no row can be found")
423 unittest
424 {
425     static struct Entity
426     {
427         @primaryKey int id;
428         int value;
429     }
430     static interface Repo : Repository!Entity {}
431     auto db = new SqliteFactory(":memory:").open!Repo;
432     scope(exit) db.close();
433 
434     assert(db.find(0).empty, "Result was not empty");
435 }
436 
437 @("An invalid foreign key will cause an error")
438 unittest
439 {
440     static struct Parent
441     {
442         @primaryKey int id;
443     }
444 
445     static struct Child
446     {
447         @primaryKey
448         int id;
449 
450         @foreign!Parent
451         int child;
452     }
453 
454     static interface ParentRepo : Repository!Parent {}
455     static interface ChildRepo : Repository!Child {}
456     auto factory = new SqliteFactory(":memory:");
457     auto parentDb = factory.open!ParentRepo;
458     scope(exit) parentDb.close();
459     auto db = factory.open!ChildRepo;
460     scope(exit) db.close();
461 
462     Child child;
463     child.child = 5;
464     assertThrown(db.save(child));
465 }
466 
467 @("A valid foreign key will be accepted")
468 unittest
469 {
470     static struct Parent
471     {
472         @primaryKey int id;
473     }
474 
475     static struct Child
476     {
477         @primaryKey
478         int id;
479 
480         @foreign!Parent
481         int child;
482     }
483 
484     static interface ParentRepo : Repository!Parent {}
485     static interface ChildRepo : Repository!Child {}
486     auto factory = new SqliteFactory(":memory:");
487     auto parentDb = factory.open!ParentRepo;
488     scope(exit) parentDb.close();
489     auto db = factory.open!ChildRepo;
490     scope(exit) db.close();
491 
492     Parent parent;
493     parent = parentDb.save(parent);
494 
495     Child child;
496     child.child = parent.id;
497     db.save(child);
498 }
499 
500 @("Custom select statement finds only specific data")
501 unittest
502 {
503     static struct Entry
504     {
505         @primaryKey int id;
506         int value;
507 
508         static Entry withValue(int value)
509         {
510             return Entry(0, value);
511         }
512     }
513 
514     static interface Repo : Repository!Entry
515     {
516         Entry[] findByValue(int value);
517     }
518 
519     auto db = new SqliteFactory(":memory:").open!Repo;
520     scope(exit) db.close();
521     db.save(Entry.withValue(2));
522     db.save(Entry.withValue(3));
523     db.save(Entry.withValue(4));
524     db.save(Entry.withValue(3));
525 
526     Entry[] entries = db.findByValue(3);
527     entries.should.equal([Entry(2, 3), Entry(4, 3)]);
528 }
529 
530 @("Custom select statement also works with strings")
531 unittest
532 {
533     static struct Entry
534     {
535         @primaryKey int id;
536         string name;
537 
538         static Entry withName(string name)
539         {
540             return Entry(0, name);
541         }
542     }
543 
544     static interface Repo : Repository!Entry
545     {
546         Entry[] findByName(string name);
547     }
548 
549     Repo db = new SqliteFactory(":memory:").open!Repo;
550     scope(exit) db.close();
551     db.save(Entry.withName("foo"));
552     db.save(Entry.withName("bar"));
553 
554     Entry[] entries = db.findByName("foo");
555     entries.should.equal([Entry(1, "foo")]);
556 }
557 
558 @("Multiple custom select statement are supported")
559 unittest
560 {
561     static struct Entry
562     {
563         @primaryKey int id;
564         int value;
565         string name;
566     }
567 
568     static interface Repo : Repository!Entry
569     {
570         Entry[] findByValue(int value);
571         Entry[] findByName(string name);
572     }
573 
574     Repo db = new SqliteFactory(":memory:").open!Repo;
575     scope(exit) db.close();
576     db.save(Entry(1, 1337, "foo"));
577     db.save(Entry(2, 666, "bar"));
578 
579     db.findByName("bar").should.equal([Entry(2, 666, "bar")]);
580     db.findByValue(1337).should.equal([Entry(1, 1337, "foo")]);
581 }
582 
583 @("Cannot create interface with incorrect user-defined methods")
584 unittest
585 {
586     static struct Entry
587     {
588         @primaryKey int id;
589         int value;
590     }
591 
592     static interface Repo : Repository!Entry
593     {
594         // Name is wrong on purpose
595         Entry[] findByColumn(int value);
596     }
597 
598     __traits(compiles, SqliteDriver!Repo).should.equal(false)
599         .because("'findByColumn' is incorrect and should not compile");
600 }
601 
602 @("save() can store booleans")
603 unittest
604 {
605     static struct Entity
606     {
607         @primaryKey int id;
608         bool value;
609     }
610     static interface Repo : Repository!Entity {}
611     auto db = new SqliteFactory(":memory:").open!Repo;
612     scope(exit) db.close();
613 
614     auto entityA = db.save(Entity(0, true));
615     auto entityB = db.save(Entity(0, false));
616     assert(db.find(entityA.id).front.value == true);
617     assert(db.find(entityB.id).front.value == false);
618 }