Skip to content

Commit ad02e6e

Browse files
committed
Added SQLServer_Persister
1 parent 4af1a72 commit ad02e6e

1 file changed

Lines changed: 225 additions & 0 deletions

File tree

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import platform
2+
import subprocess
3+
from collections.abc import MutableMapping
4+
from py2store.util import ModuleNotFoundErrorNiceMessage
5+
6+
with ModuleNotFoundErrorNiceMessage():
7+
import pyodbc
8+
9+
10+
class SQLServerPersister(MutableMapping):
11+
def __init__(self,
12+
conn_protocol='tcp',
13+
host='localhost',
14+
port='1433',
15+
db_username='SA',
16+
db_pass='Admin123x',
17+
db_name='py2store',
18+
table_name='py2store_default_table',
19+
key_fields=('id',),
20+
data_fields=('name',),
21+
create_table_if_not_exists=True,
22+
NVARCHAR_LENGTH=100):
23+
"""
24+
25+
:param conn_protocol: Which connection protocol to use to connect to MS SQLServer
26+
:param host: HOST where SQLServer is running
27+
:param port: PORT of SQLServer
28+
:param db_username: Username tht will be used to connect to the SQLServer
29+
:param db_pass: Database password
30+
:param db_name: Database Name, to connect to
31+
:param table_name: Table name on which you are going to perform CRUD ops
32+
:param key_fields: Primary key fields of the table
33+
:param data_fields: Non primary fields of the table
34+
:param create_table_if_not_exists: Flag to create database table if it doesn't exist previously
35+
:param NVARCHAR_LENGTH: Length of the NVARCHAR column types in case the table isn't created previously
36+
"""
37+
38+
self.__check_dependencies()
39+
self._sql_server_client = pyodbc.connect('DRIVER={{ODBC Driver 17 for SQL Server}};'
40+
'SERVER={}:{},{};'
41+
'DATABASE={};'
42+
'UID={};'
43+
'PWD={}'
44+
.format(conn_protocol, host, port, db_name, db_username, db_pass))
45+
46+
self._cursor = self._sql_server_client.cursor()
47+
48+
self._table_name = table_name
49+
self.key_fields = key_fields
50+
self.data_fields = data_fields
51+
self.NVARCHAR_LENGTH = NVARCHAR_LENGTH
52+
53+
if create_table_if_not_exists:
54+
self.__create_table()
55+
56+
def __create_table(self):
57+
"""
58+
Create table if it doesn't exist previously
59+
:return:
60+
"""
61+
if not self._cursor.tables(table=self._table_name, tableType='TABLE').fetchone():
62+
base_query = 'CREATE TABLE {table_name} ({attributes})'
63+
attributes = ''
64+
65+
for key_field in self.key_fields:
66+
attributes += '{column} {type} PRIMARY KEY,'.format(column=key_field,
67+
type='NVARCHAR({})'.format(self.NVARCHAR_LENGTH))
68+
69+
for data_field in self.data_fields:
70+
attributes += '{column} {type},'.format(column=data_field, type='NVARCHAR(100)')
71+
72+
create_table_query = base_query.format(table_name=self._table_name, attributes=attributes)
73+
74+
print(create_table_query)
75+
self._cursor.execute(create_table_query)
76+
self._sql_server_client.commit()
77+
78+
@staticmethod
79+
def __check_dependencies():
80+
"""
81+
Checks for the required dependencies for the OS
82+
Currently checks are there for a Ubuntu machine only, more can be added
83+
:return:
84+
"""
85+
if 'ubuntu' in platform.platform().lower():
86+
result = subprocess.Popen(["dpkg", "-s", "msodbcsql17"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
87+
out, err = result.communicate()
88+
if not out:
89+
raise ModuleNotFoundError("ODBC Driver for SQL Server is missing. Please refer "
90+
"https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/installing-the-"
91+
"microsoft-odbc-driver-for-sql-server?view=sql-server-2017")
92+
93+
def __getitem__(self, k):
94+
base_query = "SELECT * from {table} where {{conditions}}".format(table=self._table_name)
95+
96+
conditions = ''
97+
for key, val in k.items():
98+
conditions += "{key} = '{val}' AND ".format(key=key, val=val)
99+
100+
select_query = base_query.format(conditions=conditions)
101+
select_query = select_query.rstrip('AND ')
102+
103+
self._cursor.execute(select_query.format(value=k))
104+
record = self._cursor.fetchone()
105+
if record:
106+
return record
107+
raise KeyError(f"No record found for primary_key: {k}")
108+
109+
def __setitem__(self, k, v):
110+
base_query = "INSERT into {table} ({{attributes}}) VALUES ({{values}})".format(table=self._table_name)
111+
112+
attributes = ','.join([key for key in k])
113+
values = ','.join(["'{}'".format(val) for val in v])
114+
insert_query = base_query.format(attributes=attributes, values=values)
115+
116+
try:
117+
self._cursor.execute(insert_query)
118+
except pyodbc.IntegrityError as e:
119+
raise KeyError("Cannot insert a duplicate entry")
120+
121+
self._sql_server_client.commit()
122+
123+
def __delitem__(self, k):
124+
base_query = "DELETE FROM {table} where {{conditions}}".format(table=self._table_name)
125+
126+
conditions = ''
127+
for key, val in k.items():
128+
conditions += "{key} = '{val}' AND ".format(key=key, val=val)
129+
130+
select_query = base_query.format(conditions=conditions)
131+
select_query = select_query.rstrip('AND ')
132+
133+
self._cursor.execute(select_query.format(value=k))
134+
self._sql_server_client.commit()
135+
136+
def __iter__(self):
137+
select_all_query = "SELECT * from {table};".format(table=self._table_name)
138+
self._cursor.execute(select_all_query)
139+
records = self._cursor.fetchall()
140+
141+
for record in records:
142+
yield record
143+
144+
def __len__(self):
145+
select_all_query = "SELECT * from {table};".format(table=self._table_name)
146+
self._cursor.execute(select_all_query)
147+
records = self._cursor.fetchall()
148+
return len(records)
149+
150+
151+
def test_sqlserver_persister():
152+
"""
153+
A test case for the persister which performs many actions
154+
155+
Output of the test case is given below:
156+
157+
Adding Records
158+
=========================
159+
Adding a Duplicate Record
160+
Cannot Enter a Duplicate Record
161+
=========================
162+
Fetching Records
163+
('1', 'Test 1')
164+
('2', 'Test 2')
165+
('3', 'Test 3')
166+
=========================
167+
Iterating over the records
168+
('1', 'Test 1')
169+
('2', 'Test 2')
170+
('3', 'Test 3')
171+
=========================
172+
Getting the length
173+
3
174+
=====================
175+
Deleting Records
176+
=====================
177+
Getting the length AGAIN after deletion
178+
0
179+
=====================
180+
181+
:return:
182+
"""
183+
sql_server_persister = SQLServerPersister()
184+
185+
print("Adding Records")
186+
sql_server_persister[('id', 'name',)] = ('1', 'Test 1',)
187+
sql_server_persister[('id', 'name',)] = ('2', 'Test 2',)
188+
sql_server_persister[('id', 'name',)] = ('3', 'Test 3',)
189+
print("=========================")
190+
191+
print("Adding a Duplicate Record")
192+
try:
193+
sql_server_persister[('id', 'name',)] = ('1', 'Test 4',)
194+
except KeyError:
195+
print("Cannot Enter a Duplicate Record")
196+
print("=========================")
197+
198+
print("Fetching Records")
199+
print(sql_server_persister[{'id': '1',
200+
'name': 'Test 1'}])
201+
print(sql_server_persister[{'name': 'Test 2'}])
202+
print(sql_server_persister[{'id': 3}])
203+
print("=========================")
204+
205+
print("Iterating over the records")
206+
for record in sql_server_persister:
207+
print(record)
208+
print("=========================")
209+
210+
print("Getting the length")
211+
print(len(sql_server_persister))
212+
print("=====================")
213+
214+
print("Deleting Records")
215+
del sql_server_persister[{'id': '1'}]
216+
del sql_server_persister[{'name': 'Test 2'}]
217+
del sql_server_persister[{'id': '3', 'name': 'Test 3'}]
218+
print("=====================")
219+
220+
print("Getting the length AGAIN after deletion")
221+
print(len(sql_server_persister))
222+
print("=====================")
223+
224+
225+
test_sqlserver_persister()

0 commit comments

Comments
 (0)