#############################################################
###           Author: Xin Zhao												    ###
###           Date: 2018_06_15                            ###
###        DMnet for Text Clustering                      ###
#############################################################

# ADMM for variable selection
admm.norm <- function(w, E, Ebynode, xi, lambda, rho = 1000, method = "DM++",
                      epsilon1 = 1.e-8,epsilon2 = 1.e-14, s = 1.e-5){
  
  # G : initial value
  iter = 1
  maxiter = 300
  total.time = 0
  D = length(Ebynode)
  V = ncol(w)
  G = initial.MOM(w, Ebynode) + 0.01
  Theta = log(G)
  z = u = vector("list", D)
  for (i in 1:D){
    if (i %in% as.numeric(rownames(Theta))){
      z[[i]] = matrix(Theta[as.character(i),], nrow = length(Ebynode[[i]]), ncol = V, byrow = TRUE)   
      u[[i]] = matrix(0, nrow = length(Ebynode[[i]]), ncol=V, byrow=TRUE)
      rownames(z[[i]]) = rownames(u[[i]]) = as.character(Ebynode[[i]]) 
    }
  }
  for (iter in 1:maxiter) {
    start.time = Sys.time()
    for (d in 1:D) {
      if (d %in% as.numeric(rownames(Theta))){
        Theta[as.character(d),] = grad.desc(Theta[as.character(d),], w[as.character(d),], z[[d]], u[[d]], rho = rho, method = method, epsilon = 1.e-10)
      }
    }
    for (e in 1:nrow(E)) {
      e1 = E[e,1]
      e2 = E[e,2]
      temp1 = Theta[as.character(e1),] + u[[e1]][as.character(e2), ] - (Theta[as.character(e2),]+u[[e2]][as.character(e1), ])
      temp = rho*sqrt(sum(temp1*temp1))
      if (length(which(temp1 == 0)) == V){temp=1}
      c_theta = max(1-lambda*xi[e]/temp,0.5)
      z[[e1]][as.character(e2), ] = c_theta*(Theta[as.character(e1),]+u[[e1]][as.character(e2),]) + (1-c_theta)*(Theta[as.character(e2),]+u[[e2]][as.character(e1),])
      z[[e2]][as.character(e1), ] = (1-c_theta)*(Theta[as.character(e1),]+u[[e1]][as.character(e2),]) + c_theta*(Theta[as.character(e2),]+u[[e2]][as.character(e1),])
    }
    for (d in 1:D) {
      if (d %in% as.numeric(rownames(Theta))){
        u[[d]] = u[[d]] + matrix(Theta[as.character(d), ], length(Ebynode[[d]]), V, byrow=TRUE) - z[[d]]
      }
    }
    
    end.time = Sys.time()
    run.time = end.time - start.time
    temp2 = iter/100
    if (trunc(temp2)-temp2 == 0){print(c(iter = iter, time = run.time))}
    total.time = total.time + run.time
  }
  G=exp(Theta)
  logLik = L(G,w)
  return(list(G = G, Theta = Theta, z = z, u = u, neg.loglikelihood = logLik, total.time = total.time, lambda = lambda, rho = rho,s = s))
}

oracle <- function(D,V,labels, w, Ebynode, net.type = "net1", method){
  
  if (net.type == "net1"){
    thrsh = 5
    maxit = 100
  }
  if (net.type == "net2"){
    thrsh = 40
    maxit = 100
  }
  numc = max(na.omit(labels))
  a = initial.MOM(w, Ebynode) + 0.001
  rownames(a) = rownames(w)
  rest = NULL
  for (i in 1: numc){
    if (is.null(names(labels)) == TRUE){index = which(labels == i)}
    else{index = as.numeric(names(which(labels == i)))}
    L11 <- function(rg) {
      # rg : V-vector
      w = w[as.character(index), ]
      D = length(index)
      ws = rowSums(w)
      g = matrix(rep(rg, D), nrow = D, byrow = T)
      gs = rowSums(g)
      res = sum(lgamma(gs) - lgamma(ws+gs) + rowSums(lgamma(w+g) - lgamma(g))) + 0.005*sum(log(g))
      return(-res)
    }
    
    L22 <- function(rg) {
      # rg : V-vector
      w = w[as.character(index), ]
      D = length(index)
      ws = rowSums(w)
      g = matrix(rep(rg, D), nrow = D, byrow = T)
      gs = rowSums(g)
      res = sum(lgamma(gs) - lgamma(ws+gs) + rowSums(indicator(w)*log(g))) + 0.005*sum(log(g))
      return(-res)
    }
    
    L33 <- function(rg) {
      # rg : V-vector
      w = w[as.character(index), ]
      D = length(index)
      ws = rowSums(w)
      g = matrix(rep(rg, D),nrow = D,byrow = T)
      gs = rowSums(g)
      Ni = rowSums(w)
      res = - sum(Ni*log(gs) - rowSums(w*log(g))) + 0.005*sum(log(g))
      return(-res)
    }
    # first order derivative of L r.t b
    dL11 <- function(rg) {
      # rg : V-vector
      w = w[as.character(index), ]
      D = length(index)
      g = matrix(rep(rg, D), nrow = D, byrow = T)
      gs = rowSums(g)
      ws = rowSums(w)
      res <- 	rowSums(t(digamma(gs) - digamma(ws+gs)+ digamma(w+g) - digamma(g))) + 0.005*sum(1/g)
      return(-res)
    }
    
    dL22 <- function(rg) {
      # rg : V-vector
      w = w[as.character(index), ]
      D = length(index)
      g = matrix(rep(rg, D), nrow = D, byrow = T)
      gs = rowSums(g)
      ws = rowSums(w)
      res = rowSums(t(digamma(gs) - digamma(ws+gs) + indicator(w))) + 0.005*sum(1/g)
      return(-res)
    }
    
    dL33 <- function(rg) {
      # rg : V-vector
      w = w[as.character(index), ]
      D = length(index)
      g = matrix(rep(rg, D), nrow = D, byrow = T)
      gs = rowSums(g)
      ws = rowSums(w)
      res =	rowSums(t(ws*g/gs + w)) + 0.005*sum(1/g)
      return(-res)
    }
    
    if (length(index) >= thrsh){
      par = colMeans(a[as.character(index), ])
      if (method == "DM"||method == "DMM"){
        temp.op = optim(par = par, fn = L11, gr = dL11, control=c(maxit=maxit))
      }else if(method == "DM+"){
        temp.op = optim(par = par, fn = L22, gr = dL22, control=c(maxit=maxit))
      }else if(method == "DM++"){
        temp.op = optim(par = par, fn = L33, gr = dL33, control=c(maxit=maxit))
      }
      
      temp = temp.op$par
      a[as.character(index), ] = matrix(temp, nrow = length(index), ncol = V, byrow = TRUE)
    }else{
      rest = c(rest, index)
    }
  }
  # if (length(rest) > 0){
  #   print(TRUE)
  #   mean.a = colMeans(a[as.character(rest), ])
  #   temp.op = optim(par = mean.a, fn = L11, gr = dL11, control=c(maxit=maxit))
  #   a[as.character(rest),] = temp.op$par
  #   # a[as.character(rest),] = matrix(mean.a, nrow = length(rest), ncol = V, byrow = TRUE)
  #   # print(which(a[as.character(rest),] == 0))
  # }
  return (a)
}

