-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_raw_data.py
93 lines (79 loc) · 2.92 KB
/
load_raw_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Airflow DAG to load raw data from speadsheet into database.
Author
------
Nicolas Rojas
"""
# imports
import os
from datetime import datetime
import pandas as pd
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
def check_table_exists():
"""Check whether raw_clients table exists in raw_data database. If not, create it."""
# count number of rows in raw data table
query = 'SELECT COUNT(*) FROM information_schema.tables WHERE table_name="raw_clients"'
mysql_hook = MySqlHook(mysql_conn_id="raw_data", schema="raw_data")
connection = mysql_hook.get_conn()
cursor = connection.cursor()
cursor.execute(query)
results = cursor.fetchall()
# check whether table exists
if results[0][0] == 0:
# create table
print("----- table does not exists, creating it")
create_sql = "CREATE TABLE `raw_clients`\
(`id` BIGINT,\
`age` SMALLINT,\
`anual_income` BIGINT,\
`credit_score` SMALLINT,\
`loan_amount` BIGINT,\
`loan_duration_years` TINYINT,\
`number_of_open_accounts` SMALLINT,\
`had_past_default` TINYINT,\
`loan_approval` TINYINT\
)"
mysql_hook.run(create_sql)
else:
# no need to create table
print("----- table already exists")
return "Table checked"
def store_data():
"""Store raw data in respective table and database."""
# Path to the raw training data
_data_root = "./data"
_data_filename = "dataset.csv"
_data_filepath = os.path.join(_data_root, _data_filename)
# read data and obtain variable names
dataframe = pd.read_csv(_data_filepath)
dataframe.rename(columns={"Unnamed: 0": "ID"}, inplace=True)
sql_column_names = [name.lower() for name in dataframe.columns]
# insert every dataframe row into sql table
mysql_hook = MySqlHook(mysql_conn_id="raw_data", schema="raw_data")
conn = mysql_hook.get_conn()
cur = conn.cursor()
# VALUES in query are %s repeated as many columns are in dataframe
sql_column_names = ", ".join(
["`" + name + "`" for name in sql_column_names]
)
query = f"INSERT INTO `raw_clients` ({sql_column_names}) \
VALUES ({', '.join(['%s' for _ in range(dataframe.shape[1])])})"
dataframe = list(dataframe.itertuples(index=False, name=None))
cur.executemany(query, dataframe)
conn.commit()
return "Data stored"
with DAG(
"load_data",
description="Read data from source and store it in raw_data database",
start_date=datetime(2024, 9, 18, 0, 0),
schedule_interval="@once",
) as dag:
check_table_task = PythonOperator(
task_id="check_table_exists", python_callable=check_table_exists
)
store_data_task = PythonOperator(
task_id="store_data", python_callable=store_data
)
check_table_task >> store_data_task