前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >《大数据+AI在大健康领域中最佳实践前瞻》---- 基于变分自编码器(VAE) 进行疾病预测实现

《大数据+AI在大健康领域中最佳实践前瞻》---- 基于变分自编码器(VAE) 进行疾病预测实现

作者头像
流川疯
发布2021-12-06 16:18:21
3690
发布2021-12-06 16:18:21
举报
文章被收录于专栏:流川疯编写程序的艺术

文章大纲

Using Variational Autoencoders to predict future diagnosis

VAE for Collaborative Filtering

This work is an adaptation of the work by Dawen et.al who used VAEs for the purpose of Collaborative filtering. The work by Dawen et.al exploits the Generative nature of VAEs to arrive at a completed user-preference information given an input of partial user-preference information.

Work by Dawen et.al: https://arxiv.org/abs/1802.05814

Diagnosis Codes

In the healthcare industry, the diagnosis any patient encounters, has been standardized with diagnosis codes. Each disease or a medical condition is mapped to a diagnosis code. ICD10 Diagnosis Codes: https://www.icd10data.com/

Data

In the data we are using, we have information on a set of patients and the diagnoses that they have undergone. Each of these diagnoses are mapped to diagnosis codes. The data contains a total of 1567 unique diagnosis codes. So a given patient is represented by a binary vector of dimension 1567 where an element is 1 if that patient has undergone the particular diagnosis and 0 otherwise.

VAE for next diagnosis prediction

Now given the patient diagnosis information, the VAE encodes it into a latent space. It learns the information on distribution of patients and the clusters of diagnoses they undergo.

To provide a simple example if you consider diabetes in older adults, the group of diagnosis that would appear commonly among such adults would be something like, Diabetes, Cholesterol, Blood Pressure, Arthritics etc,. Now given a patient with a diagnosis set which says something like, Diabetes and Cholesterol, this patient would get mapped to the same space in the latent dimension as the older adults with the diagnosis mentioned earlier.

This mapping of similar patients to similar latent space has a very favourable impact on decoding/reconstruction to original space. On decoding, what happens is that the, missing diagnosis with high probabilities of occurence, for a particular patient is also reconstructed.

This ability to fill in the missing diagnosis in the form of a Collaborative Filtering of sorts is why I apply this technique to predict the next diagnosis.

Applications

This work can be used for many applications ranging from insurance companies using it to better predict a patient’s needs to healthcare applications which encourage people to improve their life-style choices.

Some background

I came across the application of VAEs for Collaborative filtering, when I studied it for my previous work “Hybrid VAE for Collaborative Filtering”. This work processes the movie plot information from IMDb and uses it as an input to improve movie recommendation systems. This particular work was published in RecSys 2018 Knowledge Transfer Learning workshop: https://arxiv.org/abs/1808.01006

IMPLEMENTATION

代码语言:javascript
复制
 %matplotlib inline
import numpy as np
import pickle
import os
from matplotlib import pyplot as plt
from keras.layers import Input, Dense, Lambda, Multiply, Dropout,Embedding, Flatten, Activation, Reshape
from keras.models import Model
from keras import losses
from keras import backend as K
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, Callback
from IPython.display import clear_output
from sklearn import preprocessing
from keras import regularizers
import keras
import pandas as pd
import numpy as np
代码语言:javascript
复制
Using TensorFlow backend.
代码语言:javascript
复制
import os
os.getcwd()
代码语言:javascript
复制
'C:\\Users\\iz\\Desktop'
代码语言:javascript
复制
df = pd.read_csv('test.csv')
代码语言:javascript
复制
df = df.drop(['KEY'],axis = 1)
代码语言:javascript
复制
df.head()

T40

A08

I69

Z48

R44

N92

R59

B97

M96

I35

...

H61

T84

M16

J38

Z90

D68

K83

Z87

Z75

Z43

0

0

0

0

0

0

0

0

0

0

0

...

0

0

0

0

1

0

0

1

0

0

1

0

0

0

0

0

0

0

0

0

0

...

0

0

0

0

1

0

0

1

0

0

2

0

1

0

0

0

0

0

0

0

0

...

0

0

0

0

0

1

0

1

0

0

3

0

0

0

0

0

0

1