net.cv <- function(w, E, adj, net.dis, Ebynode, D1, D2, D3, method, itvl = 0.96, lambda.max = 35, rho, net.type = "net1", seeds=2){
  
  # net.type = "net2"; itvl = 0.96; lambda.max = 35; rho = 1000; method = "DM"; seeds = 16
  V = ncol(w)
  D = nrow(w)
  fold = 5
  Index = NULL
  ncol.Index = 0
  lambda.index = 0
  if (net.type == "net1"){
    j = 1
  }else{
    j = 2}
  temp = cal.weight(w, E, eta)
  xi = temp$xi
  # lambda.max = max.lambda(lambda.max, w, E, Ebynode, xi, rho = rho, itvl = 0.98, method = method)
  # Index.txt = paste("../plotcv", "/cv.net", j, ".", D, ".", method, ".txt", sep = "")
  print(c(net.type = net.type, rho = rho, method = method, D = D))
  groups = list(rep(1,D))
  grps = NULL
  while ((ncol.Index < 10) && (length(groups)<=40)) {
    while(class(grps) == "NULL"){
      result = admm.norm(w, E, Ebynode, xi, lambda = lambda.max, rho = rho, method = method)
      grps = clus.into.grp(result$z, E, Ebynode, result$lambda, del.vec = NULL, s = result$s)
      if (class(grps) == "NULL") {lambda.max = lambda.max * itvl}
    }
    plot.vert(grps$graph.g, D1, D2, D3, mode = "admm", col = c("coral2", "aquamarine", "cadetblue4"), lambda = result$lambda, rho = result$rho)
    groups = grps$group
    labels= label(groups);labels
    ve= NULL
    loss.fold = loss.fold.be = rep(0, fold)
    set.seed(seeds)
    sample.set = sample(D,D)
    for (m in 1:fold) {
      # test.set = c(((m-1)*D1/fold + 1):(m*D1/fold), (D1+(m-1)*D2/fold + 1):(D1+m*D2/fold), (D1+D2+(m-1)*D3/fold + 1) : (D1+D2+m*D3/fold))
      test.set = sort(sample.set[((m-1)*D/fold + 1): (m*D/fold)])
      train.set = (1:D)[-test.set]
      train.adj = adj[train.set, train.set]
      test.adj = adj[test.set, test.set]
      net.dis.train = net.dis[train.set, train.set]

      # deal with scattered nodes
      scatter.node = which(rowSums(train.adj)==0)
      if (length(scatter.node) != 0){
        for (r in 1:length(scatter.node)) {
          dis.node.r = net.dis.train[scatter.node[r],]
          min.dis.r = min(dis.node.r[dis.node.r!=min(dis.node.r)])
          add2r = which(dis.node.r == min.dis.r)
          train.adj[scatter.node[r], add2r] = 1
          train.adj[add2r, scatter.node[r]] = 1
        }
      }
      rownames(train.adj) = colnames(train.adj) = train.set
      E.train = cal.E(train.adj)#;print(1)
      Ebynode.train = cal.Ebynode(E.train)
      w.train = w[train.set,]
      D.train = nrow(w.train)
      xi.dis = cal.weight(w[train.set,], E.train, eta)
      xi.train = xi.dis$xi
      net.dis.train = xi.dis$net.dis
      w.test = w[test.set,]
      # rownames(w.train) = train.set
      # rownames(w.test) = test.set
      # w = w.train; E = E.train;Ebynode = Ebynode.train; xi = xi.train; lambda = lambda.max; rho = rho; method = method
      result.train = admm.norm(w = w.train, E = E.train,Ebynode = Ebynode.train, xi = xi.train, lambda = lambda.max, rho = rho, method = method)
      grps.train = clus.into.grp(result.train$z, E.train, Ebynode.train, result.train$lambda, del.vec = NULL,s = result$s)
      group.train = grps.train$group
      labels.train = label(group.train)
      labels.train = labels.train[which(is.na(labels.train)==FALSE)]
      names(labels.train) = train.set

      train.theta.be = result.train$Theta
      # D = D.train; labels =  labels.train; w = w.train; Ebynode = Ebynode.train; g = result.train$G
      train.G = oracle(D.train, V, labels.train, w.train, Ebynode.train, net.type = net.type, method = method)
      rownames(train.G) = rownames(train.theta.be) = train.set
      train.theta = log(train.G)
      L.mat = L.mat.be = matrix(0, nrow = nrow(w.test), ncol = length(group.train))
      rownames(L.mat) = rownames(L.mat.be) = as.character(test.set)
      for (k in 1:max(labels.train)) {
        if (length(which(labels.train==k)) > 10) {
          try.theta.be = train.theta.be[names(which(labels.train==k)), ]
          try.theta = train.theta[names(which(labels.train==k)), ]
          for(j in 1:length(test.set)){
            L.mat[j, k] = fd(colMeans(try.theta), w.test[j,])
            L.mat.be[j, k] = fd(colMeans(try.theta.be), w.test[j,])
          }
        }
      }
      L.mat[, which(colSums(L.mat)==0)] = max(L.mat)
      L.mat.print = cbind(L.mat, apply(L.mat, 1, which.min))

      L.mat.be[, which(colSums(L.mat.be)==0)] = max(L.mat.be)
      L.mat.be.print = cbind(L.mat.be, apply(L.mat.be, 1, which.min))
      loss.fold[m] = sum(apply(L.mat, 1, min))
      loss.fold.be[m] = sum(apply(L.mat.be, 1, min))
    }
    Index = cbind(Index, c(lambda.max, length(groups), mean(loss.fold), mean(loss.fold.be)))
    print(Index)
    lambda.index = which(Index[3,] == min(Index[3,]))
    # print(c(lambda.index = lambda.index))
    ncol.Index = ncol(Index)
    lambda.max = lambda.max * itvl
    grps = NULL
  }
  rownames(Index) = c("lambda", "K", "loss", "loss.be")
  # write.table(Index, Index.txt, sep = ",", col.names=TRUE)
  return(list(Index = Index, lambda = Index[1 , lambda.index], lambda.index = lambda.index, eta = eta, xi = xi))
}

