Слияние кода завершено, страница обновится автоматически
# coding=utf-8
from gnn.data.dataset import GraphDataset, WhiteSpaceTokenizer
from gnn.data.example import load_M10, load_cora, load_dblp
from gnn.model.gcn import GCN, GCNTrainer
import tensorflow as tf
# eager mode must be enabled
from tensorflow.contrib.eager.python import tfe
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
tfe.enable_eager_execution()
# read graph dataset: data/M10 data/dblp
# dataset = GraphDataset("data/dblp", ignore_featureless_node=True)
dataset = load_M10("data/M10", ignore_featureless_node=True)
adj = dataset.adj_matrix(sparse=True)
feature_matrix, feature_masks = dataset.feature_matrix(bag_of_words=True, sparse=True)
labels, label_masks = dataset.label_list_or_matrix(one_hot=False)
train_node_indices, test_node_indices, train_masks, test_masks = dataset.split_train_and_test(training_rate=0.3)
gcn_model = GCN([16, dataset.num_classes()], drop_rate=0.1)
gcn_trainer = GCNTrainer(gcn_model)
gcn_trainer.train(adj, feature_matrix, labels, train_masks, test_masks, learning_rate=1e-3, l2_coe=1e-3)
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )