Windows版本
//VS2022
#include <stdio.h>
#include <stdlib.h>
#define max(a,b) ((a) > (b) ? (a) : (b))
typedef struct Node {
int data, h;
struct Node* lchild, * rchild;
}Node;
Node __NIL;
#define NIL (&__NIL)
#define DEBUG
#ifdef DEBUG
#define LOG(frm, ...) {\
printf(frm, ##__VA_ARGS__);\
}\
#else
#define LOG(frm, ...) {}
#endif
void init_NIL() {
NIL->data = NIL->h = 0;
NIL->lchild = NIL->rchild = NIL;
return;
}
Node* getNewNode(int val) {
Node* p = (Node*)malloc(sizeof(Node));
p->data = val;
p->h = 1;
p->lchild = p->rchild = NIL;
return p;
}
void update_height(Node* root) {
root->h = max(root->lchild->h, root->rchild->h) + 1;
return;
}
Node* left_rotate(Node* root) {
LOG("%d left rotate\n", root->data);
Node* new_root = root->rchild;
root->rchild = new_root->lchild;
new_root->lchild = root;
update_height(root);
update_height(new_root);
return new_root;
}
Node* right_rotate(Node* root) {
LOG("%d left rotate\n", root->data);
Node* new_root = root->lchild;
root->lchild = new_root->rchild;
new_root->rchild = root;
update_height(root);
update_height(new_root);
return new_root;
}
const char* type_str[4] = { "LL", "LR", "RR", "RL" };
Node* maintain(Node* root) {
if (abs(root->lchild->h - root->rchild->h) < 2) return root;
int type = -1;
if (root->lchild->h > root->rchild->h) {
if (root->lchild->rchild->h > root->lchild->lchild->h) {
root->lchild = left_rotate(root->lchild);
type += 1;
}
type += 1;
root = right_rotate(root);
}
else {
type = 1;
if (root->rchild->lchild->h > root->rchild->rchild->h) {
root->rchild = right_rotate(root->rchild);
type += 1;
}
type += 1;
root = left_rotate(root);
}
LOG(" TYPE : %s\n", type_str[type]);
return root;
}
Node* insert(Node* root, int val) {
if (root == NIL) return getNewNode(val);
if (root->data == val) return root;
if (root->data > val) root->lchild = insert(root->lchild, val);
else root->rchild = insert(root->rchild, val);
update_height(root);
return maintain(root);
}
Node* erase(Node* root, int val) {
if (root == NIL) return root;
if (root->data > val) root->lchild = erase(root->lchild, val);
else if (root->data < val) root->rchild = erase(root->rchild, val);
else {
if (root->lchild == NIL || root->rchild == NIL) {
Node* temp = root->lchild != NIL ? root->lchild : root->rchild;
free(root);
return temp;
}
else {
Node* temp = root->lchild;
while (temp->rchild != NIL) temp = temp->rchild;
root->data = temp->data;
root->lchild = erase(root->lchild, temp->data);
}
}
update_height(root);
return maintain(root);
}
void clear(Node* root) {
if (root == NIL) return;
clear(root->lchild);
clear(root->rchild);
free(root);
return;
}
void print_node(Node* root) {
printf("[ %d(%d) | %d, %d]\n", root->data, root->h, root->lchild->data, root->rchild->data);
return;
}
void output(Node* root) {
if (root == NIL) return;
print_node(root);
output(root->lchild);
output(root->rchild);
return;
}
int main() {
Node* root = NIL;
int val;
//insert
while (~scanf_s("%d", &val)) {
if (val == -1) break;
printf("insert %d to AVL Tree\n", val);
root = insert(root, val);
output(root);
}
//erase
while (~scanf_s("%d", &val)) {
if (val == -1) break;
printf("erase %d from AVL Tree\n", val);
root = erase(root, val);
output(root);
}
return 0;
}
Linux版本
#include <stdio.h>
#include <stdlib.h>
#define max(a,b) ((a) > (b) ? (a) : (b))
typedef struct Node {
int data, h;
struct Node* lchild, * rchild;
}Node;
Node __NIL;
#define NIL (&__NIL)
#define DEBUG
#ifdef DEBUG
#define LOG(frm, args...) {\
printf(frm, ##args);\
}\
#else
#define LOG(frm, args...) {}
#endif
__attribute__((constructor))
void init_NIL() {
NIL->data = NIL->h = 0;
NIL->lchild = NIL->rchild = NIL;
return;
}
Node* getNewNode(int val) {
Node* p = (Node*)malloc(sizeof(Node));
p->data = val;
p->h = 1;
p->lchild = p->rchild = NIL;
return p;
}
void update_height(Node* root) {
root->h = max(root->lchild->h, root->rchild->h) + 1;
return;
}
Node* left_rotate(Node* root) {
LOG("%d left rotate\n", root->data);
Node* new_root = root->rchild;
root->rchild = new_root->lchild;
new_root->lchild = root;
update_height(root);
update_height(new_root);
return new_root;
}
Node* right_rotate(Node* root) {
LOG("%d left rotate\n", root->data);
Node* new_root = root->lchild;
root->lchild = new_root->rchild;
new_root->rchild = root;
update_height(root);
update_height(new_root);
return new_root;
}
const char* type_str[4] = { "LL", "LR", "RR", "RL" };
Node* maintain(Node* root) {
if (abs(root->lchild->h - root->rchild->h) < 2) return root;
int type = -1;
if (root->lchild->h > root->rchild->h) {
if (root->lchild->rchild->h > root->lchild->lchild->h) {
root->lchild = left_rotate(root->lchild);
type += 1;
}
type += 1;
root = right_rotate(root);
}
else {
type = 1;
if (root->rchild->lchild->h > root->rchild->rchild->h) {
root->rchild = right_rotate(root->rchild);
type += 1;
}
type += 1;
root = left_rotate(root);
}
LOG(" TYPE : %s\n", type_str[type]);
return root;
}
Node* insert(Node* root, int val) {
if (root == NIL) return getNewNode(val);
if (root->data == val) return root;
if (root->data > val) root->lchild = insert(root->lchild, val);
else root->rchild = insert(root->rchild, val);
update_height(root);
return maintain(root);
}
Node* erase(Node* root, int val) {
if (root == NIL) return root;
if (root->data > val) root->lchild = erase(root->lchild, val);
else if (root->data < val) root->rchild = erase(root->rchild, val);
else {
if (root->lchild == NIL || root->rchild == NIL) {
Node* temp = root->lchild != NIL ? root->lchild : root->rchild;
free(root);
return temp;
}
else {
Node* temp = root->lchild;
while (temp->rchild != NIL) temp = temp->rchild;
root->data = temp->data;
root->lchild = erase(root->lchild, temp->data);
}
}
update_height(root);
return maintain(root);
}
void clear(Node* root) {
if (root == NIL) return;
clear(root->lchild);
clear(root->rchild);
free(root);
return;
}
void print_node(Node* root) {
printf("[ %d(%d) | %d, %d]\n", root->data, root->h, root->lchild->data, root->rchild->data);
return;
}
void output(Node* root) {
if (root == NIL) return;
print_node(root);
output(root->lchild);
output(root->rchild);
return;
}
int main() {
Node* root = NIL;
int val;
//insert
while (~scanf("%d", &val)) {
if (val == -1) break;
printf("insert %d to AVL Tree\n", val);
root = insert(root, val);
output(root);
}
//erase
while (~scanf("%d", &val)) {
if (val == -1) break;
printf("erase %d from AVL Tree\n", val);
root = erase(root, val);
output(root);
}
return 0;
}