max.lambda <- function(testlambda, w, E, Ebynode, xi, rho , itvl = 0.96, method = "DM"){
  
  lst=NULL
  D = nrow(w)
  while (length(lst) <= 2) {
    lst=NULL
    result = admm.norm(w, E, Ebynode, xi,lambda = testlambda, rho = rho, method = method)
    grps = clus.into.grp(result$z, E, Ebynode, result$lambda, del.vec = NULL,s = result$s)
    plot.vert(grps$graph.g, D1, D2, D3, mode = "admm", col = c("coral2","aquamarine","cadetblue4"), lambda = result$lambda, rho = result$rho)
    groups = grps$group
    labels = label(groups);labels
    for (j in 1:length(groups)){if (length(groups[[j]]) >= D/7){lst = c(lst,j)}}
    testlambda = testlambda*itvl
  }
  return(testlambda/itvl/itvl)
}

net.cv.realdata <- function(w, adj, xi, net.dis, D1, D2, D3, method, itvl = 0.98, testlambda, rho, dataset = "cora"){
  
  V = ncol(w)
  D = nrow(w)
  fold = 5
  # # net1
  # net2
  # seq.lambda = lambda.grid(lambda.min=13, lambda.max=28, intervals = 0.98)
  Index = NULL
  if (dataset == "Cora"){groupdata = groupCora}
  if (dataset == "Citeseer"){groupdata = groupCiteseer}
  lambda.max = max.lambda.realdata(testlambda = 80, w, E, Ebynode, xi, rho = rho, itvl = 0.96, method = method, dataset = dataset)
  while (ncol.Index < 15) {
    result = admm.norm(w, E, Ebynode, xi, lambda = lambda.max, rho = rho, method = method)
    grps = clus.into.grp(result$z, E, Ebynode, result$lambda, del.vec = NULL,s = result$s)
    groups = grps$group
    labels= label(groups);labels
    lst = NULL
    for (j in 1:length(groups)){if (length(groups[[j]]) >= 70){lst=c(lst,j)}};lst
    for (k in 1:length(lst)) {
      count = count + length(groups[[lst[[k]]]])
      print(length(groups[[lst[[k]]]]))
    };
    loss.fold.be = rep(0, fold)
    
    for (m in 1:fold) {
      test.1 = ((m-1)*D1/fold + 1):(m*D1/fold)
      test.2 = D1+ ((m-1)*D2/fold + 1):(m*D2/fold)
      test.3 = D1+D2+((m-1)*D3/fold + 1):(m*D3/fold)
      test.4 = D1+D2+D3+((m-1)*D4/fold + 1):(m*D4/fold)
      test.5 = D1+D2+D3+D4+((m-1)*D5/fold + 1):(m*D5/fold)
      test.6 = D1+D2+D3+D4+D5+((m-1)*D6/fold + 1):(m*D6/fold)
      test.set = c(test.1, test.2, test.3, test.4, test.5, test.6)
      if (dataset == 'cora') {
        test.7 = D1+D2+D3+D4+D5+D6+((m-1)*D7/fold + 1):(m*D7/fold)
        test.set = c(((m-1)*D/fold + 1):(m*D/fold))
      }
      train.set = (1:D)[-test.set]
      train.adj = adj[train.set, train.set]
      test.adj = adj[test.set, test.set]
      net.dis.train = net.dis[train.set, train.set]
      
      # deal with scattered nodes
      scatter.node = which(rowSums(train.adj)==0)
      if (length(scatter.node) != 0){
        for (r in 1:length(scatter.node)) {
          dis.node.r = net.dis.train[scatter.node[r],]
          min.dis.r = min(dis.node.r[dis.node.r!=min(dis.node.r)])
          add2r = which(dis.node.r == min.dis.r)
          train.adj[scatter.node[r], add2r] = 1
          train.adj[add2r, scatter.node[r]] = 1
        }
      }
      E.train = cal.E(train.adj)
      Ebynode.train = cal.Ebynode(E.train)
      w.train = w[train.set,]
      D.train = nrow(w.train)
      xi.dis = cal.weight(w[train.set,], E.train, eta)
      xi.train = xi.dis$xi
      net.dis.train = xi.dis$net.dis
      w.test = w[test.set,]
      result.train = admm.norm(w.train, E.train, Ebynode.train, xi.train, lambda.max, rho = rho, method = method)
      grps.train = clus.into.grp(result.train$z, E.train, Ebynode.train, result.train$lambda, Del.vec = NULL,s = result$s)
      group.train = grps.train$group
      labels.train= label(group.train)
      lst.train=NULL
      for (j in 1:length(group.train)){if (length(group.train[[j]]) >= 5){lst.train = c(lst.train, j)}}
      
      train.theta.be = result.train$Theta
      L.mat.be = matrix(0, nrow = nrow(w.test), ncol = length(lst.train))
      for (k in 1:length(lst.train)) {
        try.theta.be = train.theta.be[lst.train[[k]][1], ]
        for(j in 1:length(test.set)){
          L.mat.be[j, k] = fd(try.theta.be, w.test[j,])
        }
      }
      loss.fold.be[m] = sum(apply(L.mat.be, 1, min))
    }
    Index = cbind(Index, c(lambda.max, length(groups), mean(loss.fold.be)))
    print(Index)
    ncol.Index = ncol(Index)
    lambda.max = lambda.max * itvl
  }
  return(list(Index = Index, lambda = lambda, lambda.index = lambda.index))
}