0

0

0

...

0

0

0

0

0

0

0

1

0

0

4

0

0

0

0

0

0

0

0

0

0

...

0

0

0

0

0

0

0

0

0

0

5 rows × 500 columns

代码语言:javascript
复制
from sklearn.model_selection import train_test_split

y = range(df.shape[0])
xtrain, xtest, ytrain, ytest = train_test_split(df, y, test_size = 0.1, random_state = 42)
xtrain, xval, ytrain, yval = train_test_split(xtrain, ytrain, test_size = 0.1, random_state = 42)
代码语言:javascript
复制
代码语言:javascript
复制
# import numpy as np

with open('./train.data', 'wb') as f:
    np.save(f, xtrain)
with open('./test.data', 'wb') as f:
    np.save(f, xtest)
with open('./val.data', 'wb') as f:
    np.save(f, xval)
Function to plot Losses
代码语言:javascript
复制
class PlotLosses(Callback):
    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.losses = []
        self.val_losses = []        
        self.fig = plt.figure()
        self.logs = []

    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(logs)
        self.x.append(self.i)
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        self.i += 1
        
        clear_output(wait=True)
        plt.plot(self.x, self.losses, label="loss")
        plt.plot(self.x, self.val_losses, label="val_loss")
        plt.legend()
        plt.show();
        
plot_losses = PlotLosses()
Load Data
代码语言:javascript
复制
with open('train.data', 'rb') as f:
    x_train = np.load(f)
print("number of training users: ", x_train.shape[0])

with open('val.data', 'rb') as f:
    x_val = np.load(f)
print("number of validation users: ", x_val.shape[0])
代码语言:javascript
复制
number of training users:  5234
number of validation users:  582
代码语言:javascript
复制
x_train.shape,x_val.shape
代码语言:javascript
复制
((5234, 500), (582, 500))
代码语言:javascript
复制
x_train[0].shape
代码语言:javascript
复制
(500,)
代码语言:javascript
复制
x_train = x_train[:5000]
x_val = x_val[:500]
Configure Network
代码语言:javascript
复制
# encoder/decoder network size

batch_size=100
original_dim = x_train.shape[1]
intermediate_dim=200
latent_dim=100
nb_epochs=30
epsilon_std=1.0
Using a two output network

Here, we have two outputs from the network, which is much different compared to the original VAE network proposed by Dawen et.al.

The first output reconstructs the given input, while the second output gives out a probability distribution over the Diagnosis codes. Each of them have a specific loss function that maximizes the particular objective.

Network 1: Same objective functions for the two outputs
代码语言:javascript
复制
#Function to increase the relevance of the KL regularization as the training progresses
class increaseBeta(Callback):
    def __init__(self):
        self.global_beta = 0.0
    def on_train_begin(self, logs={}):
        self.global_beta = 0.0
    def on_epoch_end(self, epoch, logs={}):
        self.global_beta = self.global_beta + 0.01

updateBeta = increaseBeta()

#Function to l2 normalize the inputs
def l2normalize(args):
    _x=args
    return K.l2_normalize(_x, axis = -1)

#Function to do the sampling from Latent Space
def sampling(args):
    _mean,_log_var=args
    epsilon=K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)
    return _mean+K.exp(_log_var/2)*epsilon

# encoder network

x=Input(batch_shape=(batch_size,original_dim))
norm_x = Lambda(l2normalize, output_shape=(original_dim,))(x)
norm_x = Dropout(rate = 0.5)(norm_x)
h=Dense(intermediate_dim, activation='relu')(norm_x)
z_mean=Dense(latent_dim)(h)
z_log_var=Dense(latent_dim)(h)

z= Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# decoder network
h_decoder=Dense(intermediate_dim, activation='relu')
x_bar=Dense(original_dim, activation='sigmoid') 
x_prob=Dense(original_dim, activation='softmax')
h_decoded = h_decoder(z)

# We have two outputs, one which reconstructs the given input, the other which reconstructs the probability 
x_decoded = x_bar(h_decoded)
x_probability = x_prob(h_decoded)

def vae_loss(x,x_bar):
    reconst_loss = K.sum(losses.binary_crossentropy(x,x_bar), axis = -1)
    kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)
    return reconst_loss + (updateBeta.global_beta)*kl_loss

