-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparse_data_logs.py
More file actions
174 lines (150 loc) · 6.27 KB
/
Copy pathparse_data_logs.py
File metadata and controls
174 lines (150 loc) · 6.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# data_out = {}
# # resamples the data with a potentially new dt
# data_extractor.resample_data(data_out, args, data_in)
# # creates the input / output pairs
# io_pairs = data_extractor.create_io_pairs(data_out, args, train_indexes, validation_indexes)
#
# # appends new values to the dataset
# append_to_dataset(dataset_file_path, data_out, "test_log_" + str(i))
# TODO sample random initial states and inputs and propagate the distribution
# TODO also sample random parameters as requested
# TODO create plots based off of the model
import atexit
import subprocess
import sys
import torch
import util
import model
import time
import yaml
import h5py
from system.dynamics_model_dataset import append_to_dataset
import numpy as np
from tqdm import trange
from data_extractor import DataExtractorBase
import pathlib
from system import DynamicsModelDataset, DynamicsModelFullDataset, getDatasetName
from torch.utils.data import DataLoader
from system import load_system
import plotting
def _terminate_server(proc):
if proc.poll() is None:
proc.terminate()
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()
def checkDataValid(system, dataset):
if system.input_dim + system.state_dim > 100:
dataloader = DataLoader(
dataset,
batch_size=1000,
num_workers=4,
prefetch_factor=1,
)
else:
dataloader = DataLoader(
dataset,
batch_size=5000,
num_workers=8,
prefetch_factor=1,
)
it = iter(dataloader)
for _ in trange(len(dataloader), desc="validating data", unit="batch"):
X_batch, state_batch = next(it)
assert torch.isfinite(X_batch).all()
assert torch.isfinite(state_batch).all()
assert X_batch.shape[0] == state_batch.shape[0]
assert X_batch.shape[1] == system.traj_length + system.init_length
assert state_batch.shape[1] == system.traj_length + system.init_length + 1
assert X_batch.shape[2] == system.input_dim
assert state_batch.shape[2] == system.state_dim
if __name__ == "__main__":
total_runtime = time.time()
args = util.loadSharedArgumentsModelLearning("train")
model_config_path = args["model"]
model_config_file = yaml.load(open(model_config_path, "r"), Loader=yaml.FullLoader)
system = load_system(model_config_file["prediction_system"])
# start timing for epoch
model_start_time = time.time()
system_type = system.name
data_extractor = DataExtractorBase(system)
dataset_identifier = "test" # TODO the name that says what the thing is
print(f"\n***************************************************")
print(f"using system {system.name}")
print(f"***************************************************")
dataset_path_location = pathlib.Path(args["dataset"]) # path to where the data is
uid = util.get_datetime_uid()
dataset_path = dataset_path_location / getDatasetName(
system, uid + "_" + dataset_identifier
) # path to the new created hdf5
# if it is a path to a setup dataset treat it as such
if dataset_path_location.suffix != ".hdf5":
# TODO need to parse here
data_extractor.createDatasetFromLogs(dataset_path_location, dataset_path, args)
if not dataset_path.exists():
print(
f"\n\nERROR: the dataset does not exist and it should\n"
f"this can be caused by a couple issues\n"
f" 1. check that data is not being ignored because you are missing topics there is a warning printed out\n"
f" 2. The system is filtering out all of the data, reasons listed below\n"
f"Once you resolve the issues you can rerun using --reload_data flag and you will skip generating npz from rosbags"
)
if "counter" in args:
raise RuntimeError("No data was filtered out, it is number 1")
print(
f"TOTAL\n================ ignored data reasons "
f"{args['counter']['invalid_count']}/{args['counter']['expected_size']} ============"
)
for key in args["counter"]["reasons"]:
print(
f"ignored {args['counter']['reasons'][key]}/{args['counter']['expected_size']} or "
f"{(args['counter']['reasons'][key]/args['counter']['expected_size']) * 100:.1f}% because {key}"
)
print(f"============================ removed amount total ")
raise RuntimeError(f"the dataset does not exist at {dataset_path}")
# TODO need to check that is works with the given model, dt tau and such all work together
train_dataset = DynamicsModelDataset(
dataset_path, "train", system.init_length, system.traj_length
)
validation_dataset = DynamicsModelDataset(
dataset_path, "validation", system.init_length, system.traj_length
)
print(f"val dataset size {len(validation_dataset)}")
print(f"train dataset size {len(train_dataset)}")
checkDataValid(system, train_dataset)
checkDataValid(system, validation_dataset)
stats_output_path = dataset_path.parent / (dataset_path.stem + "_statistics")
stats_output_path.mkdir(parents=True, exist_ok=True)
plotting.write_server_script(stats_output_path)
server_proc = subprocess.Popen(
[sys.executable, str(stats_output_path / "run.py")],
stdout=subprocess.DEVNULL,
)
atexit.register(_terminate_server, server_proc)
plotting.plotlyDatasetStatistics(
stats_output_path, train_dataset, system, split_name="train", prefix="train"
)
plotting.plotlyDatasetStatistics(
stats_output_path,
validation_dataset,
system,
split_name="validation",
prefix="validation",
)
plotting.write_dataset_landing_page(
stats_output_path / "index.html",
dataset_name=dataset_path.stem,
split_counts={
"train": len(train_dataset),
"validation": len(validation_dataset),
},
)
if not args["no_pause"]:
print(f"\nResults: http://localhost:8080 (serving {stats_output_path})")
print("Press Ctrl+C to stop the server.")
try:
server_proc.wait()
except KeyboardInterrupt:
pass # atexit handler will terminate the server on exit