measures <- function(labels, D1, D2, D3){
  D = length(labels)
  groupdata = NULL
  for (num in 1:3){groupdata[[num]] = seq((1+D1*(num-1)),num*D1)}
  labelstrue = c(rep(1, D1), rep(2, D2), rep(3, D3))
  
  TP = TN = FP = FN = 0
  for (i in 1 : (D-1)) {
    for (j in (i+1):D) {
      if (labelstrue[i] == labelstrue[j]){
        if (labels[i] == labels[j]){TP = TP + 1}
        if (labels[i] != labels[j]){FN = FN + 1}
      }
      if (labelstrue[i] != labelstrue[j]){
        if (labels[i] == labels[j]){FP = FP + 1}
        if (labels[i] != labels[j]){TN = TN + 1}
      }
    }
  }
  
  # Ajusted Rand Index
  tab = table(c(rep(1, D1), rep(2, D2), rep(3, D3)),labels)
  ARI = randIndex(tab)
  
  # Rand Index
  RI = (TP + TN)/(TP + TN + FN + FP)
  precision = TP/(TP + FP)
  recall = TP/(TP + FN)
  
  # F-measure
  Fmeasure = 2 * precision * recall/(precision + recall)
  
  # Jaccard Index
  Jaccard = TP/(TP + FP + FN)
  
  # Purity
  Purity = sum(apply(tab, 2, max))/D
  return(c(Jaccard = Jaccard, ARI = ARI, Purity = Purity, RI = RI, Fmeasure = Fmeasure))
}

do.hundred <- function(true.G, lambda, rho, D1, D2, D3, net.type = "net2", method = "DM"){
  
  times = 100
  path.0 = "../label_txt"
  D = D1 + D2 + D3
  if (net.type == "net1"){
    j = 1
  }else{
    j = 2}
  groupdata = NULL
  for (num in 1:3){groupdata[[num]] = seq((1+D1*(num-1)),num*D1)}
  labelstrue = c(rep(1, D1), rep(2, D2), rep(3, D3))
  file_bigw = paste(path.0, "/bigw_", D, ".txt", sep = "")
  file_bigedgelist = paste(path.0, "/bigedgelist_net", j,"_", D , ".txt", sep = "")
  label.txt = paste(path.0, "/labels_net", j, "_", method, D, ".txt", sep = "")
  print(label.txt)
  print(c(net.type = net.type, D = D, method = method, rho = rho, lambda = lambda))
  error.list = mat = NULL
  bigedgelist = read.csv(file_bigedgelist, header = FALSE)
  bigw = read.csv(file_bigw, header = FALSE)
  bigw = as.matrix(bigw)
  label.mat = matrix(0, nrow = times, ncol = D)
  groupdata = NULL; avg.time =NULL
  for (num in 1:3){groupdata[[num]] = seq((1+D1*(num-1)),num*D1)}
  for (i in 1:times) {
    w =  bigw[((i-1)*D + 1): (D*i),]
    rownames(w) = 1:D
    edgelist = bigedgelist[i, ]
    edgelist = strsplit(as.vector(edgelist)," ")
    tempedge = as.numeric(edgelist[[1]])+1
    E = matrix(tempedge, ncol = 2)
    Ebynode = cal.Ebynode(E)
    temp = cal.weight(w, E, eta)
    xi = temp$xi
    net.dis = temp$net.dis
    adj <- cal.adj(E, D)
    print(paste("choose lambda equals to", lambda))
    result = admm.norm(w, E, Ebynode, xi, lambda = lambda, rho= rho, method)
    grps = clus.into.grp(result$z, E, Ebynode, result$lambda, del.vec = NULL, s = result$s)
    plot.vert(grps$graph.g, D1, D2, D3, i=NULL, mode = "admm", 
              col=c("coral2", "aquamarine", "cadetblue4"), lambda = result$lambda, rho = result$rho)
    group = grps$group
    labels = label(group)
    names(labels) = 1:D
    label.mat[i, ] = labels
    avg.time = c(avg.time, result$total.time)
    a = oracle(D,V,labels, w, Ebynode, net.type = net.type, method)
    error.list = c(error.list, sum((a-true.G)^2)^0.5)
    mat = rbind(mat, measures(as.numeric(label.mat[i,]), D1, D2, D3))
    res.mean = apply(mat, 2, mean)
    print(c(i = i, mean.time = mean(avg.time), mean.error = mean(error.list)))
    print(res.mean)
  }
  write.table(label.mat, file = label.txt, col.names = F, sep = ",")
}