# build and compile model
vae = Model(x, [x_decoded, x_probability])
vae.compile(optimizer='adam', loss=vae_loss, loss_weights=[1., 1.])

weightsPath = "./weights/weights_vae1.hdf5"
x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
#         validation_data=(x_val, [x_val, x_val]), callbacks=[checkpointer, reduce_lr, plot_losses, updateBeta])
vae.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
        validation_data=(x_val, [x_val, x_val]), )# checkpointer = ModelCheckpoint(filepath=weightsPath, verbose=1, save_best_only=True)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)

# vae.fit(x = 
代码语言:javascript
复制
Train on 5000 samples, validate on 500 samples
Epoch 1/30
5000/5000 [==============================] - 4s 758us/step - loss: 78.1787 - dense_23_loss: 34.3244 - dense_24_loss: 43.8543 - val_loss: 64.6742 - val_dense_23_loss: 22.1748 - val_dense_24_loss: 42.4994
Epoch 2/30
5000/5000 [==============================] - 4s 743us/step - loss: 64.4046 - dense_23_loss: 21.9435 - dense_24_loss: 42.4611 - val_loss: 63.6366 - val_dense_23_loss: 21.5553 - val_dense_24_loss: 42.0814
Epoch 3/30
5000/5000 [==============================] - 4s 837us/step - loss: 63.6365 - dense_23_loss: 21.5137 - dense_24_loss: 42.1228 - val_loss: 62.9220 - val_dense_23_loss: 21.1661 - val_dense_24_loss: 41.7559
Epoch 4/30
5000/5000 [==============================] - 4s 727us/step - loss: 62.5700 - dense_23_loss: 20.9471 - dense_24_loss: 41.6229 - val_loss: 61.4582 - val_dense_23_loss: 20.4014 - val_dense_24_loss: 41.0568
Epoch 5/30
5000/5000 [==============================] - 4s 733us/step - loss: 61.5120 - dense_23_loss: 20.3677 - dense_24_loss: 41.1443 - val_loss: 60.5572 - val_dense_23_loss: 19.8999 - val_dense_24_loss: 40.6573 los
Epoch 6/30
5000/5000 [==============================] - 3s 669us/step - loss: 60.6162 - dense_23_loss: 19.8809 - dense_24_loss: 40.7353 - val_loss: 59.4887 - val_dense_23_loss: 19.3130 - val_dense_24_loss: 40.1757
Epoch 7/30
5000/5000 [==============================] - 4s 879us/step - loss: 59.7845 - dense_23_loss: 19.4161 - dense_24_loss: 40.3683 - val_loss: 58.6560 - val_dense_23_loss: 18.8436 - val_dense_24_loss: 39.8125
Epoch 8/30
5000/5000 [==============================] - 3s 581us/step - loss: 59.0007 - dense_23_loss: 18.9740 - dense_24_loss: 40.0267 - val_loss: 57.7723 - val_dense_23_loss: 18.3500 - val_dense_24_loss: 39.4223
Epoch 9/30
5000/5000 [==============================] - 3s 664us/step - loss: 58.1647 - dense_23_loss: 18.4887 - dense_24_loss: 39.6760 - val_loss: 56.8185 - val_dense_23_loss: 17.8036 - val_dense_24_loss: 39.0149
Epoch 10/30
5000/5000 [==============================] - 4s 705us/step - loss: 57.3740 - dense_23_loss: 18.0296 - dense_24_loss: 39.3444 - val_loss: 56.0715 - val_dense_23_loss: 17.3908 - val_dense_24_loss: 38.6807
Epoch 11/30
5000/5000 [==============================] - 4s 711us/step - loss: 56.7545 - dense_23_loss: 17.6735 - dense_24_loss: 39.0810 - val_loss: 55.1875 - val_dense_23_loss: 16.8682 - val_dense_24_loss: 38.3194
Epoch 12/30
5000/5000 [==============================] - 3s 550us/step - loss: 56.1791 - dense_23_loss: 17.3366 - dense_24_loss: 38.8425 - val_loss: 54.6427 - val_dense_23_loss: 16.5640 - val_dense_24_loss: 38.0787
Epoch 13/30
5000/5000 [==============================] - 3s 568us/step - loss: 55.6814 - dense_23_loss: 17.0418 - dense_24_loss: 38.6397 - val_loss: 54.1010 - val_dense_23_loss: 16.2432 - val_dense_24_loss: 37.8578
Epoch 14/30
 500/5000 [==>...........................] - ETA: 2s - loss: 54.4526 - dense_23_loss: 16.7270 - dense_24_loss: 37.7256

