-
Notifications
You must be signed in to change notification settings - Fork 5
Description
Problem with: "indices = T.Placeholder((0,),'int32')", it works if it is replace with: "indices= T.Placeholder((batch_size,),'int32')
import symjax as sj
import symjax.tensor as T
import numpy as np
import scipy
import matplotlib.pyplot as plt
from sklearn.cluster.k_means_ import *
import scipy.sparse as sp
from sklearn.cluster.k_means_ import k_init
from sklearn.cluster.k_means import _kmeans_single_lloyd
from sklearn.cluster import MiniBatchKMeans
class RAI_KMeans:
def init(self,x_shape,n_data_train,n_data_test,batch_size,n_clusters,n_landmarks):
#placeholder
data = T.Placeholder((batch_size,*x_shape),'float32')
indices = T.Placeholder((0,),'int32')
#trainable
centroids = T.Variable(sj.nn.initializers.glorot_uniform((n_clusters,*x_shape)))
theta_aff_train = T.Variable(np.tile(np.array([1,0,0,0,1,0]),(n_data_train,1)))
theta_diff_train = T.Variable(np.zeros((n_data_train,2*n_landmarks)))
theta_aff_test = T.Variable(T.Variable(np.tile(np.array([1,0,0,0,1,0]),(n_data_train,1))))
theta_diff_test = T.Variable(np.zeros((n_data_test,n_clusters,2*n_landmarks)))
theta_aff_train_ind = theta_aff_train[indices]
theta_diff_train_ind = theta_diff_train[indices]
theta_aff_test_ind = theta_aff_test[indices]
theta_diff_test_ind = theta_aff_test[indices]
data_train_aff = T.interpolation.affine_transform(data,theta_aff_train_ind)
print(data_train_aff.shape.get())
data_train_diff = T.interpolation.thin_plate_spline(data_train_aff,theta_diff_train_ind)
self.get_data_aff = sj.function(data,indices,outputs=data_train_diff)
data = sj.data.mnist()["train_set/images"].astype('float32')
x_shape = data.shape[1:]
n_data_train = data.shape[0]
batch_size = 10
ind_train = np.random.permutation(data.shape[0])[:batch_size].astype('int32')
print(x_shape)
data_train = data[ind_train]
test = RAI_KMeans(x_shape,n_data_train,10,batch_size,5,4)
image = test.get_data_aff(data_train,ind_train)