cal.ve <- function(true.G, D1, D2, D3, net.type = "net2", method = "DM"){

  # net.type = "net2"; method = "DM+"

  times=100
  path.0 = "../label_txt"
  if (net.type == "net1"){
    j = 1
  }else{j = 2}
  D = D1 + D2 + D3
  print(c(net.type = net.type, D = D, method = method))
  file_bigw = paste(path.0, "/bigw_", D, ".txt", sep = "")

  if (method == "DMM"){
    label.txt = paste(path.0, "/labels", "_", method, D, ".txt", sep = "")
  }else{
    file_bigedgelist = paste(path.0, "/bigedgelist_net", j, "_", D , ".txt", sep = "")
    label.txt = paste(path.0, "/labels_net", j, "_", method, D, ".txt", sep = "")
  }
  bigedgelist = read.csv(file_bigedgelist, header = FALSE)
  bigw = read.csv(file_bigw, header = FALSE)
  bigw = as.matrix(bigw)
  avg.time = log.ve = log.ve.real =ve = ve.real=NULL
  groupdata = NULL
  print(label.txt)
  label.mat = read.csv(label.txt, header = FALSE)
  label.mat = as.matrix(label.mat)
  # write.table(label.mat[,2:ncol(label.mat)], file = label.txt, sep = ",", append = F)
  label.mat = label.mat[,2:ncol(label.mat)]
  for (num in 1:3){groupdata[[num]] = seq((1+D1*(num-1)),num*D1)}
  labelsreal=c(rep(1, D1), rep(2, D2), rep(3, D3))
  for (i in 1:times) {
    w =  bigw[((i-1)*D+1) : (i*D), ]
    rownames(w) = 1:D
    edgelist = bigedgelist[i, ]
    edgelist = strsplit(as.vector(edgelist)," ")
    tempedge = as.numeric(edgelist[[1]])+1
    E = matrix(tempedge, ncol = 2)
    Ebynode = cal.Ebynode(E)
    temp = cal.weight(w, E, eta = eta)
    xi = temp$xi
    net.dis = temp$net.dis
    labels = label.mat[i, ]

    a= oracle(D,V,as.vector(labels),w, Ebynode, net.type = net.type, method = method)
    a.real = oracle(D,V,labelsreal,w, Ebynode, net.type = net.type, method = method)
    log.ve = c(log.ve,norm(log(a)-log(true.G),"F"))
    log.ve.real = c(log.ve.real,norm(log(a.real)-log(true.G),"F"))
    ve = c(ve,norm(a-true.G,"F"))
    ve.real = c(ve.real,norm(a.real-true.G,"F"))
    # avg.time = c(avg.time, result$total.time)
    BigList = list(i = i,
                   msve = c(mean(ve),sd(ve)), msve.real = c(mean(ve.real),sd(ve.real)),
                   mslog.ve = c(mean(log.ve),sd(log.ve)), mslog.ve.real = c(mean(log.ve.real),sd(log.ve.real)))
    print(BigList)
  }
  return(BigList)
}

meas.mat <- function(D1, D2, D3, net.type = "net1"){
  # net.type = "net1"
  
  D = D1 + D2 + D3
  compare.lst = c("DM", "DM+","DM++", "DMM", "spec", "kmeans", "kmeans++", "louvian", "FClouvian", "EM", "scan")
  res.var = res.mean = matrix(0, nrow = length(compare.lst), ncol = 5)
  colnames(res.mean) = colnames(res.var) = c("Jaccard", "ARI", "Purity", "RI", "Fmeasure")
  rownames(res.mean) = rownames(res.var) = compare.lst
  if (net.type == "net1"){j = 1}else{j = 2}
  path.0 = "../label_txt"
  
  for (k in 1:length(compare.lst)) {
    method.compare = compare.lst[k]
    if (method.compare %in% c("kmeans", "kmeans++", "DMM", "spec")){
      label.txt = paste(path.0, "/labels", "_", method.compare, D, ".txt", sep = "")
    }else{
      label.txt = paste(path.0, "/labels_net", j, "_", method.compare, D, ".txt", sep = "")
    }
    print(label.txt)
    print(c(net.type = net.type, method.compare = method.compare, D = D))
    label.mat = read.csv(label.txt, header = FALSE)
    # write.table(label.mat[,2:ncol(label.mat)], file = label.txt, sep = ",", append = F)
    label.mat = label.mat[2:nrow(label.mat),2:ncol(label.mat)]
    mat = NULL
    print(paste("The dim of label.mat is", c(dim(label.mat)), sep = ":"))
    for (i in 1:dim(label.mat)[1]) {
      mat = rbind(mat, measures(as.numeric(label.mat[i,]), D1, D2, D3))
    }
    res.var[k,] = apply(mat, 2, sd)
    res.mean[k, ] = apply(mat, 2, mean)
    print(list(res.mean = res.mean, res.var = res.var))
  }
  mean.mat.txt = paste(path.0, "/measure.mean", "_net",j,"_", D, ".txt", sep = "")
  var.mat.txt =paste(path.0, "/measure.var", "_net",j,"_", D, ".txt", sep = "")
  write.table(res.mean, file = mean.mat.txt, col.names = F, sep = ",")
  write.table(res.var, file = var.mat.txt, col.names = F, sep = ",")
  return(mat)
}