C:\ProgramData\Anaconda3\envs\zhongdian\lib\site-packages\keras\callbacks\callbacks.py:95: RuntimeWarning: Method (on_train_batch_end) is slow compared to the batch update (0.111798). Check your callbacks.
  % (hook_name, delta_t_median), RuntimeWarning)


5000/5000 [==============================] - 3s 624us/step - loss: 55.2011 - dense_23_loss: 16.7642 - dense_24_loss: 38.4370 - val_loss: 53.4904 - val_dense_23_loss: 15.8693 - val_dense_24_loss: 37.6210
Epoch 15/30
5000/5000 [==============================] - ETA: 0s - loss: 54.8870 - dense_23_loss: 16.5693 - dense_24_loss: 38.31 - 3s 699us/step - loss: 54.8549 - dense_23_loss: 16.5582 - dense_24_loss: 38.2967 - val_loss: 53.0989 - val_dense_23_loss: 15.6399 - val_dense_24_loss: 37.4590
Epoch 16/30
5000/5000 [==============================] - 3s 612us/step - loss: 54.5157 - dense_23_loss: 16.3574 - dense_24_loss: 38.1583 - val_loss: 52.5809 - val_dense_23_loss: 15.3208 - val_dense_24_loss: 37.2601
Epoch 17/30
5000/5000 [==============================] - 3s 628us/step - loss: 54.2176 - dense_23_loss: 16.1803 - dense_24_loss: 38.0373 - val_loss: 52.2525 - val_dense_23_loss: 15.1291 - val_dense_24_loss: 37.1233
Epoch 18/30
5000/5000 [==============================] - 4s 744us/step - loss: 53.8684 - dense_23_loss: 15.9648 - dense_24_loss: 37.9036 - val_loss: 51.7984 - val_dense_23_loss: 14.8564 - val_dense_24_loss: 36.9420 - dense_23_
Epoch 19/30
5000/5000 [==============================] - 3s 600us/step - loss: 53.5824 - dense_23_loss: 15.7955 - dense_24_loss: 37.7869 - val_loss: 51.4563 - val_dense_23_loss: 14.6374 - val_dense_24_loss: 36.8189
Epoch 20/30
5000/5000 [==============================] - 3s 534us/step - loss: 53.3213 - dense_23_loss: 15.6395 - dense_24_loss: 37.6818 - val_loss: 51.2304 - val_dense_23_loss: 14.5035 - val_dense_24_loss: 36.7269
Epoch 21/30
5000/5000 [==============================] - 4s 701us/step - loss: 53.0680 - dense_23_loss: 15.4855 - dense_24_loss: 37.5825 - val_loss: 50.9055 - val_dense_23_loss: 14.3052 - val_dense_24_loss: 36.6004
Epoch 22/30
5000/5000 [==============================] - 4s 735us/step - loss: 52.8112 - dense_23_loss: 15.3356 - dense_24_loss: 37.4755 - val_loss: 50.6851 - val_dense_23_loss: 14.1706 - val_dense_24_loss: 36.5145
Epoch 23/30
5000/5000 [==============================] - 4s 733us/step - loss: 52.6116 - dense_23_loss: 15.2092 - dense_24_loss: 37.4025 - val_loss: 50.3594 - val_dense_23_loss: 13.9758 - val_dense_24_loss: 36.3836
Epoch 24/30
5000/5000 [==============================] - 4s 849us/step - loss: 52.4208 - dense_23_loss: 15.0949 - dense_24_loss: 37.3259 - val_loss: 50.1555 - val_dense_23_loss: 13.8413 - val_dense_24_loss: 36.3142
Epoch 25/30
5000/5000 [==============================] - 4s 732us/step - loss: 52.2180 - dense_23_loss: 14.9752 - dense_24_loss: 37.2428 - val_loss: 49.9194 - val_dense_23_loss: 13.6918 - val_dense_24_loss: 36.2276
Epoch 26/30
5000/5000 [==============================] - 3s 604us/step - loss: 52.0623 - dense_23_loss: 14.8808 - dense_24_loss: 37.1815 - val_loss: 49.7312 - val_dense_23_loss: 13.5534 - val_dense_24_loss: 36.1778
Epoch 27/30
5000/5000 [==============================] - 3s 507us/step - loss: 51.8958 - dense_23_loss: 14.7736 - dense_24_loss: 37.1222 - val_loss: 49.5618 - val_dense_23_loss: 13.4641 - val_dense_24_loss: 36.0977
Epoch 28/30
5000/5000 [==============================] - 3s 566us/step - loss: 51.7471 - dense_23_loss: 14.6879 - dense_24_loss: 37.0592 - val_loss: 49.3073 - val_dense_23_loss: 13.3000 - val_dense_24_loss: 36.0073
Epoch 29/30
5000/5000 [==============================] - 3s 595us/step - loss: 51.5263 - dense_23_loss: 14.5482 - dense_24_loss: 36.9781 - val_loss: 49.1972 - val_dense_23_loss: 13.2198 - val_dense_24_loss: 35.9775
Epoch 30/30
5000/5000 [==============================] - 3s 551us/step - loss: 51.4380 - dense_23_loss: 14.4985 - dense_24_loss: 36.9395 - val_loss: 48.9715 - val_dense_23_loss: 13.0922 - val_dense_24_loss: 35.8793





