1 module utils.db.mysql;
2 
3 import
4 		std.conv,
5 		std.meta,
6 		std.array,
7 		std..string,
8 		std.traits,
9 		std.typecons,
10 		std.exception,
11 		std.algorithm,
12 
13 		utils.except,
14 		utils.db.mysql.binding;
15 
16 
17 final class MySQL
18 {
19 	this(string host, string user, string pass, string db, uint port = 3306)
20 	{
21 		_db = mysql_init(null);
22 
23 		{
24 			bool opt = true;
25 			!mysql_options(_db, MYSQL_OPT_RECONNECT, &opt) || throwError(lastError);
26 		}
27 
28 		mysql_real_connect(_db, host.toStringz, user.toStringz, pass.toStringz, db.toStringz, port, null, 0) || throwError(lastError);
29 	}
30 
31 	~this()
32 	{
33 		_stmts.byValue.each!(a => remove(a));
34 		mysql_close(_db);
35 	}
36 
37 package(utils.db):
38 
39 	void process(MYSQL_STMT* stmt)
40 	{
41 		mysql_stmt_reset(stmt);
42 	}
43 
44 	auto process(A...)(MYSQL_STMT* stmt)
45 	{
46 		assert(mysql_stmt_field_count(stmt) == A.length, `incorrect number of fields to return`);
47 
48 		{
49 			bool attr = true;
50 			!mysql_stmt_attr_set(stmt, STMT_ATTR_UPDATE_MAX_LENGTH, &attr) || throwError(lastError(stmt));
51 		}
52 
53 		auto self = this; // TODO: DMD BUG
54 
55 		struct S
56 		{
57 			this(this) @disable;
58 
59 			~this()
60 			{
61 				mysql_stmt_free_result(stmt);
62 				mysql_stmt_reset(stmt);
63 			}
64 
65 			bool empty() const
66 			{
67 				return !_hasRow;
68 			}
69 
70 			void popFront()
71 			{
72 				assert(_hasRow);
73 				_hasRow = self.fetch(stmt);
74 			}
75 
76 			auto array()
77 			{
78 				ReturnType!front[] res;
79 
80 				for(; _hasRow; popFront)
81 				{
82 					res ~= front;
83 				}
84 
85 				return res;
86 			}
87 
88 			auto front()
89 			{
90 				assert(_hasRow);
91 
92 				auto r = *_res;
93 
94 				foreach(i, T; A)
95 				{
96 					static if(isSomeString!T)
97 					{
98 						r[i] = r[i][0..*_lens[i]].idup;
99 					}
100 				}
101 
102 				static if(A.length > 1)
103 				{
104 					return r;
105 				}
106 				else
107 				{
108 					return r[0];
109 				}
110 			}
111 
112 		private:
113 			void initialize()
114 			{
115 				MYSQL_BIND[] arr;
116 				_res = new Tuple!A;
117 
118 				enforce(!mysql_stmt_store_result(stmt));
119 
120 				{
121 					auto info = mysql_stmt_result_metadata(stmt);
122 
123 					foreach(i, ref v; *_res)
124 					{
125 						c_ulong* len;
126 
127 						static if(isSomeString!(A[i]))
128 						{
129 							_lens[i] = len = new c_ulong;
130 							v.length = info.fields[i].max_length;
131 						}
132 
133 						arr ~= self.makeBind(&v, len);
134 					}
135 
136 					mysql_free_result(info);
137 				}
138 
139 				!mysql_stmt_bind_result(stmt, arr.ptr) || throwError(self.lastError(stmt));
140 				_hasRow = self.fetch(stmt);
141 			}
142 
143 			Tuple!A* _res;
144 			c_ulong*[uint] _lens;
145 			bool _hasRow;
146 		}
147 
148 		S s;
149 		s.initialize;
150 		return s;
151 	}
152 
153 	auto prepare(string sql)
154 	{
155 		auto stmt = _stmts.get(sql, null);
156 
157 		if(!stmt)
158 		{
159 			stmt = mysql_stmt_init(_db);
160 			!mysql_stmt_prepare(stmt, sql.ptr, cast(uint)sql.length) || throwError(lastError(stmt));
161 
162 			_stmts[sql] = stmt;
163 		}
164 
165 		return stmt;
166 	}
167 
168 	void bind(A...)(MYSQL_STMT* stmt, A args)
169 	{
170 		MYSQL_BIND[] ps;
171 		assert(mysql_stmt_param_count(stmt) == A.length, `incorrect number of bind parameters`);
172 
173 		foreach(ref v; args)
174 		{
175 			ps ~= makeBind(&v);
176 		}
177 
178 		!mysql_stmt_bind_param(stmt, ps.ptr) || throwError(lastError(stmt));
179 		execute(stmt);
180 	}
181 
182 	auto lastId(MYSQL_STMT* stmt)
183 	{
184 		return mysql_stmt_insert_id(stmt);
185 	}
186 
187 	auto affected(MYSQL_STMT* stmt)
188 	{
189 		return mysql_stmt_affected_rows(stmt);
190 	}
191 
192 private:
193 	auto makeBind(T)(T* v, c_ulong* len = null)
194 	{
195 		MYSQL_BIND b;
196 
197 		static if(is(T == typeof(null)))
198 		{
199 			b.buffer_type = MYSQL_TYPE_NULL;
200 		}
201 		else static if(isFloatingPoint!T)
202 		{
203 			b.buffer = v;
204 			b.buffer_type = T.sizeof == 4 ? MYSQL_TYPE_FLOAT : MYSQL_TYPE_DOUBLE;
205 		}
206 		else static if(isIntegral!T)
207 		{
208 			/*static*/ immutable aa =
209 			[
210 				1: MYSQL_TYPE_TINY,
211 				2: MYSQL_TYPE_SHORT,
212 				4: MYSQL_TYPE_LONG,
213 				8: MYSQL_TYPE_LONGLONG,
214 			];
215 
216 			b.is_unsigned = isUnsigned!T;
217 			b.buffer = v;
218 			b.buffer_type = aa[T.sizeof];
219 		}
220 		else static if(isSomeString!T)
221 		{
222 			b.length = len;
223 			b.buffer = cast(void*)v.ptr;
224 			b.buffer_length = cast(uint)v.length;
225 			b.buffer_type = MYSQL_TYPE_STRING;
226 		}
227 		else
228 		{
229 			static assert(false);
230 		}
231 
232 		return b;
233 	}
234 
235 	bool fetch(MYSQL_STMT* stmt)
236 	{
237 		auto r = mysql_stmt_fetch(stmt);
238 
239 		r != MYSQL_DATA_TRUNCATED || throwError(`data was truncated`);
240 		r == MYSQL_NO_DATA || !r || throwError(lastError(stmt));
241 
242 		return !r;
243 	}
244 
245 	void remove(MYSQL_STMT* stmt)
246 	{
247 		mysql_stmt_close(stmt);
248 	}
249 
250 	void execute(MYSQL_STMT* stmt)
251 	{
252 		!mysql_stmt_execute(stmt) || throwError(lastError(stmt));
253 	}
254 
255 	auto lastError()
256 	{
257 		return mysql_error(_db).fromStringz;
258 	}
259 
260 	auto lastError(MYSQL_STMT* stmt)
261 	{
262 		return mysql_stmt_error(stmt).fromStringz;
263 	}
264 
265 	MYSQL* _db;
266 	MYSQL_STMT*[string] _stmts;
267 }