plot.vert <- function(graph, D1, D2, D3, mode = "simul", type.w = NULL, 
                      col = c("coral2", "aquamarine", "cadetblue4"), i = NULL, lambda = NULL, rho=NULL){
  D = D1 + D2 + D3
  if (mode == "simul"){
    V(graph)[which(V(graph)<=D1)]$color = col[1]
    V(graph)[which(V(graph)<=(D1+D2)&V(graph)>D1)]$color = col[2]
    V(graph)[which(V(graph)>(D1+D2))]$color = col[3]
  }else if(mode == "admm"){
    V(graph)[which(as.numeric(names(V(graph)))<=D1)]$color = col[1]
    V(graph)[which(as.numeric(names(V(graph)))<=(D1+D2)&as.numeric(names(V(graph)))>D1)]$color = col[2]
    V(graph)[which(as.numeric(names(V(graph)))>(D1+D2))]$color = col[3]
  }
  plot(graph, vertex.size = 6, vertex.color = V(graph)$color, main = c(lambda, rho, i))
}

clus.into.grp <- function(z, E, Ebynode, lambda, del.vec = NULL, s){
  
  # clusters--built-in functions
  # --If there is a line between two points, then those two point are considered in one group
  # D = length(unique(c(E[,1],E[,2])))
  prune.E = prune(z = z, E = E, Ebynode = Ebynode, del.vec = del.vec, s = s)
  if (nrow(prune.E) == 0) {group = NULL}
  graph.g =  graph_from_data_frame(prune.E, directed=FALSE)
  clus = igraph::clusters(graph.g, mode="weak")
  # if (clus$no == 1){
  #   print(paste("Please put a smaller lambda."))
  #   return()
  # }
  # group.temp = groups(clus)
  clus$membership = clus$membership[as.character(sort(as.numeric(names(clus$membership))))]
  group = list()
  for (item in 1:max((clus$membership))) {
    group[[item]] = sort(as.numeric(names(clus$membership[which(clus$membership==item)])))
  }
  return(list(group = group,graph.g = graph.g))
}

prune <- function(z, E, Ebynode, del.vec = NULL, s){
  prune.E.vec = numeric()
  for (i in 1:nrow(E)) {
    if (sum(abs(z[[E[i,1]]][as.character(E[i,2]), ] - z[[E[i,2]]][as.character(E[i,1]), ]))<=s){
      prune.E.vec = c(prune.E.vec, E[i,1],E[i,2])
    }
  }
  temp.E = matrix(prune.E.vec, ncol = 2, byrow = T);temp.E
  temp = sort(setdiff(unique(as.vector(E)), unique(as.vector(temp.E))))
  if (length(temp) == 0){prune.E = temp.E}else{
    for (item in 1:length(temp)) {
      mmm = numeric()
      indices = setdiff(Ebynode[[temp[item]]], del.vec)
      for (j in 1:length(indices)) {
        mmm = c(mmm,sum(abs(z[[temp[item]]][as.character(indices[j]), ] - z[[indices[j]]][as.character(temp[item]), ])))
      }
      temp.E = rbind(temp.E, c(temp[item], indices[which.min(mmm)]))
    }
    prune.E = unique(temp.E)
  }
  return(prune.E)
}

# Given the true parameters, this function computes the log probabilities 
# assigned to one text under three models, i.e. DM, DM+ and DM++.
compare <- function(wi, V, gammai){
  wm = NULL
  for (j in 1 : V) {
    tmpvec = rep(0, V)
    if (wi[j] != 0){
      tmpvec[j] = 1
      tmpmat = matrix(rep(tmpvec,wi[j]), ncol = V,byrow = T)
      wm = rbind(wm, tmpmat)
    }
  }
  Ni = sum(wi)
  log.f.DM = lgamma(sum(gammai)) - lgamma(Ni+sum(gammai)) + sum(lgamma(wi+gammai)-lgamma(gammai))
  log.f.DM.p = -sum(log(sum(gammai)+seq(0,Ni-1))) + sum(indicator(wi)*log(gammai))
  log.f.DM.pp = -Ni * log(sum(gammai)) + sum(wi*log(gammai))
  return(list(log.f.DM = log.f.DM, log.f.DM.p = log.f.DM.p, log.f.DM.pp = log.f.DM.pp))
}

prob.compare <- function(w, gamma1, gamma2, gamma3, pair){
  lst1 = lst2 = NULL
  if (pair == "12") {i = 1; j = 2}
  if (pair == "13") {i = 1; j = 3}
  if (pair == "23") {i = 2; j = 3}
  for (d in 1: D) {
    if (d %in% 1 : D1) {
      lst1 = c(lst1, compare(w[d,], V, gamma1)[[i]])
      lst2 = c(lst2, compare(w[d,], V, gamma1)[[j]])
    }
    
    if (d %in% (1+D1) : (D1+D2)){
      lst1 = c(lst1, compare(w[d,], V, gamma2)[[i]])
      lst2 = c(lst2, compare(w[d,], V, gamma2)[[j]])
    }
    if (d %in% (1+D1+D2) : D) {
      lst1 = c(lst1, compare(w[d, ], V, gamma3)[[i]])
      lst2 = c(lst2, compare(w[d, ], V, gamma3)[[j]])
    }
  }
  return(list(lst1 = lst1, lst2 = lst2))
}