<keras.callbacks.callbacks.History at 0x226a912c6d8>
Network 2: Different objective functions for the two outputs
代码语言:javascript
复制
# Function to increase the relevance of the KL regularization as the training progresses

class increaseBeta(Callback):
    def __init__(self):
        self.global_beta = 0.0
    def on_train_begin(self, logs={}):
        self.global_beta = 0.0
    def on_epoch_end(self, epoch, logs={}):
        self.global_beta = self.global_beta + 0.01

updateBeta = increaseBeta()

#Function to l2 normalize the inputs
def l2normalize(args):
    _x=args
    return K.l2_normalize(_x, axis = -1)

#Function to do the sampling from Latent Space
def sampling(args):
    _mean,_log_var=args
    epsilon=K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)
    return _mean+K.exp(_log_var/2)*epsilon


# encoder network
x=Input(batch_shape=(batch_size,original_dim))
norm_x = Lambda(l2normalize, output_shape=(original_dim,))(x)
norm_x = Dropout(rate = 0.5)(norm_x)
h=Dense(intermediate_dim, activation='relu')(norm_x)
z_mean=Dense(latent_dim)(h)
z_log_var=Dense(latent_dim)(h)

z= Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# decoder network
h_decoder=Dense(intermediate_dim, activation='relu')
x_bar=Dense(original_dim, activation='sigmoid') 
x_prob=Dense(original_dim, activation='softmax')
h_decoded = h_decoder(z)
#We have two outputs, one which reconstructs the given input, the other which reconstructs the probability 
x_decoded = x_bar(h_decoded)
x_probability = x_prob(h_decoded)

def vae_loss1(x,x_bar):
    reconst_loss = K.sum(losses.binary_crossentropy(x,x_bar), axis = -1)
    kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)
    return reconst_loss + (updateBeta.global_beta)*kl_loss

def vae_loss2(x,x_bar):
    neg_ll = -K.sum(x_bar*x, axis = -1)
    kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)
    return neg_ll + (updateBeta.global_beta)*kl_loss

# build and compile model
vae2 = Model(x, [x_decoded, x_probability])
vae2.compile(optimizer='adam', loss=[vae_loss1, vae_loss2], loss_weights=[0.5, 0.5])

# weightsPath = "./weights/weights_vae2.hdf5"
# checkpointer = ModelCheckpoint(filepath=weightsPath, verbose=1, save_best_only=True)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)

# vae2.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
#         validation_data=(x_val, [x_val, x_val]), callbacks=[checkpointer, reduce_lr, plot_losses, updateBeta])
vae2.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
        validation_data=(x_val, [x_val, x_val]), )
