Source code for config.Config

#coding:utf-8
import torch 
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
import torch.optim as optim
import os
import time
import datetime
import ctypes
import json

[docs]class Config(object): r"""In this class, we set the configuration parameters, adopt C library for data and memory processing. In the following, we train models and test models. """ def __init__(self): self.lib = ctypes.cdll.LoadLibrary("./release/Base.so") self.lib.sampling.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64] self.lib.getHeadBatch.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] self.lib.getTailBatch.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] self.lib.testHead.argtypes = [ctypes.c_void_p] self.lib.testTail.argtypes = [ctypes.c_void_p] self.test_flag = False self.in_path = "./" self.out_path = "./" self.bern = 0 self.hidden_size = 100 self.ent_size = self.hidden_size self.rel_size = self.hidden_size self.train_times = 0 self.margin = 1.0 self.nbatches = 100 self.negative_ent = 1 self.negative_rel = 0 self.workThreads = 1 self.alpha = 0.001 self.lmbda = 0.000 self.log_on = 1 self.lr_decay=0.000 self.weight_decay=0.000 self.exportName = None self.importName = None self.export_steps = 0 self.opt_method = "SGD" self.optimizer = None def init(self): self.trainModel = None if self.in_path != None: self.lib.setInPath(ctypes.create_string_buffer(self.in_path, len(self.in_path) * 2)) self.lib.setBern(self.bern) self.lib.setWorkThreads(self.workThreads) self.lib.randReset() self.lib.importTrainFiles() self.relTotal = self.lib.getRelationTotal() self.entTotal = self.lib.getEntityTotal() self.trainTotal = self.lib.getTrainTotal() self.batch_size = self.lib.getTrainTotal() / self.nbatches self.batch_seq_size = self.batch_size * (1 + self.negative_ent + self.negative_rel) self.batch_h = np.zeros(self.batch_size * (1 + self.negative_ent + self.negative_rel), dtype = np.int64) self.batch_t = np.zeros(self.batch_size * (1 + self.negative_ent + self.negative_rel), dtype = np.int64) self.batch_r = np.zeros(self.batch_size * (1 + self.negative_ent + self.negative_rel), dtype = np.int64) self.batch_y = np.zeros(self.batch_size * (1 + self.negative_ent + self.negative_rel), dtype = np.float32) self.batch_h_addr = self.batch_h.__array_interface__['data'][0] self.batch_t_addr = self.batch_t.__array_interface__['data'][0] self.batch_r_addr = self.batch_r.__array_interface__['data'][0] self.batch_y_addr = self.batch_y.__array_interface__['data'][0] if self.test_flag: self.lib.importTestFiles() self.test_h = np.zeros(self.lib.getEntityTotal(), dtype = np.int64) self.test_t = np.zeros(self.lib.getEntityTotal(), dtype = np.int64) self.test_r = np.zeros(self.lib.getEntityTotal(), dtype = np.int64) self.test_h_addr = self.test_h.__array_interface__['data'][0] self.test_t_addr = self.test_t.__array_interface__['data'][0] self.test_r_addr = self.test_r.__array_interface__['data'][0]
[docs] def get_ent_total(self): r"""This method gets the entity total of knowledge base. """ return self.entTotal
[docs] def get_rel_total(self): r""" This method gets the relation total of knowledge base. """ return self.relTotal
def set_lmbda(self, lmbda): self.lmbda = lmbda
[docs] def set_opt_method(self, method): r"""This method sets the optimizer for your model. Arguments: optimizer: ``SGD`` ``Adagrad`` ``Adam`` and ``Adadelta`` can be chosen for optimizing. """ self.opt_method = method
[docs] def set_test_flag(self, flag): r"""This method sets whether we test our model. Arguments: flag (bool): if True, we test the model. .. note:: Note that test_flag must be set **after** all the other configuration parameters are set. """ self.test_flag = flag
[docs] def set_log_on(self, flag): r"""This method sets whether to log on the loss value. Arguments: flag (bool): if True, logs on the loss value when training. """ self.log_on = flag
[docs] def set_alpha(self, alpha): r"""This mothod sets the learning rate for gradient descent. Arguments: alpha (float): the learning rate. """ self.alpha = alpha
[docs] def set_in_path(self, path): r"""This method sets the path of benchmark. """ self.in_path = path
[docs] def set_out_files(self, path): r"""This method sets where to emport embedding matrix. """ self.out_path = path
[docs] def set_bern(self, bern): r"""This method sets the strategy for negative sampling. Arguments: bern: "bern" or "unif" """ self.bern = bern
[docs] def set_dimension(self, dim): r"""This method sets the entity dimension and relation dimension at the same time. Arguments: dim (int): the dimension of entity and relation. """ self.hidden_size = dim self.ent_size = dim self.rel_size = dim
[docs] def set_ent_dimension(self, dim): r"""This method sets the dimension of entity. Arguments: dim (int): the dimension of entity. """ self.ent_size = dim
[docs] def set_rel_dimension(self, dim): r"""This method sets the dimension of relation. Arguments: dim (int): the dimension of relation. """ self.rel_size = dim
[docs] def set_train_times(self, times): r"""This method sets the rounds for training. Arguments: times (int): rounds for training. """ self.train_times = times
[docs] def set_nbatches(self, nbatches): r"""This method sets the number of batch. Arguments: nbatches (int): number of batch. """ self.nbatches = nbatches
[docs] def set_margin(self, margin): r"""This method sets the margin for the widely used pairwise margin-based ranking loss. Arguments: margin (float): margin for margin-based ranking function """ self.margin = margin
[docs] def set_work_threads(self, threads): r"""We can use multi-threading trainning for accelaration. This method sets the numebr of threads. Arguments: threads (int): number of working threads. """ self.workThreads = threads
[docs] def set_ent_neg_rate(self, rate): r"""the number of negatives generated per positive training sample influnces the experiment results. This method sets the number of negative entities constructed per positive sample. Arguments: rate (int): the number of negative entities per positive sample. """ self.negative_ent = rate
[docs] def set_rel_neg_rate(self, rate): r"""This method sets the number of negative relations per positive sample. Arguments: rate (int): the number of negative relations per positive sample. """ self.negative_rel = rate
[docs] def set_import_files(self, path): r"""Model paramters are exported automatically every few rounds. This method sets the path to find exported model parameters. Arguments: path: path to automatically exported model parameters. """ self.importName = path
[docs] def set_export_files(self, path): r"""Model parameters will be exported to this path automatically. Arguments: path: files that model parameters will be exported to. """ self.exportName = path
[docs] def set_export_steps(self, steps): r""" This method sets that every few steps the model paramters will be exported automatically. Arguments: steps (int): Models will be exported via torch.save() automatically every few rounds """ self.export_steps = steps
[docs] def set_lr_decay(self,lr_decay): r"""This method sets the learning rate decay for ``Adagrad`` optim method. Arguments: lr_decay (float): learning rate decay """ self.lr_decay=lr_decay
[docs] def set_weight_decay(self,weight_decay): r"""This method sets the weight decay for ``Adagrad`` optim method. Arguments: weight_decay (float): weight decay for ``Adagrad``. """ self.weight_decay=weight_decay
[docs] def sampling(self): r"""In this function, we choose positive samples and construct negative samples. """ self.lib.sampling(self.batch_h_addr, self.batch_t_addr, self.batch_r_addr, self.batch_y_addr, self.batch_size, self.negative_ent, self.negative_rel)
[docs] def save_pytorch(self): r"""This method saves the model paramters to ``self.exportName`` which was set by :func:`set_export_files`. """ torch.save(self.trainModel.state_dict(), self.exportName)
[docs] def restore_pytorch(self): r"""This method restore model through ``torch.load`` """ self.trainModel.load_state_dict(torch.load(self.importName))
[docs] def export_variables(self, path = None): r"""This method export model paramters through ``torch.save``. Arguments: path: If None, this function euquals to :func:`save_pytorch`, else save paramters to ``path`` """ if path == None: torch.save(self.trainModel.state_dict(), self.exportName) else: torch.save(self.trainModel.state_dict(), path)
[docs] def import_variables(self, path = None): r"""This method export model paramters through ``torch.load``. Arguments: path: If None, this function euquals to :func:`restore_pytorch`, else save paramters to ``path`` """ if path == None: self.trainModel.load_state_dict(torch.load(self.importName)) else: self.trainModel.load_state_dict(torch.load(path))
def get_parameter_lists(self): return self.trainModel.cpu().state_dict() def get_parameters_by_name(self, var_name): return self.trainModel.cpu().state_dict().get(var_name)
[docs] def get_parameters(self, mode = "numpy"): r"""This method gets the model paramters. Arguments: mode: if ``numpy``, returns model parameters as numpy array, if ``list``, returns those as list """ res = {} lists = self.get_parameter_lists() for var_name in lists: if mode == "numpy": res[var_name] = lists[var_name].numpy() if mode == "list": res[var_name] = lists[var_name].numpy().tolist() else: res[var_name] = lists[var_name] return res
[docs] def save_parameters(self, path = None): r"""This method save model parameters as json files when training finished. Arguments: path: if None, save parameters to ``self.out_path`` which was set by :func:`set_out_files`. """ if path == None: path = self.out_path f = open(path, "w") f.write(json.dumps(self.get_parameters("list"))) f.close()
def set_parameters_by_name(self, var_name, tensor): self.trainModel.state_dict().get(var_name).copy_(torch.from_numpy(np.array(tensor))) def set_parameters(self, lists): for i in lists: self.set_parameters_by_name(i, lists[i])
[docs] def set_model(self, model): r"""This method sets the traing model and optimizer method. Arguments: model: training model. We can choose from :class:``models.TransE`` :class:``models.TransH`` :class:``models.TransR`` :class:``models.TransD`` :class:``models.RESCAL`` :class:``models.DistMult`` and :class:``models.ComplEx`` """ self.model = model self.trainModel = self.model(config = self) self.trainModel.cuda() if self.optimizer != None: pass elif self.opt_method == "Adagrad" or self.opt_method == "adagrad": self.optimizer = optim.Adagrad(self.trainModel.parameters(), lr=self.alpha,lr_decay=self.lr_decay,weight_decay=self.weight_decay) elif self.opt_method == "Adadelta" or self.opt_method == "adadelta": self.optimizer = optim.Adadelta(self.trainModel.parameters(), lr=self.alpha) elif self.opt_method == "Adam" or self.opt_method == "adam": self.optimizer = optim.Adam(self.trainModel.parameters(), lr=self.alpha) else: self.optimizer = optim.SGD(self.trainModel.parameters(), lr=self.alpha)
[docs] def run(self): r"""In this function, we train the model""" if self.importName != None: self.restore_pytorch() for epoch in range(self.train_times): res = 0.0 for batch in range(self.nbatches): self.sampling() self.optimizer.zero_grad() loss = self.trainModel() res = res + loss.data[0] loss.backward() self.optimizer.step() if self.exportName != None and (self.export_steps!=0 and epoch % self.export_steps == 0): self.save_pytorch() if self.log_on == 1: print epoch print res if self.exportName != None: self.save_pytorch() if self.out_path != None: self.save_parameters(self.out_path)
[docs] def test(self): r"""In this function, we test the model.""" if self.importName != None: self.restore_pytorch() #self.trainModel.cuda() total = self.lib.getTestTotal() for epoch in range(total): self.lib.getHeadBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr) res = self.trainModel.predict(self.test_h, self.test_t, self.test_r) self.lib.testHead(res.data.numpy().__array_interface__['data'][0]) self.lib.getTailBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr) res = self.trainModel.predict(self.test_h, self.test_t, self.test_r) self.lib.testTail(res.data.numpy().__array_interface__['data'][0]) if self.log_on: print epoch self.lib.test()