compare.plot <- function(w, gamma1, gamma2,gamma3){
  data.compare12 = prob.compare(w, gamma1, gamma2, gamma3, pair = "12")
  data.compare13 = prob.compare(w, gamma1, gamma2, gamma3, pair = "13")
  
  D = nrow(w)
  mat = matrix(c(data.compare12$lst1, data.compare12$lst2, data.compare13$lst2), ncol = 3)
  df = as.data.frame(mat)
  colorScales <- c(rep("#EC6A5C", D), rep("#0099CC", D))
  names(colorScales) <- c(rep("DM+", D), rep("DM++", D))
  lo <- min(df$V1, df$V2, df$V3)
  hi <- max(df$V1, df$V2, df$V3)
  ggplot() +
    geom_abline(intercept=0, slope=1, linetype="dashed") +
    geom_point(aes(x=df$V1, y=df$V2, color=names(colorScales)[1:D]), size=0.5) +
    scale_color_manual(name="Method", values=colorScales) +
    geom_point(aes(x=df$V1, y=df$V3, color=names(colorScales)[(D+1):(2*D)]), size=0.5) +
    labs(x="DM log probabilities", y="DM+/++ log probabilities") +
    theme_classic() +
    theme(panel.background=element_rect(fill="white", color="black"),
          axis.line=element_blank(), axis.text=element_text(color="black"),
          axis.ticks=element_line(color="black")) +
    guides(color=guide_legend(override.aes= list(size=2))) +
    coord_fixed(ratio=1, xlim=c(lo,hi), ylim=c(lo,hi))
}


# negative log likelihood for the whole corpus
# independent non-identical situation
L <- function(G, w) {
  # G  : D * V matrix
  
  ws = rowSums(w)
  Gs = rowSums(G)
  res <- 	sum(lgamma(Gs) - lgamma(ws + Gs) +
                rowSums(lgamma(w + G) - lgamma(G)))
  return(-res)
}

# gradient of L r.t gamma
# vectorize G, and take derivative to DV-vector
dL <- function(G, w) {
  # G  : D * V matrix
  
  Gs = rowSums(G)
  ws = rowSums(w)
  res <- 	t(digamma(Gs) - digamma(ws+Gs)+
              digamma(w+G) - digamma(G))
  return(-as.vector(res))
}

# indicator function
indicator <- function(t){
  res = ifelse(t>=1,1,0)
  return (res)
}

# negative log-likelihood for one doc.
fd <- function(theta.d, w.d){
  G.d=exp(theta.d)
  f.d = lgamma(sum(G.d)) - lgamma(sum(w.d+G.d)) + sum(lgamma(w.d+G.d) - lgamma(G.d))
  
  return(-f.d)
}

dfd <- function(theta.d, w.d){
  G.d=exp(theta.d)
  grad = (digamma(sum(G.d)) - digamma(sum(w.d+G.d)) + digamma(w.d+G.d) - digamma(G.d))*G.d
  
  return(-grad)
}

# negative log-likelihood for one doc. (NEF)
fd.p <- function(theta.d, w.d){
  G.d=exp(theta.d)
  negf.d = lgamma(sum(G.d)) - lgamma(sum(w.d+G.d)) + sum(indicator(w.d)*log(G.d))
  
  return(-negf.d)
}

dfd.p <- function(theta.d, w.d){
  G.d=exp(theta.d)
  neggrad = (digamma(sum(G.d)) - digamma(sum(w.d+G.d)))*G.d + indicator(w.d)
  
  return(-neggrad)
}

# negative log-likelihood for one word by one word. (NEF)
fd.pp <- function(theta.d, w.d){
  G.d=exp(theta.d)
  negf.d = -sum(w.d)*log(sum(G.d))+sum(w.d*theta.d)
  
  return(-negf.d)
}

dfd.pp <- function(theta.d, w.d){
  G.d=exp(theta.d)
  neggrad = -sum(w.d)*G.d/sum(G.d) + w.d
  
  return(-neggrad)
}

cal.Ebynode <- function(E){
  node.set = sort(unique(c(E[ ,1], E[ ,2])))
  D = length(node.set)
  Ebynode = vector("list", D)
  for (i in 1:D) Ebynode[[node.set[i]]] = rep(0,0)
  for (i in 1:nrow(E)){
    Ebynode[[E[i,1]]] = c(Ebynode[[E[i,1]]], E[i,2])
    Ebynode[[E[i,2]]] = c(Ebynode[[E[i,2]]], E[i,1])
  }
  Ebynode
}

cal.weight <- function(w, E, eta){
  xi = NULL
  net.data = w/rowSums(w)
  net.dis = as.matrix(dist(net.data))
  rownames(net.dis) = colnames(net.dis) = sort(unique(c(E[ ,1], E[ ,2])))
  for (e in 1:nrow(E)){
    xi[e] = 1/(net.dis[as.character(E[e,1]), as.character(E[e,2])])^eta
    xi[which(xi == Inf)] = max(xi)
  }
  return(list(xi = xi, net.dis = net.dis, eta = eta))
}

cal.adj <- function(E, D){
  adj=matrix(0, D, D)
  for (e in 1:nrow(E)){
    adj[E[e, 1], E[e, 2]]=1
    adj[E[e, 2], E[e, 1]]=1
  }
  return(adj)
}

cal.E <- function(adj){
  D = ncol(adj)
  E.vec = NULL
  for (i in 1:(D-1)) {
    for (j in (i+1):D) {
      if (adj[i,j] == 1){E.vec = c(E.vec, as.numeric(rownames(adj)[i]), as.numeric(rownames(adj)[j]))}
    }
  }
  E = matrix(E.vec, ncol = 2, byrow = T)
  return(E)
}

# Use the implicit gradient descent to solve the subproblem of ADMM
grad.desc <- function(theta0, w.d, z.d, u.d, rho, method, epsilon = 1.e-10){
  # x0 : initial value; V vector
  Dif = 1
  temp =nrow(z.d)
  a = colSums(z.d - u.d)
  while (Dif > epsilon) {
    if (method == "DM"){
      theta = (-dfd(theta0, w.d)/rho + a)/temp
    }else if (method == "DM+"){
      theta = (-dfd.p(theta0, w.d)/rho + a)/temp
    }else if (method == "DM++"){
      theta = (-dfd.pp(theta0, w.d)/rho + a)/temp
    }
    Dif = sum(abs(theta-theta0))/sum(theta)
    theta0 = theta
  }
  return(theta)
}