代码语言:javascript
复制
Train on 5000 samples, validate on 500 samples
Epoch 1/30
5000/5000 [==============================] - 4s 729us/step - loss: 16.9416 - dense_29_loss: 34.2665 - dense_30_loss: -0.3832 - val_loss: 10.8511 - val_dense_29_loss: 22.3960 - val_dense_30_loss: -0.6938
Epoch 2/30
5000/5000 [==============================] - 3s 596us/step - loss: 10.6567 - dense_29_loss: 22.0604 - dense_30_loss: -0.7470 - val_loss: 10.5063 - val_dense_29_loss: 21.7948 - val_dense_30_loss: -0.7821
Epoch 3/30
5000/5000 [==============================] - 3s 587us/step - loss: 10.4868 - dense_29_loss: 21.7343 - dense_30_loss: -0.7607 - val_loss: 10.3348 - val_dense_29_loss: 21.4568 - val_dense_30_loss: -0.7872
Epoch 4/30
5000/5000 [==============================] - 3s 588us/step - loss: 10.3471 - dense_29_loss: 21.4561 - dense_30_loss: -0.7619 - val_loss: 10.2142 - val_dense_29_loss: 21.2168 - val_dense_30_loss: -0.7883
Epoch 5/30
5000/5000 [==============================] - 3s 583us/step - loss: 10.1537 - dense_29_loss: 21.0695 - dense_30_loss: -0.7622 - val_loss: 9.8815 - val_dense_29_loss: 20.5522 - val_dense_30_loss: -0.7891
Epoch 6/30
5000/5000 [==============================] - 4s 700us/step - loss: 9.8399 - dense_29_loss: 20.4424 - dense_30_loss: -0.7627 - val_loss: 9.6098 - val_dense_29_loss: 20.0090 - val_dense_30_loss: -0.7894
Epoch 7/30
5000/5000 [==============================] - 3s 661us/step - loss: 9.6009 - dense_29_loss: 19.9645 - dense_30_loss: -0.7627 - val_loss: 9.3354 - val_dense_29_loss: 19.4604 - val_dense_30_loss: -0.7896
Epoch 8/30
5000/5000 [==============================] - 4s 732us/step - loss: 9.3988 - dense_29_loss: 19.5603 - dense_30_loss: -0.7628 - val_loss: 9.1374 - val_dense_29_loss: 19.0644 - val_dense_30_loss: -0.7897
Epoch 9/30
5000/5000 [==============================] - 3s 698us/step - loss: 9.1813 - dense_29_loss: 19.1254 - dense_30_loss: -0.7629 - val_loss: 8.9246 - val_dense_29_loss: 18.6389 - val_dense_30_loss: -0.7896
Epoch 10/30
5000/5000 [==============================] - 4s 761us/step - loss: 8.9580 - dense_29_loss: 18.6789 - dense_30_loss: -0.7629 - val_loss: 8.6732 - val_dense_29_loss: 18.1362 - val_dense_30_loss: -0.7898
Epoch 11/30
5000/5000 [==============================] - 3s 695us/step - loss: 8.7604 - dense_29_loss: 18.2836 - dense_30_loss: -0.7629 - val_loss: 8.3996 - val_dense_29_loss: 17.5891 - val_dense_30_loss: -0.7899
Epoch 12/30
5000/5000 [==============================] - 3s 628us/step - loss: 8.5640 - dense_29_loss: 17.8911 - dense_30_loss: -0.7630 - val_loss: 8.1917 - val_dense_29_loss: 17.1734 - val_dense_30_loss: -0.7899
Epoch 13/30
5000/5000 [==============================] - 3s 616us/step - loss: 8.3888 - dense_29_loss: 17.5405 - dense_30_loss: -0.7630 - val_loss: 7.9692 - val_dense_29_loss: 16.7284 - val_dense_30_loss: -0.7900
Epoch 14/30
5000/5000 [==============================] - 3s 638us/step - loss: 8.2512 - dense_29_loss: 17.2653 - dense_30_loss: -0.7630 - val_loss: 7.8081 - val_dense_29_loss: 16.4062 - val_dense_30_loss: -0.7900
Epoch 15/30
5000/5000 [==============================] - 4s 777us/step - loss: 8.1205 - dense_29_loss: 17.0040 - dense_30_loss: -0.7630 - val_loss: 7.6419 - val_dense_29_loss: 16.0737 - val_dense_30_loss: -0.7900
Epoch 16/30
5000/5000 [==============================] - 5s 972us/step - loss: 7.9886 - dense_29_loss: 16.7401 - dense_30_loss: -0.7630 - val_loss: 7.5225 - val_dense_29_loss: 15.8349 - val_dense_30_loss: -0.7900
Epoch 17/30
5000/5000 [==============================] - 4s 724us/step - loss: 7.8793 - dense_29_loss: 16.5216 - dense_30_loss: -0.7630 - val_loss: 7.4022 - val_dense_29_loss: 15.5944 - val_dense_30_loss: -0.7900
Epoch 18/30
5000/5000 [==============================] - 4s 753us/step - loss: 7.7752 - dense_29_loss: 16.3134 - dense_30_loss: -0.7630 - val_loss: 7.2394 - val_dense_29_loss: 15.2688 - val_dense_30_loss: -0.7900
Epoch 19/30
5000/5000 [==============================] - 4s 841us/step - loss: 7.6858 - dense_29_loss: 16.1345 - dense_30_loss: -0.7630 - val_loss: 7.1225 - val_dense_29_loss: 15.0350 - val_dense_30_loss: -0.7900
Epoch 20/30
5000/5000 [==============================] - 4s 801us/step - loss: 7.5966 - dense_29_loss: 15.9562 - dense_30_loss: -0.7630 - val_loss: 6.9988 - val_dense_29_loss: 14.7876 - val_dense_30_loss: -0.7900
Epoch 21/30
5000/5000 [==============================] - 4s 744us/step - loss: 7.5157 - dense_29_loss: 15.7943 - dense_30_loss: -0.7630 - val_loss: 6.9055 - val_dense_29_loss: 14.6010 - val_dense_30_loss: -0.7900
Epoch 22/30
5000/5000 [==============================] - 4s 847us/step - loss: 7.4233 - dense_29_loss: 15.6096 - dense_30_loss: -0.7630 - val_loss: 6.8266 - val_dense_29_loss: 14.4433 - val_dense_30_loss: -0.7900
Epoch 23/30
5000/5000 [==============================] - 3s 668us/step - loss: 7.3547 - dense_29_loss: 15.4724 - dense_30_loss: -0.7630 - val_loss: 6.7169 - val_dense_29_loss: 14.2239 - val_dense_30_loss: -0.7900
Epoch 24/30
5000/5000 [==============================] - 3s 653us/step - loss: 7.2897 - dense_29_loss: 15.3423 - dense_30_loss: -0.7630 - val_loss: 6.6318 - val_dense_29_loss: 14.0535 - val_dense_30_loss: -0.7900
Epoch 25/30
5000/5000 [==============================] - 3s 693us/step - loss: 7.2195 - dense_29_loss: 15.2020 - dense_30_loss: -0.7630 - val_loss: 6.5772 - val_dense_29_loss: 13.9444 - val_dense_30_loss: -0.7900
Epoch 26/30
5000/5000 [==============================] - 4s 846us/step - loss: 7.1611 - dense_29_loss: 15.0852 - dense_30_loss: -0.7630 - val_loss: 6.5330 - val_dense_29_loss: 13.8560 - val_dense_30_loss: -0.7900
Epoch 27/30
5000/5000 [==============================] - 4s 817us/step - loss: 7.1067 - dense_29_loss: 14.9764 - dense_30_loss: -0.7630 - val_loss: 6.4391 - val_dense_29_loss: 13.6681 - val_dense_30_loss: -0.7900
Epoch 28/30
5000/5000 [==============================] - 4s 779us/step - loss: 7.0528 - dense_29_loss: 14.8687 - dense_30_loss: -0.7630 - val_loss: 6.3591 - val_dense_29_loss: 13.5081 - val_dense_30_loss: -0.7900
Epoch 29/30
5000/5000 [==============================] - 4s 728us/step - loss: 6.9904 - dense_29_loss: 14.7439 - dense_30_loss: -0.7630 - val_loss: 6.2726 - val_dense_29_loss: 13.3353 - val_dense_30_loss: -0.7900
Epoch 30/30
5000/5000 [==============================] - 5s 901us/step - loss: 6.9431 - dense_29_loss: 14.6491 - dense_30_loss: -0.7630 - val_loss: 6.2600 - val_dense_29_loss: 13.3100 - val_dense_30_loss: -0.7900





