-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_close.py
More file actions
71 lines (48 loc) · 1.84 KB
/
predict_close.py
File metadata and controls
71 lines (48 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# predict close price
from numpy import mean
from numpy import array
import sklearn
from sklearn import linear_model
def read_csv(filename):
import csv
with open(filename) as f: rows=[tuple(row) for row in csv.reader(f)]
print rows[0] # print field names
return rows[1:] # remove field names and return just data
def compile_features_and_values(rows):
num_days = 10
feature_sets = []
value_sets = []
for ii in range( len(rows)-num_days ):
features = []
for jj in range( num_days ):
day_index = ii + jj
#print ii, jj, day_index
# fields: Date,Open,High,Low,Close,Volume,Adj Close
features += [float(rows[day_index][1]), float(rows[day_index][2]), float(rows[day_index][3]), float(rows[day_index][5])]
feature_sets += [features]
value_sets += [float(rows[ii][4])]
return feature_sets, value_sets
def predict(regr, rows, day):
num_days = 10
# day = 0 is the most recent day
ii = day
features = []
for jj in range( num_days ):
day_index = ii + jj
# fields: Date,Open,High,Low,Close,Volume,Adj Close
features += [float(rows[day_index][1]), float(rows[day_index][2]), float(rows[day_index][3]), float(rows[day_index][5])]
print "close prediction for day", day, ":", regr.predict(features)
#features, mpg = read_mpg('data.txt')
rows = read_csv('ge.csv')
features, mpg = compile_features_and_values(rows)
#print mpg
# Create linear regression object
#regr = linear_model.LinearRegression()
regr = linear_model.Lasso(alpha=1) # they call lambda alpha
# Train the model using the training sets
regr.fit(features, mpg)
print regr.coef_
print regr.intercept_
for day in range(10):
predict(regr, rows, day)
#print "mean(mpg):", mean(mpg)