label <- function(group){
  lst = NULL
  for (k in 1:length(group)){
    for (i in 1:length(group[[k]])){
      lst[group[[k]][i]] = k
    }
  }
  return(lst)
}

groupList <-function(labelresult){
  group = list()
  aa = unique(labelresult)
  for (a in 1:length(aa)){
    group[[a]] = c(which(labelresult == aa[a]))
  }
  return(group)
}

# This function generates the initial value 
# of the Dirichlet-multinomial parameters.
initial.MOM <- function(w,Ebynode){
  D0 = length(Ebynode)
  D = nrow(w)
  V = ncol(w)
  vec = NULL
  for (d in 1:D){
    if (sum(w[d, ]) == 0){
      vec=c(vec, as.numeric(rownames(w)[d]))
    }
  }
  Gamma = matrix(0, D, V)
  rownames(Gamma) = rownames(w)
  Gamma.sum = alpha0 = rep(0, D)
  names(Gamma.sum) = names(alpha0) = rownames(w)
  B = w/matrix(rowSums(w), D, V)
  for (d in 1:D0) {
    if ((d %in% setdiff(1:D0, vec)) && (is.null(Ebynode[[d]]) == FALSE)){
      tempw = w[as.character(setdiff(c(d,Ebynode[[d]]), vec)), ]
      tempb = B[as.character(setdiff(c(d,Ebynode[[d]]), vec)), ]
      beta = rep(0,V)
      if (length(setdiff(c(d, Ebynode[[d]]), vec)) != 1){
        beta = colSums(tempb)/nrow(tempb)
        Dd = nrow(tempw)
        wrs = rowSums(tempw)
        wcs = colSums(tempw)
        ws = sum(tempw)
        a = (tempw/wrs - matrix(wcs/ws, Dd,V,byrow = T))^2
        S = colSums(wrs*a)/(Dd-1)
        
        b = colSums(tempw * (wrs-tempw) / wrs)
        c = ws - Dd
        G = b/c
        
        Nc = (ws - sum(wrs^2)/ws) / (Dd-1)
        
        alpha0[as.character(d)] = pmax(sum(S-G) / sum(S+(Nc-1)*G), 0.01)
        Gamma.sum[as.character(d)] = (1-alpha0[as.character(d)]) / alpha0[as.character(d)]
        Gamma[as.character(d), ] = Gamma.sum[as.character(d)] * beta
      }
    }
  }
  for (d in 1:D0){
    if((d %in% vec) && (is.null(Ebynode[[d]]) == FALSE)){
      if (length(setdiff(Ebynode[[d]], vec)) > 1){
        Gamma[as.character(d), ] = colMeans(Gamma[setdiff(Ebynode[[d]], vec), ])
      }else{Gamma[as.character(d), ] = Gamma[setdiff(Ebynode[[d]], vec), ]}
    }
  }
  return(Gamma*1)
} 

measures.real <- function(labels, labelstrue = labelreal){
  D = length(labels)
  # The preparatory work
  TP = TN = FP = FN = 0
  for (i in 1:(D-1)) {
    for (j in (i+1):D) {
      if (labelstrue[i] == labelstrue[j]){
        if (labels[i] == labels[j]){TP = TP + 1}
        if (labels[i] != labels[j]){FN = FN + 1}
      }
      if (labelstrue[i] != labelstrue[j]){
        if (labels[i] == labels[j]){FP = FP + 1}
        if (labels[i] != labels[j]){TN = TN + 1}
      }
    }
  }
  
  # Ajusted Rand Index
  tab = table(labelstrue, labels)
  ARI = randIndex(tab)
  
  # Rand Index
  RI = (TP+TN) / (TP+TN+FN+FP)
  precision = TP / (TP+FP)
  recall = TP / (TP+FN)
  
  # F-measure
  Fmeasure = 2 * precision * recall / (precision+recall)
  
  # Jaccard Index
  Jaccard = TP / (TP+FP+FN)
  
  # Purity
  Purity = sum(apply(tab, 2, max)) / D
  return(c(Jaccard = Jaccard, ARI = ARI, Purity = Purity, RI = RI, Fmeasure = Fmeasure))
}

meas.mat.real <- function(net.type = "cora"){
  
  compare.lst = c("DM", "DM+", "DM++", "kmeans", "kmeans++", "louvian", "FClouvian", "EM", "scan")
  res.var = res.mean = matrix(0, nrow = length(compare.lst), ncol = 5)
  colnames(res.mean) = colnames(res.var) = c("Jaccard","RI","ARI", "Fmeasure", "Purity")
  rownames(res.mean) = rownames(res.var) = compare.lst
  path.0 = "../label_txt"
  
  for (k in 1:length(compare.lst)) {
    method.compare = compare.lst[k]
    label.txt = paste(path.0, "/labels", "_",net.type,"_", method.compare, ".txt", sep = "")
    
    print(label.txt)
    print(c(net.type = net.type, method.compare = method.compare))
    label.mat = read.csv(label.txt, header = FALSE)
    # write.table(label.mat[ ,2:ncol(label.mat)], file = label.txt, sep = ",", append = F)
    label.mat = label.mat[,2:ncol(label.mat)]
    mat = NULL
    print(paste("The dim of label.mat is", c(dim(label.mat)), sep = ":"))
    for (i in 1:dim(label.mat)[1]) {
      mat = rbind(mat, measures.real(as.numeric(label.mat[i,])))
    }
    res.var[k,] = apply(mat, 2, sd)
    res.mean[k, ] = apply(mat, 2, mean)
    print(list(res.mean = res.mean, res.var = res.var))
  }
  mean.mat.txt = paste(path.0, "/measure.mean", "_",net.type, ".txt", sep = "")
  write.table(res.mean, file = mean.mat.txt, col.names = F, sep = ",")
  return(mat)
}