<keras.callbacks.callbacks.History at 0x226ad466748>
代码语言:javascript
复制
with open('test.data', 'rb') as f:
    x_test = np.load(f)
print("number of testing users: ", x_test.shape[0])
代码语言:javascript
复制
number of testing users:  647
代码语言:javascript
复制
x_test = x_test[:600]
代码语言:javascript
复制
x_test.shape
代码语言:javascript
复制
(600, 500)
代码语言:javascript
复制
# x_test[0]
Calculating Recall

The way we are testing the trained system here is something like this. For each patient,

  1. We choose a random diagnosis of the M diagnoses for which the patient has the value 1 (The patient has undergone that diagnosis)
  2. We set that diagnosis to 0.
  3. We pass it through the network to arrive at the probability distribution for the diagnosis codes.
  4. Sort the diagnosis codes by their probabilities.
  5. The network was given an input with M-1 diagnosis. We know calculate the recall@k as the percentage of times the missing diagnosis was seen in the (M-1)+k top spots with respect to its probability.
代码语言:javascript
复制
x_test_hold_new = np.copy(x_test)
hold_out_ind_new = [np.random.choice(np.nonzero(i)[0]) for i in x_test[:,:473]]
for i in range(x_test.shape[0]) :
    x_test_hold_new[i][hold_out_ind_new[i]] = 0
