Skip to content

Placeholder default size #17

@Koldh

Description

@Koldh

Problem with: "indices = T.Placeholder((0,),'int32')", it works if it is replace with: "indices= T.Placeholder((batch_size,),'int32')

import symjax as sj
import symjax.tensor as T
import numpy as np
import scipy
import matplotlib.pyplot as plt
from sklearn.cluster.k_means_ import *
import scipy.sparse as sp
from sklearn.cluster.k_means_ import k_init
from sklearn.cluster.k_means
import _kmeans_single_lloyd
from sklearn.cluster import MiniBatchKMeans

class RAI_KMeans:
def init(self,x_shape,n_data_train,n_data_test,batch_size,n_clusters,n_landmarks):

    #placeholder
    data      = T.Placeholder((batch_size,*x_shape),'float32')
    indices   = T.Placeholder((0,),'int32')
    #trainable
    centroids = T.Variable(sj.nn.initializers.glorot_uniform((n_clusters,*x_shape)))
    theta_aff_train = T.Variable(np.tile(np.array([1,0,0,0,1,0]),(n_data_train,1)))
    theta_diff_train = T.Variable(np.zeros((n_data_train,2*n_landmarks)))
    theta_aff_test = T.Variable(T.Variable(np.tile(np.array([1,0,0,0,1,0]),(n_data_train,1))))
    theta_diff_test = T.Variable(np.zeros((n_data_test,n_clusters,2*n_landmarks)))



    theta_aff_train_ind = theta_aff_train[indices]
    theta_diff_train_ind = theta_diff_train[indices]
    theta_aff_test_ind = theta_aff_test[indices]
    theta_diff_test_ind = theta_aff_test[indices]

    data_train_aff    = T.interpolation.affine_transform(data,theta_aff_train_ind)
    print(data_train_aff.shape.get())
    data_train_diff   = T.interpolation.thin_plate_spline(data_train_aff,theta_diff_train_ind)
    self.get_data_aff = sj.function(data,indices,outputs=data_train_diff)

data = sj.data.mnist()["train_set/images"].astype('float32')
x_shape = data.shape[1:]
n_data_train = data.shape[0]
batch_size = 10
ind_train = np.random.permutation(data.shape[0])[:batch_size].astype('int32')

print(x_shape)

data_train = data[ind_train]

test = RAI_KMeans(x_shape,n_data_train,10,batch_size,5,4)
image = test.get_data_aff(data_train,ind_train)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions