CHEN Yihui 3 gadi atpakaļ
vecāks
revīzija
56af1ed071
1 mainītis faili ar 60 papildinājumiem un 0 dzēšanām
  1. 60 0
      np.py

+ 60 - 0
np.py

@@ -0,0 +1,60 @@
+import os
+import random
+import shutil
+import time
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.optim
+import torch.multiprocessing as mp
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+
+
+
+def convert(image_folder, gpu_id=None, batch_size=1):
+
+    if gpu_id != None:
+        torch.cuda.set_device(gpu_id)
+
+    # prepare valid dataloader
+    val_transform = transforms.Compose([
+        transforms.Resize(342),
+        transforms.CenterCrop(299),
+        transforms.ToTensor(),
+        transforms.Normalize(mean=[0.5, 0.5, 0.5],
+                             std=[0.5, 0.5, 0.5])
+    ])
+    val_dataset = datasets.ImageFolder(image_folder, val_transform)
+    val_loader = torch.utils.data.DataLoader(
+        val_dataset, batch_size=batch_size, shuffle=False,
+        num_workers=1, pin_memory=False)
+    # valid model in the valid dataloader
+    validate(val_loader, gpu_id)
+
+def validate(val_loader, gpu_id):
+    with torch.no_grad():
+        if gpu_id != None:
+            torch.cuda.synchronize()
+        for i, (images, target) in enumerate(val_loader):
+            images = images.permute(0, 2, 3, 1)
+            #print(images.shape)
+            #print(target.item())
+            inpy = images.numpy()
+            f = open('calib_data_c/%05d.bin'%i, 'wb')
+            f.write(inpy.tobytes('C'))
+            f.close()
+            
+
+
+convert('calib_data', 0)
+convert('val_Data', 0)
+
+