当前位置:首页 > Web开发 > 正文

pointNet代码

2024-03-31 Web开发

1.PointNet classification network分类网络

part segmentation network

数据集

1.point clouds sampled from 3D shapes
2.ShapeNetPart dataset.

结构

其主要分成以下三部分:

数据处理

model构建

结果选择

数据处理

将点云处理成程序可用的格式,具体实现在 provider.py 中,主要包含了数据下载、预处理(shuffle->rotate->jitter)、格式转换(hdf5->txt)

shuffle

def shuffle_data(data, labels): """ Shuffle data and labels. Input: data: B,N,... numpy array label: B,... numpy array Return: shuffled data, label and shuffle indices """ idx = np.arange(len(labels))#返回一个列表 # print(‘idx=‘,idx)#idx= [ 0 1 2 ... 2045 2046 2047] np.random.shuffle(idx)#把idx进行shuffle # print(‘idx=‘, idx) return data[idx, ...], labels[idx], idx

rotate旋转处理

def rotate_point_cloud(batch_data): # print(‘batch data shape=‘,batch_data.shape)#(32, 1024, 3) rotated_data = np.zeros(batch_data.shape, dtype=np.float32) for k in range(batch_data.shape[0]): rotation_angle = np.random.uniform() * 2 * np.pi#生成一个随机数 cosval = np.cos(rotation_angle) sinval = np.sin(rotation_angle) rotation_matrix = np.array([[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]) shape_pc = batch_data[k, ...] rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) #先让shape_pc的形状变成(?,3),因为旋转矩阵为(3,3) return rotated_data

jitter抖动处理

def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): B, N, C = batch_data.shape assert(clip > 0) jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)#将数组范围限制在(-1*clip, clip) jittered_data += batch_data return jittered_data

model构建 Feature transform net

with tf.variable_scope(transform_net1) as sc:#T-net transform = input_transform_net(point_cloud, is_training, bn_decay, K=3) print(point cloud=,point_cloud)#(32, 1024, 3) # print(‘input transform=‘,transform)#(32, 3, 3) point_cloud_transformed = tf.matmul(point_cloud, transform) # print(‘point_cloud_transformed=‘,point_cloud_transformed)#(32, 1024, 3)

mlp(64,128,1024)

net = tf_util.conv2d(net_transformed, 64, [1,1], padding=VALID, stride=[1,1], bn=True, is_training=is_training, scope=conv3, bn_decay=bn_decay) print(net3=,net)#(32, 1024, 1, 64) net = tf_util.conv2d(net, 128, [1,1], padding=VALID, stride=[1,1], bn=True, is_training=is_training, scope=conv4, bn_decay=bn_decay) print(net4=,net)#(32, 1024, 1, 128) net = tf_util.conv2d(net, 1024, [1,1], padding=VALID, stride=[1,1], bn=True, is_training=is_training, scope=conv5, bn_decay=bn_decay) print(net5=,net)#(32, 1024, 1, 1024)

类别投票 实现方法

batch_pred_sum.shape=(?,40) # 每个data对40个类的可能性

pred_val.shape=(?,) # 每个data所属的可能性最大的类

pred_val = np.argmax(batch_pred_sum, 1) #返回沿轴axis最大值的索引,即得到预测值最大的那一类的idx(label)

评估 输出(预测label,真实label)

</dump/pred_label.txt>

4, 4 0, 0 2, 2 8, 8 14, 23 ...

<shape_names.txt> airplane bathtub bed bench bookshelf bottle bowl car chair cone cup

保存预测错误的图片,,并可视化

</dump/xxxx_pred_name.jpg>
命名=第几个预测错误的图片+真实label+预测label

例子 /dump/1028_label_bed_pred_sofa.jpg

技术图片

三张点云图片,分别是当前点云数据旋转三个不同角度之后的样子

save code

温馨提示: 本文由Jm博客推荐,转载请保留链接: https://www.jmwww.net/file/web/41161.html