Work in progress
This commit is contained in:
201
RegressionModels/AdvertisementPrediction/advertising.csv
Normal file
201
RegressionModels/AdvertisementPrediction/advertising.csv
Normal file
@ -0,0 +1,201 @@
|
||||
TV,Radio,Newspaper,Sales
|
||||
230.1,37.8,69.2,22.1
|
||||
44.5,39.3,45.1,10.4
|
||||
17.2,45.9,69.3,12
|
||||
151.5,41.3,58.5,16.5
|
||||
180.8,10.8,58.4,17.9
|
||||
8.7,48.9,75,7.2
|
||||
57.5,32.8,23.5,11.8
|
||||
120.2,19.6,11.6,13.2
|
||||
8.6,2.1,1,4.8
|
||||
199.8,2.6,21.2,15.6
|
||||
66.1,5.8,24.2,12.6
|
||||
214.7,24,4,17.4
|
||||
23.8,35.1,65.9,9.2
|
||||
97.5,7.6,7.2,13.7
|
||||
204.1,32.9,46,19
|
||||
195.4,47.7,52.9,22.4
|
||||
67.8,36.6,114,12.5
|
||||
281.4,39.6,55.8,24.4
|
||||
69.2,20.5,18.3,11.3
|
||||
147.3,23.9,19.1,14.6
|
||||
218.4,27.7,53.4,18
|
||||
237.4,5.1,23.5,17.5
|
||||
13.2,15.9,49.6,5.6
|
||||
228.3,16.9,26.2,20.5
|
||||
62.3,12.6,18.3,9.7
|
||||
262.9,3.5,19.5,17
|
||||
142.9,29.3,12.6,15
|
||||
240.1,16.7,22.9,20.9
|
||||
248.8,27.1,22.9,18.9
|
||||
70.6,16,40.8,10.5
|
||||
292.9,28.3,43.2,21.4
|
||||
112.9,17.4,38.6,11.9
|
||||
97.2,1.5,30,13.2
|
||||
265.6,20,0.3,17.4
|
||||
95.7,1.4,7.4,11.9
|
||||
290.7,4.1,8.5,17.8
|
||||
266.9,43.8,5,25.4
|
||||
74.7,49.4,45.7,14.7
|
||||
43.1,26.7,35.1,10.1
|
||||
228,37.7,32,21.5
|
||||
202.5,22.3,31.6,16.6
|
||||
177,33.4,38.7,17.1
|
||||
293.6,27.7,1.8,20.7
|
||||
206.9,8.4,26.4,17.9
|
||||
25.1,25.7,43.3,8.5
|
||||
175.1,22.5,31.5,16.1
|
||||
89.7,9.9,35.7,10.6
|
||||
239.9,41.5,18.5,23.2
|
||||
227.2,15.8,49.9,19.8
|
||||
66.9,11.7,36.8,9.7
|
||||
199.8,3.1,34.6,16.4
|
||||
100.4,9.6,3.6,10.7
|
||||
216.4,41.7,39.6,22.6
|
||||
182.6,46.2,58.7,21.2
|
||||
262.7,28.8,15.9,20.2
|
||||
198.9,49.4,60,23.7
|
||||
7.3,28.1,41.4,5.5
|
||||
136.2,19.2,16.6,13.2
|
||||
210.8,49.6,37.7,23.8
|
||||
210.7,29.5,9.3,18.4
|
||||
53.5,2,21.4,8.1
|
||||
261.3,42.7,54.7,24.2
|
||||
239.3,15.5,27.3,20.7
|
||||
102.7,29.6,8.4,14
|
||||
131.1,42.8,28.9,16
|
||||
69,9.3,0.9,11.3
|
||||
31.5,24.6,2.2,11
|
||||
139.3,14.5,10.2,13.4
|
||||
237.4,27.5,11,18.9
|
||||
216.8,43.9,27.2,22.3
|
||||
199.1,30.6,38.7,18.3
|
||||
109.8,14.3,31.7,12.4
|
||||
26.8,33,19.3,8.8
|
||||
129.4,5.7,31.3,11
|
||||
213.4,24.6,13.1,17
|
||||
16.9,43.7,89.4,8.7
|
||||
27.5,1.6,20.7,6.9
|
||||
120.5,28.5,14.2,14.2
|
||||
5.4,29.9,9.4,5.3
|
||||
116,7.7,23.1,11
|
||||
76.4,26.7,22.3,11.8
|
||||
239.8,4.1,36.9,17.3
|
||||
75.3,20.3,32.5,11.3
|
||||
68.4,44.5,35.6,13.6
|
||||
213.5,43,33.8,21.7
|
||||
193.2,18.4,65.7,20.2
|
||||
76.3,27.5,16,12
|
||||
110.7,40.6,63.2,16
|
||||
88.3,25.5,73.4,12.9
|
||||
109.8,47.8,51.4,16.7
|
||||
134.3,4.9,9.3,14
|
||||
28.6,1.5,33,7.3
|
||||
217.7,33.5,59,19.4
|
||||
250.9,36.5,72.3,22.2
|
||||
107.4,14,10.9,11.5
|
||||
163.3,31.6,52.9,16.9
|
||||
197.6,3.5,5.9,16.7
|
||||
184.9,21,22,20.5
|
||||
289.7,42.3,51.2,25.4
|
||||
135.2,41.7,45.9,17.2
|
||||
222.4,4.3,49.8,16.7
|
||||
296.4,36.3,100.9,23.8
|
||||
280.2,10.1,21.4,19.8
|
||||
187.9,17.2,17.9,19.7
|
||||
238.2,34.3,5.3,20.7
|
||||
137.9,46.4,59,15
|
||||
25,11,29.7,7.2
|
||||
90.4,0.3,23.2,12
|
||||
13.1,0.4,25.6,5.3
|
||||
255.4,26.9,5.5,19.8
|
||||
225.8,8.2,56.5,18.4
|
||||
241.7,38,23.2,21.8
|
||||
175.7,15.4,2.4,17.1
|
||||
209.6,20.6,10.7,20.9
|
||||
78.2,46.8,34.5,14.6
|
||||
75.1,35,52.7,12.6
|
||||
139.2,14.3,25.6,12.2
|
||||
76.4,0.8,14.8,9.4
|
||||
125.7,36.9,79.2,15.9
|
||||
19.4,16,22.3,6.6
|
||||
141.3,26.8,46.2,15.5
|
||||
18.8,21.7,50.4,7
|
||||
224,2.4,15.6,16.6
|
||||
123.1,34.6,12.4,15.2
|
||||
229.5,32.3,74.2,19.7
|
||||
87.2,11.8,25.9,10.6
|
||||
7.8,38.9,50.6,6.6
|
||||
80.2,0,9.2,11.9
|
||||
220.3,49,3.2,24.7
|
||||
59.6,12,43.1,9.7
|
||||
0.7,39.6,8.7,1.6
|
||||
265.2,2.9,43,17.7
|
||||
8.4,27.2,2.1,5.7
|
||||
219.8,33.5,45.1,19.6
|
||||
36.9,38.6,65.6,10.8
|
||||
48.3,47,8.5,11.6
|
||||
25.6,39,9.3,9.5
|
||||
273.7,28.9,59.7,20.8
|
||||
43,25.9,20.5,9.6
|
||||
184.9,43.9,1.7,20.7
|
||||
73.4,17,12.9,10.9
|
||||
193.7,35.4,75.6,19.2
|
||||
220.5,33.2,37.9,20.1
|
||||
104.6,5.7,34.4,10.4
|
||||
96.2,14.8,38.9,12.3
|
||||
140.3,1.9,9,10.3
|
||||
240.1,7.3,8.7,18.2
|
||||
243.2,49,44.3,25.4
|
||||
38,40.3,11.9,10.9
|
||||
44.7,25.8,20.6,10.1
|
||||
280.7,13.9,37,16.1
|
||||
121,8.4,48.7,11.6
|
||||
197.6,23.3,14.2,16.6
|
||||
171.3,39.7,37.7,16
|
||||
187.8,21.1,9.5,20.6
|
||||
4.1,11.6,5.7,3.2
|
||||
93.9,43.5,50.5,15.3
|
||||
149.8,1.3,24.3,10.1
|
||||
11.7,36.9,45.2,7.3
|
||||
131.7,18.4,34.6,12.9
|
||||
172.5,18.1,30.7,16.4
|
||||
85.7,35.8,49.3,13.3
|
||||
188.4,18.1,25.6,19.9
|
||||
163.5,36.8,7.4,18
|
||||
117.2,14.7,5.4,11.9
|
||||
234.5,3.4,84.8,16.9
|
||||
17.9,37.6,21.6,8
|
||||
206.8,5.2,19.4,17.2
|
||||
215.4,23.6,57.6,17.1
|
||||
284.3,10.6,6.4,20
|
||||
50,11.6,18.4,8.4
|
||||
164.5,20.9,47.4,17.5
|
||||
19.6,20.1,17,7.6
|
||||
168.4,7.1,12.8,16.7
|
||||
222.4,3.4,13.1,16.5
|
||||
276.9,48.9,41.8,27
|
||||
248.4,30.2,20.3,20.2
|
||||
170.2,7.8,35.2,16.7
|
||||
276.7,2.3,23.7,16.8
|
||||
165.6,10,17.6,17.6
|
||||
156.6,2.6,8.3,15.5
|
||||
218.5,5.4,27.4,17.2
|
||||
56.2,5.7,29.7,8.7
|
||||
287.6,43,71.8,26.2
|
||||
253.8,21.3,30,17.6
|
||||
205,45.1,19.6,22.6
|
||||
139.5,2.1,26.6,10.3
|
||||
191.1,28.7,18.2,17.3
|
||||
286,13.9,3.7,20.9
|
||||
18.7,12.1,23.4,6.7
|
||||
39.5,41.1,5.8,10.8
|
||||
75.5,10.8,6,11.9
|
||||
17.2,4.1,31.6,5.9
|
||||
166.8,42,3.6,19.6
|
||||
149.7,35.6,6,17.3
|
||||
38.2,3.7,13.8,7.6
|
||||
94.2,4.9,8.1,14
|
||||
177,9.3,6.4,14.8
|
||||
283.6,42,66.2,25.5
|
||||
232.1,8.6,8.7,18.4
|
||||
|
25
RegressionModels/AdvertisementPrediction/predict_sales.py
Normal file
25
RegressionModels/AdvertisementPrediction/predict_sales.py
Normal file
@ -0,0 +1,25 @@
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
df = pd.read_csv("./RegressionModels/AdvertisementPrediction/advertising.csv")
|
||||
|
||||
X = torch.tensor(df[["TV", "Radio", "Newspaper"]].values, dtype=torch.float32)
|
||||
Y = torch.tensor(df["Sales"].values, dtype=torch.float32)
|
||||
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(3,1)
|
||||
)
|
||||
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=3e-5)
|
||||
|
||||
for epoch in range(2000):
|
||||
y_pred = model(X)
|
||||
loss = loss_fn(y_pred, Y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if epoch % 100 == 99:
|
||||
print(f'Epoch {epoch+1}, Loss: {loss.item():.2f}')
|
||||
|
||||
Reference in New Issue
Block a user