-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathuanl_lr.py
More file actions
40 lines (32 loc) · 1.44 KB
/
uanl_lr.py
File metadata and controls
40 lines (32 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import matplotlib.pyplot as plt
import statsmodels.api as sm
import numbers
import pandas as pd
from tabulate import tabulate
def print_tabulate(df: pd.DataFrame):
print(tabulate(df, headers=df.columns, tablefmt="orgtbl"))
def transform_variable(df: pd.DataFrame, x:str)->pd.Series:
if isinstance(df[x][0], numbers.Number):
return df[x] # type: pd.Series
else:
return pd.Series([i for i in range(0, len(df[x]))])
def linear_regression(df: pd.DataFrame, x:str, y: str)->None:
fixed_x = transform_variable(df, x)
model= sm.OLS(df[y],sm.add_constant(fixed_x)).fit()
# print(model.summary())
coef = pd.read_html(model.summary().tables[1].as_html(),header=0,index_col=0)[0]['coef']
df.plot(x=x,y=y, kind='scatter')
plt.plot(df[x],[pd.DataFrame.mean(df[y]) for _ in fixed_x.items()], color='green')
plt.plot(df_by_sal[x],[ coef.values[1] * x + coef.values[0] for _, x in fixed_x.items()], color='red')
plt.xticks(rotation=90)
plt.savefig(f'img/lr_{y}_{x}.png')
plt.close()
print('ok')
df = pd.read_csv("csv/typed_uanl.csv") # type: pd.DataFrame
#print_tabulate(df.head(50))
df_by_sal = df.groupby("Fecha")\
.aggregate(sueldo_mensual=pd.NamedAgg(column="Sueldo Neto", aggfunc=pd.DataFrame.mean))
# df_by_sal["sueldo_mensual"] = df_by_sal["sueldo_mensual"]**10
df_by_sal.reset_index(inplace=True)
print_tabulate(df_by_sal.head())
linear_regression(df_by_sal, "Fecha", "sueldo_mensual")