代码语言:javascript
复制
def calc_heldout_recall_new(x_test, x_rec, k):
    count = 1.0
    tot = 1.0
    x_rank = np.argsort(x_rec)
    for i in range(x_rank.shape[0]):
        sm = np.sum(x_test[i])-1
        if sm < 5:
            continue
        else:
            tot +=1
            if hold_out_ind_new[i] in x_rank[i][-(k+sm):]:
                count+=1.0
    return count/tot
代码语言:javascript
复制
x_rec, x_prob = vae.predict(x_test_hold_new, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10, 15]:
    print(calc_heldout_recall_new(x_test, x_prob[:,:473], k))
代码语言:javascript
复制
# x_rec, x_prob = vae2.predict(x_test_hold_new, batch_size=batch_size)
x_rec, x_prob = vae.predict(x_test_hold_new, batch_size=batch_size)

for k in [1, 2, 3, 4, 5, 10,15]:
    print(calc_heldout_recall_new(x_test, x_prob[:,:473], k))
代码语言:javascript
复制
x_test_hold = np.copy(x_test)
hold_out_ind = [np.random.choice(np.nonzero(i)[0]) for i in x_test]
for i in range(x_test.shape[0]) :
    x_test_hold[i][hold_out_ind[i]] = 0
代码语言:javascript
复制
def calc_heldout_recall(x_test, x_rec, k):
    count = 1.0
    tot = 1.0
    x_rank = np.argsort(x_rec)
    for i in range(x_rank.shape[0]):
        sm = np.sum(x_test[i])-1
        if sm < 5:
            continue
        else:
            tot +=1
            if hold_out_ind[i] in x_rank[i][-(k+sm):]:
                count+=1.0
    return count/tot
代码语言:javascript
复制
x_rec, x_prob = vae.predict(x_test_hold, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10, 15]:
    print(calc_heldout_recall(x_test, x_prob, k))
代码语言:javascript
复制
x_rec, x_prob = vae2.predict(x_test_hold, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10]:
    print(calc_heldout_recall(x_test, x_prob, k))
Impact of different ways to calculating the objective functions

We can see that the recall@k where k = 1, 2, 3, 4, 5, 10, 15 is pretty significant, considering that the network had to choose among 1500 other diagnoses.

An interesting observation is that, the second approach of calculating the objective, captures the recalls for smaller 'k’s in a better way compared to the first approach. This is however the opposite when it comes to the larger 'k’

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021/04/05 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 文章大纲
  • Using Variational Autoencoders to predict future diagnosis
    • IMPLEMENTATION
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档